diff ludwig_experiment.py @ 8:0ee0bc6736a2 draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit e2ab4c0f9ce8b7a0a48f749ef5dd9899d6c2b1f8
author goeckslab
date Sat, 22 Nov 2025 01:17:09 +0000
parents 78b1e3921576
children
line wrap: on
line diff
--- a/ludwig_experiment.py	Sat Sep 06 01:53:23 2025 +0000
+++ b/ludwig_experiment.py	Sat Nov 22 01:17:09 2025 +0000
@@ -1,10 +1,15 @@
+import base64
+import html
 import json
 import logging
 import os
 import pickle
+import re
 import sys
+from io import BytesIO
 
 import pandas as pd
+from ludwig.api import LudwigModel
 from ludwig.experiment import cli
 from ludwig.globals import (
     DESCRIPTION_FILE_NAME,
@@ -21,6 +26,11 @@
     get_html_template
 )
 
+try:  # pragma: no cover - optional dependency in runtime containers
+    import matplotlib.pyplot as plt
+except ImportError:  # pragma: no cover
+    plt = None
+
 
 logging.basicConfig(level=logging.DEBUG)
 
@@ -158,44 +168,435 @@
         LOG.error(f"Error converting Parquet to CSV: {e}")
 
 
-def generate_html_report(title, ludwig_output_directory_name):
-    # ludwig_output_directory = os.path.join(
-    #     output_directory, ludwig_output_directory_name)
+def _resolve_dataset_path(dataset_path):
+    if not dataset_path:
+        return None
+
+    candidates = [dataset_path]
+
+    if not os.path.isabs(dataset_path):
+        candidates.extend([
+            os.path.join(output_directory, dataset_path),
+            os.path.join(os.getcwd(), dataset_path),
+        ])
+
+    for candidate in candidates:
+        if candidate and os.path.exists(candidate):
+            return os.path.abspath(candidate)
+
+    return None
+
+
+def _load_dataset_dataframe(dataset_path):
+    if not dataset_path:
+        return None
+
+    _, ext = os.path.splitext(dataset_path.lower())
+
+    try:
+        if ext in {".csv", ".tsv"}:
+            sep = "\t" if ext == ".tsv" else ","
+            return pd.read_csv(dataset_path, sep=sep)
+        if ext == ".parquet":
+            return pd.read_parquet(dataset_path)
+        if ext == ".json":
+            return pd.read_json(dataset_path)
+        if ext == ".h5":
+            return pd.read_hdf(dataset_path)
+    except Exception as exc:
+        LOG.warning(f"Unable to load dataset '{dataset_path}': {exc}")
+
+    LOG.warning("Unsupported dataset format for feature importance computation")
+    return None
+
+
+def sanitize_feature_name(name):
+    """Mirror Ludwig's get_sanitized_feature_name implementation."""
+    return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", str(name))
+
+
+def _sanitize_dataframe_columns(dataframe):
+    """Rename dataframe columns to Ludwig-sanitized names for explainability."""
+    column_map = {col: sanitize_feature_name(col) for col in dataframe.columns}
+
+    sanitized_df = dataframe.rename(columns=column_map)
+    if len(set(column_map.values())) != len(column_map.values()):
+        LOG.warning(
+            "Column name collision after sanitization; feature importance may be unreliable"
+        )
+
+    return sanitized_df
+
+
+def _feature_importance_plot(label_df, label_name, top_n=10, max_abs_importance=None):
+    """
+    Return base64-encoded bar plot for a label's top-N feature importances.
+
+    max_abs_importance lets us pin the x-axis across labels so readers can
+    compare magnitudes.
+    """
+    if plt is None or label_df.empty:
+        return ""
+
+    top_features = label_df.nlargest(top_n, "abs_importance")
+    if top_features.empty:
+        return ""
+
+    fig, ax = plt.subplots(figsize=(6, 3 + 0.2 * len(top_features)))
+    ax.barh(top_features["feature"], top_features["abs_importance"], color="#3f8fd2")
+    ax.set_xlabel("|importance|")
+    if max_abs_importance and max_abs_importance > 0:
+        ax.set_xlim(0, max_abs_importance * 1.05)
+    ax.invert_yaxis()
+    fig.tight_layout()
+
+    buf = BytesIO()
+    fig.savefig(buf, format="png", dpi=150)
+    plt.close(fig)
+    encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
+    return encoded
+
+
+def render_feature_importance_table(df: pd.DataFrame) -> str:
+    """Render a sortable HTML table for feature importance values."""
+    if df.empty:
+        return ""
+
+    columns = list(df.columns)
+    headers = "".join(
+        f"<th class='sortable'>{html.escape(str(col).replace('_', ' '))}</th>"
+        for col in columns
+    )
+
+    body_rows = []
+    for _, row in df.iterrows():
+        cells = []
+        for col in columns:
+            val = row[col]
+            if isinstance(val, float):
+                val_str = f"{val:.6f}"
+            else:
+                val_str = str(val)
+            cells.append(f"<td>{html.escape(val_str)}</td>")
+        body_rows.append("<tr>" + "".join(cells) + "</tr>")
+
+    return (
+        "<div class='scroll-rows-30'>"
+        "<table class='feature-importance-table sortable-table'>"
+        f"<thead><tr>{headers}</tr></thead>"
+        f"<tbody>{''.join(body_rows)}</tbody>"
+        "</table>"
+        "</div>"
+    )
+
+
+def compute_feature_importance(ludwig_output_directory_name,
+                               sample_size=200,
+                               random_seed=42):
+    ludwig_output_directory = os.path.join(
+        output_directory, ludwig_output_directory_name)
+    model_dir = os.path.join(ludwig_output_directory, "model")
+
+    output_csv_path = os.path.join(
+        ludwig_output_directory, "feature_importance.csv")
+
+    if not os.path.exists(model_dir):
+        LOG.info("Model directory not found; skipping feature importance computation")
+        return
 
