# HG changeset patch
# User goeckslab
# Date 1763774133 0
# Node ID 9abf329ff04590df2d78e5a41332420e1fa9206b
# Parent d5a7ca6bc64076b11187df895305559dcd0456dd
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit e2ab4c0f9ce8b7a0a48f749ef5dd9899d6c2b1f8
diff -r d5a7ca6bc640 -r 9abf329ff045 ludwig_experiment.py
--- a/ludwig_experiment.py Sat Sep 06 01:53:57 2025 +0000
+++ b/ludwig_experiment.py Sat Nov 22 01:15:33 2025 +0000
@@ -1,10 +1,15 @@
+import base64
+import html
import json
import logging
import os
import pickle
+import re
import sys
+from io import BytesIO
import pandas as pd
+from ludwig.api import LudwigModel
from ludwig.experiment import cli
from ludwig.globals import (
DESCRIPTION_FILE_NAME,
@@ -21,6 +26,11 @@
get_html_template
)
+try: # pragma: no cover - optional dependency in runtime containers
+ import matplotlib.pyplot as plt
+except ImportError: # pragma: no cover
+ plt = None
+
logging.basicConfig(level=logging.DEBUG)
@@ -158,44 +168,435 @@
LOG.error(f"Error converting Parquet to CSV: {e}")
-def generate_html_report(title, ludwig_output_directory_name):
- # ludwig_output_directory = os.path.join(
- # output_directory, ludwig_output_directory_name)
+def _resolve_dataset_path(dataset_path):
+ if not dataset_path:
+ return None
+
+ candidates = [dataset_path]
+
+ if not os.path.isabs(dataset_path):
+ candidates.extend([
+ os.path.join(output_directory, dataset_path),
+ os.path.join(os.getcwd(), dataset_path),
+ ])
+
+ for candidate in candidates:
+ if candidate and os.path.exists(candidate):
+ return os.path.abspath(candidate)
+
+ return None
+
+
+def _load_dataset_dataframe(dataset_path):
+ if not dataset_path:
+ return None
+
+ _, ext = os.path.splitext(dataset_path.lower())
+
+ try:
+ if ext in {".csv", ".tsv"}:
+ sep = "\t" if ext == ".tsv" else ","
+ return pd.read_csv(dataset_path, sep=sep)
+ if ext == ".parquet":
+ return pd.read_parquet(dataset_path)
+ if ext == ".json":
+ return pd.read_json(dataset_path)
+ if ext == ".h5":
+ return pd.read_hdf(dataset_path)
+ except Exception as exc:
+ LOG.warning(f"Unable to load dataset '{dataset_path}': {exc}")
+
+ LOG.warning("Unsupported dataset format for feature importance computation")
+ return None
+
+
+def sanitize_feature_name(name):
+ """Mirror Ludwig's get_sanitized_feature_name implementation."""
+ return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", str(name))
+
+
+def _sanitize_dataframe_columns(dataframe):
+ """Rename dataframe columns to Ludwig-sanitized names for explainability."""
+ column_map = {col: sanitize_feature_name(col) for col in dataframe.columns}
+
+ sanitized_df = dataframe.rename(columns=column_map)
+ if len(set(column_map.values())) != len(column_map.values()):
+ LOG.warning(
+ "Column name collision after sanitization; feature importance may be unreliable"
+ )
+
+ return sanitized_df
+
+
+def _feature_importance_plot(label_df, label_name, top_n=10, max_abs_importance=None):
+ """
+ Return base64-encoded bar plot for a label's top-N feature importances.
+
+ max_abs_importance lets us pin the x-axis across labels so readers can
+ compare magnitudes.
+ """
+ if plt is None or label_df.empty:
+ return ""
+
+ top_features = label_df.nlargest(top_n, "abs_importance")
+ if top_features.empty:
+ return ""
+
+ fig, ax = plt.subplots(figsize=(6, 3 + 0.2 * len(top_features)))
+ ax.barh(top_features["feature"], top_features["abs_importance"], color="#3f8fd2")
+ ax.set_xlabel("|importance|")
+ if max_abs_importance and max_abs_importance > 0:
+ ax.set_xlim(0, max_abs_importance * 1.05)
+ ax.invert_yaxis()
+ fig.tight_layout()
+
+ buf = BytesIO()
+ fig.savefig(buf, format="png", dpi=150)
+ plt.close(fig)
+ encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
+ return encoded
+
+
+def render_feature_importance_table(df: pd.DataFrame) -> str:
+ """Render a sortable HTML table for feature importance values."""
+ if df.empty:
+ return ""
+
+ columns = list(df.columns)
+ headers = "".join(
+ f"
{html.escape(str(col).replace('_', ' '))}
"
+ for col in columns
+ )
+
+ body_rows = []
+ for _, row in df.iterrows():
+ cells = []
+ for col in columns:
+ val = row[col]
+ if isinstance(val, float):
+ val_str = f"{val:.6f}"
+ else:
+ val_str = str(val)
+ cells.append(f"
"
- # test_statistics_html += json_to_html_table(
- # test_statistics)
- # except Exception as e:
- # LOG.info(f"Error reading test statistics: {e}")
+ try:
+ ludwig_model = LudwigModel.load(model_dir)
+ except Exception as exc:
+ LOG.warning(f"Unable to load Ludwig model for explanations: {exc}")
+ return
+
+ training_metadata = getattr(ludwig_model, "training_set_metadata", {})
+
+ output_feature_name, dataset_path = get_output_feature_name(
+ ludwig_output_directory)
+
+ if not output_feature_name or not dataset_path:
+ LOG.warning("Output feature or dataset path missing; skipping feature importance")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ dataset_full_path = _resolve_dataset_path(dataset_path)
+ if not dataset_full_path:
+ LOG.warning(f"Unable to resolve dataset path '{dataset_path}' for explanations")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ dataframe = _load_dataset_dataframe(dataset_full_path)
+ if dataframe is None or dataframe.empty:
+ LOG.warning("Dataset unavailable or empty; skipping feature importance")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ dataframe = _sanitize_dataframe_columns(dataframe)
+
+ data_subset = dataframe if len(dataframe) <= sample_size else dataframe.head(sample_size)
+ sample_df = dataframe.sample(
+ n=min(sample_size, len(dataframe)),
+ random_state=random_seed,
+ replace=False,
+ ) if len(dataframe) > sample_size else dataframe
+
+ try:
+ from ludwig.explain.captum import IntegratedGradientsExplainer
+ except ImportError as exc:
+ LOG.warning(f"Integrated Gradients explainer unavailable: {exc}")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ sanitized_output_feature = sanitize_feature_name(output_feature_name)
+
+ try:
+ explainer = IntegratedGradientsExplainer(
+ ludwig_model,
+ data_subset,
+ sample_df,
+ sanitized_output_feature,
+ )
+ explanations = explainer.explain()
+ except Exception as exc:
+ LOG.warning(f"Unable to compute feature importance: {exc}")
+ if hasattr(ludwig_model, "close"):
+ ludwig_model.close()
+ return
+
+ if hasattr(ludwig_model, "close"):
+ try:
+ ludwig_model.close()
+ except Exception:
+ pass
- # Convert visualizations to HTML
+ label_names = []
+ target_metadata = {}
+ if isinstance(training_metadata, dict):
+ target_metadata = training_metadata.get(sanitized_output_feature, {})
+
+ if isinstance(target_metadata, dict):
+ if "idx2str" in target_metadata:
+ idx2str = target_metadata["idx2str"]
+ if isinstance(idx2str, dict):
+ def _idx_key(item):
+ idx_key = item[0]
+ try:
+ return (0, int(idx_key))
+ except (TypeError, ValueError):
+ return (1, str(idx_key))
+
+ label_names = [value for key, value in sorted(
+ idx2str.items(), key=_idx_key)]
+ else:
+ label_names = idx2str
+ elif "str2idx" in target_metadata and isinstance(
+ target_metadata["str2idx"], dict):
+ # invert mapping
+ label_names = [label for label, _ in sorted(
+ target_metadata["str2idx"].items(),
+ key=lambda item: item[1])]
+
+ rows = []
+ global_explanation = explanations.global_explanation
+ for label_index, label_explanation in enumerate(
+ global_explanation.label_explanations):
+ if label_names and label_index < len(label_names):
+ label_value = str(label_names[label_index])
+ elif len(global_explanation.label_explanations) == 1:
+ label_value = output_feature_name
+ else:
+ label_value = str(label_index)
+
+ for feature in label_explanation.feature_attributions:
+ rows.append({
+ "label": label_value,
+ "feature": feature.feature_name,
+ "importance": feature.attribution,
+ "abs_importance": abs(feature.attribution),
+ })
+
+ if not rows:
+ LOG.warning("No feature importance rows produced")
+ return
+
+ importance_df = pd.DataFrame(rows)
+ importance_df.sort_values([
+ "label",
+ "abs_importance"
+ ], ascending=[True, False], inplace=True)
+
+ importance_df.to_csv(output_csv_path, index=False)
+
+ LOG.info(f"Feature importance saved to {output_csv_path}")
+
+
+def generate_html_report(title, ludwig_output_directory_name):
plots_html = ""
- if len(os.listdir(viz_output_directory)) > 0:
+ plot_files = []
+ if os.path.isdir(viz_output_directory):
+ plot_files = sorted(os.listdir(viz_output_directory))
+ if plot_files:
plots_html = "
Visualizations
"
- for plot_file in sorted(os.listdir(viz_output_directory)):
+ for plot_file in plot_files:
plot_path = os.path.join(viz_output_directory, plot_file)
if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")):
encoded_image = encode_image_to_base64(plot_path)
+ plot_title = os.path.splitext(plot_file)[0].replace("_", " ")
plots_html += (
f'
Feature importance scores come from Ludwig's Integrated Gradients explainer. "
+ "It interpolates between each example and a neutral baseline sample, summing "
+ "the change in the model output along that path. Higher |importance| values "
+ "indicate stronger influence. Plots share a common x-axis to make magnitudes "
+ "comparable across labels, and the table columns can be sorted for quick scans.
"
+ )
+ feature_importance_html = (
+ "
Feature Importance
"
+ + explanation_text
+ + render_feature_importance_table(top_rows)
+ + "".join(plot_sections)
+ )
+ except Exception as exc:
+ LOG.info(f"Unable to embed feature importance table: {exc}")
+
# Generate the full HTML content
+ feature_section = feature_importance_html or "