-    # test_statistics_html = ""
-    # # Read test statistics JSON and convert to HTML table
-    # try:
-    #     test_statistics_path = os.path.join(
-    #         ludwig_output_directory, TEST_STATISTICS_FILE_NAME)
-    #     with open(test_statistics_path, "r") as f:
-    #         test_statistics = json.load(f)
-    #     test_statistics_html = "<h2>Test Statistics</h2>"
-    #     test_statistics_html += json_to_html_table(
-    #         test_statistics)
-    # except Exception as e:
-    #     LOG.info(f"Error reading test statistics: {e}")
+    try:
+        ludwig_model = LudwigModel.load(model_dir)
+    except Exception as exc:
+        LOG.warning(f"Unable to load Ludwig model for explanations: {exc}")
+        return
+
+    training_metadata = getattr(ludwig_model, "training_set_metadata", {})
+
+    output_feature_name, dataset_path = get_output_feature_name(
+        ludwig_output_directory)
+
+    if not output_feature_name or not dataset_path:
+        LOG.warning("Output feature or dataset path missing; skipping feature importance")
+        if hasattr(ludwig_model, "close"):
+            ludwig_model.close()
+        return
+
+    dataset_full_path = _resolve_dataset_path(dataset_path)
+    if not dataset_full_path:
+        LOG.warning(f"Unable to resolve dataset path '{dataset_path}' for explanations")
+        if hasattr(ludwig_model, "close"):
+            ludwig_model.close()
+        return
+
+    dataframe = _load_dataset_dataframe(dataset_full_path)
+    if dataframe is None or dataframe.empty:
+        LOG.warning("Dataset unavailable or empty; skipping feature importance")
+        if hasattr(ludwig_model, "close"):
+            ludwig_model.close()
+        return
+
+    dataframe = _sanitize_dataframe_columns(dataframe)
+
+    data_subset = dataframe if len(dataframe) <= sample_size else dataframe.head(sample_size)
+    sample_df = dataframe.sample(
+        n=min(sample_size, len(dataframe)),
+        random_state=random_seed,
+        replace=False,
+    ) if len(dataframe) > sample_size else dataframe
+
+    try:
+        from ludwig.explain.captum import IntegratedGradientsExplainer
+    except ImportError as exc:
+        LOG.warning(f"Integrated Gradients explainer unavailable: {exc}")
+        if hasattr(ludwig_model, "close"):
+            ludwig_model.close()
+        return
+
+    sanitized_output_feature = sanitize_feature_name(output_feature_name)
+
+    try:
+        explainer = IntegratedGradientsExplainer(
+            ludwig_model,
+            data_subset,
+            sample_df,
+            sanitized_output_feature,
+        )
+        explanations = explainer.explain()
+    except Exception as exc:
+        LOG.warning(f"Unable to compute feature importance: {exc}")
+        if hasattr(ludwig_model, "close"):
+            ludwig_model.close()
+        return
+
+    if hasattr(ludwig_model, "close"):
+        try:
+            ludwig_model.close()
+        except Exception:
+            pass
 
-    # Convert visualizations to HTML
+    label_names = []
+    target_metadata = {}
+    if isinstance(training_metadata, dict):
+        target_metadata = training_metadata.get(sanitized_output_feature, {})
+
+    if isinstance(target_metadata, dict):
+        if "idx2str" in target_metadata:
+            idx2str = target_metadata["idx2str"]
+            if isinstance(idx2str, dict):
+                def _idx_key(item):
+                    idx_key = item[0]
+                    try:
+                        return (0, int(idx_key))
+                    except (TypeError, ValueError):
+                        return (1, str(idx_key))
+
+                label_names = [value for key, value in sorted(
+                    idx2str.items(), key=_idx_key)]
+            else:
+                label_names = idx2str
+        elif "str2idx" in target_metadata and isinstance(
+                target_metadata["str2idx"], dict):
+            # invert mapping
+            label_names = [label for label, _ in sorted(
+                target_metadata["str2idx"].items(),
+                key=lambda item: item[1])]
+
+    rows = []
+    global_explanation = explanations.global_explanation
+    for label_index, label_explanation in enumerate(
+            global_explanation.label_explanations):
+        if label_names and label_index < len(label_names):
+            label_value = str(label_names[label_index])
+        elif len(global_explanation.label_explanations) == 1:
+            label_value = output_feature_name
+        else:
+            label_value = str(label_index)
+
+        for feature in label_explanation.feature_attributions:
+            rows.append({
+                "label": label_value,
+                "feature": feature.feature_name,
+                "importance": feature.attribution,
+                "abs_importance": abs(feature.attribution),
+            })
+
+    if not rows:
+        LOG.warning("No feature importance rows produced")
+        return
+
+    importance_df = pd.DataFrame(rows)
+    importance_df.sort_values([
+        "label",
+        "abs_importance"
+    ], ascending=[True, False], inplace=True)
+
+    importance_df.to_csv(output_csv_path, index=False)
+
+    LOG.info(f"Feature importance saved to {output_csv_path}")
+
+
+def generate_html_report(title, ludwig_output_directory_name):
     plots_html = ""
-    if len(os.listdir(viz_output_directory)) > 0:
+    plot_files = []
+    if os.path.isdir(viz_output_directory):
+        plot_files = sorted(os.listdir(viz_output_directory))
+    if plot_files:
         plots_html = "<h2>Visualizations</h2>"
-    for plot_file in sorted(os.listdir(viz_output_directory)):
+    for plot_file in plot_files:
         plot_path = os.path.join(viz_output_directory, plot_file)
         if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")):
             encoded_image = encode_image_to_base64(plot_path)
+            plot_title = os.path.splitext(plot_file)[0].replace("_", " ")
             plots_html += (
                 f'<div class="plot">'
-                f'<h3>{os.path.splitext(plot_file)[0]}</h3>'
+                f'<h3>{plot_title}</h3>'
                 '<img src="data:image/png;base64,'
                 f'{encoded_image}" alt="{plot_file}">'
                 f'</div>'
             )
 
+    feature_importance_html = ""
+    importance_path = os.path.join(
+        output_directory,
+        ludwig_output_directory_name,
+        "feature_importance.csv",
+    )
+    if os.path.exists(importance_path):
+        try:
+            importance_df = pd.read_csv(importance_path)
+            if not importance_df.empty:
+                sorted_df = (
+                    importance_df
+                    .sort_values(["label", "abs_importance"], ascending=[True, False])
+                )
+                top_rows = (
+                    sorted_df
+                    .groupby("label", as_index=False)
+                    .head(5)
+                )
+                max_abs_importance = pd.to_numeric(
+                    importance_df.get("abs_importance", pd.Series(dtype=float)),
+                    errors="coerce",
+                ).max()
+                if pd.isna(max_abs_importance):
+                    max_abs_importance = None
+
+                plot_sections = []
+                for label in sorted(importance_df["label"].unique()):
+                    encoded_plot = _feature_importance_plot(
+                        importance_df[importance_df["label"] == label],
+                        label,
+                        max_abs_importance=max_abs_importance,
+                    )
+                    if encoded_plot:
+                        plot_sections.append(
+                            f'<div class="plot feature-importance-plot">'
+                            f'<h3>Top features for {label}</h3>'
+                            f'<img src="data:image/png;base64,{encoded_plot}" '
+                            f'alt="Feature importance plot for {label}">'
+                            f'</div>'
+                        )
+                explanation_text = (
+                    "<p>Feature importance scores come from Ludwig's Integrated Gradients explainer. "
+                    "It interpolates between each example and a neutral baseline sample, summing "
+                    "the change in the model output along that path. Higher |importance| values "
+                    "indicate stronger influence. Plots share a common x-axis to make magnitudes "
+                    "comparable across labels, and the table columns can be sorted for quick scans.</p>"
+                )
+                feature_importance_html = (
+                    "<h2>Feature Importance</h2>"
+                    + explanation_text
+                    + render_feature_importance_table(top_rows)
+                    + "".join(plot_sections)
+                )
+        except Exception as exc:
+            LOG.info(f"Unable to embed feature importance table: {exc}")
+
     # Generate the full HTML content
+    feature_section = feature_importance_html or "<p>No feature importance artifacts were generated.</p>"
+    viz_section = plots_html or "<p>No visualizations were generated.</p>"
+    tabs_style = """
+    <style>
+        .tabs {
+            display: flex;
+            border-bottom: 2px solid #ccc;
+            margin-top: 20px;
+            margin-bottom: 1rem;
+        }
+        .tablink {
+            padding: 9px 18px;
+            cursor: pointer;
+            border: 1px solid #ccc;
+            border-bottom: none;
+            background: #f9f9f9;
+            margin-right: 5px;
+            border-top-left-radius: 8px;
+            border-top-right-radius: 8px;
+            font-size: 0.95rem;
+            font-weight: 500;
+            font-family: Arial, sans-serif;
+            color: #4A4A4A;
+        }
+        .tablink.active {
+            background: #ffffff;
+            font-weight: bold;
+        }
+        .tabcontent {
+            border: 1px solid #ccc;
+            border-top: none;
+            padding: 20px;
+            display: none;
+        }
+        .tabcontent.active {
+            display: block;
+        }
+    </style>
+    """
+    tabs_script = """
+    <script>
+        function openTab(evt, tabId) {
+            var i, tabcontent, tablinks;
+            tabcontent = document.getElementsByClassName("tabcontent");
+            for (i = 0; i < tabcontent.length; i++) {
+                tabcontent[i].style.display = "none";
+                tabcontent[i].classList.remove("active");
+            }
+            tablinks = document.getElementsByClassName("tablink");
+            for (i = 0; i < tablinks.length; i++) {
+                tablinks[i].classList.remove("active");
+            }
+            var current = document.getElementById(tabId);
+            if (current) {
+                current.style.display = "block";
+                current.classList.add("active");
+            }
+            if (evt && evt.currentTarget) {
+                evt.currentTarget.classList.add("active");
+            }
+        }
+        document.addEventListener("DOMContentLoaded", function() {
+            openTab({currentTarget: document.querySelector(".tablink")}, "viz-tab");
+        });
+    </script>
+    """
+    tabs_html = f"""
+    <div class="tabs">
+        <button class="tablink active" onclick="openTab(event, 'viz-tab')">Visualizations</button>
+        <button class="tablink" onclick="openTab(event, 'feature-tab')">Feature Importance</button>
+    </div>
+    <div id="viz-tab" class="tabcontent active">
+        {viz_section}
+    </div>
+    <div id="feature-tab" class="tabcontent">
+        {feature_section}
+    </div>
+    """
     html_content = f"""
     {get_html_template()}
         <h1>{title}</h1>
-        {plots_html}
+        {tabs_style}
+        {tabs_html}
+        {tabs_script}
     {get_html_closing()}
     """
 
@@ -217,4 +618,5 @@
 
     make_visualizations(ludwig_output_directory_name)
     convert_parquet_to_csv(ludwig_output_directory_name)
+    compute_feature_importance(ludwig_output_directory_name)
     generate_html_report("Ludwig Experiment", ludwig_output_directory_name)