Mercurial > repos > goeckslab > ludwig_evaluate
diff ludwig_experiment.py @ 6:cbea2960cca2 draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit e2ab4c0f9ce8b7a0a48f749ef5dd9899d6c2b1f8
| author | goeckslab |
|---|---|
| date | Sat, 22 Nov 2025 01:16:50 +0000 |
| parents | 777be50bc321 |
| children |
line wrap: on
line diff
--- a/ludwig_experiment.py Sat Sep 06 01:53:06 2025 +0000 +++ b/ludwig_experiment.py Sat Nov 22 01:16:50 2025 +0000 @@ -1,10 +1,15 @@ +import base64 +import html import json import logging import os import pickle +import re import sys +from io import BytesIO import pandas as pd +from ludwig.api import LudwigModel from ludwig.experiment import cli from ludwig.globals import ( DESCRIPTION_FILE_NAME, @@ -21,6 +26,11 @@ get_html_template ) +try: # pragma: no cover - optional dependency in runtime containers + import matplotlib.pyplot as plt +except ImportError: # pragma: no cover + plt = None + logging.basicConfig(level=logging.DEBUG) @@ -158,44 +168,435 @@ LOG.error(f"Error converting Parquet to CSV: {e}") -def generate_html_report(title, ludwig_output_directory_name): - # ludwig_output_directory = os.path.join( - # output_directory, ludwig_output_directory_name) +def _resolve_dataset_path(dataset_path): + if not dataset_path: + return None + + candidates = [dataset_path] + + if not os.path.isabs(dataset_path): + candidates.extend([ + os.path.join(output_directory, dataset_path), + os.path.join(os.getcwd(), dataset_path), + ]) + + for candidate in candidates: + if candidate and os.path.exists(candidate): + return os.path.abspath(candidate) + + return None + + +def _load_dataset_dataframe(dataset_path): + if not dataset_path: + return None + + _, ext = os.path.splitext(dataset_path.lower()) + + try: + if ext in {".csv", ".tsv"}: + sep = "\t" if ext == ".tsv" else "," + return pd.read_csv(dataset_path, sep=sep) + if ext == ".parquet": + return pd.read_parquet(dataset_path) + if ext == ".json": + return pd.read_json(dataset_path) + if ext == ".h5": + return pd.read_hdf(dataset_path) + except Exception as exc: + LOG.warning(f"Unable to load dataset '{dataset_path}': {exc}") + + LOG.warning("Unsupported dataset format for feature importance computation") + return None + + +def sanitize_feature_name(name): + """Mirror Ludwig's get_sanitized_feature_name implementation.""" + return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", str(name)) + + +def _sanitize_dataframe_columns(dataframe): + """Rename dataframe columns to Ludwig-sanitized names for explainability.""" + column_map = {col: sanitize_feature_name(col) for col in dataframe.columns} + + sanitized_df = dataframe.rename(columns=column_map) + if len(set(column_map.values())) != len(column_map.values()): + LOG.warning( + "Column name collision after sanitization; feature importance may be unreliable" + ) + + return sanitized_df + + +def _feature_importance_plot(label_df, label_name, top_n=10, max_abs_importance=None): + """ + Return base64-encoded bar plot for a label's top-N feature importances. + + max_abs_importance lets us pin the x-axis across labels so readers can + compare magnitudes. + """ + if plt is None or label_df.empty: + return "" + + top_features = label_df.nlargest(top_n, "abs_importance") + if top_features.empty: + return "" + + fig, ax = plt.subplots(figsize=(6, 3 + 0.2 * len(top_features))) + ax.barh(top_features["feature"], top_features["abs_importance"], color="#3f8fd2") + ax.set_xlabel("|importance|") + if max_abs_importance and max_abs_importance > 0: + ax.set_xlim(0, max_abs_importance * 1.05) + ax.invert_yaxis() + fig.tight_layout() + + buf = BytesIO() + fig.savefig(buf, format="png", dpi=150) + plt.close(fig) + encoded = base64.b64encode(buf.getvalue()).decode("utf-8") + return encoded + + +def render_feature_importance_table(df: pd.DataFrame) -> str: + """Render a sortable HTML table for feature importance values.""" + if df.empty: + return "" + + columns = list(df.columns) + headers = "".join( + f"<th class='sortable'>{html.escape(str(col).replace('_', ' '))}</th>" + for col in columns + ) + + body_rows = [] + for _, row in df.iterrows(): + cells = [] + for col in columns: + val = row[col] + if isinstance(val, float): + val_str = f"{val:.6f}" + else: + val_str = str(val) + cells.append(f"<td>{html.escape(val_str)}</td>") + body_rows.append("<tr>" + "".join(cells) + "</tr>") + + return ( + "<div class='scroll-rows-30'>" + "<table class='feature-importance-table sortable-table'>" + f"<thead><tr>{headers}</tr></thead>" + f"<tbody>{''.join(body_rows)}</tbody>" + "</table>" + "</div>" + ) + + +def compute_feature_importance(ludwig_output_directory_name, + sample_size=200, + random_seed=42): + ludwig_output_directory = os.path.join( + output_directory, ludwig_output_directory_name) + model_dir = os.path.join(ludwig_output_directory, "model") + + output_csv_path = os.path.join( + ludwig_output_directory, "feature_importance.csv") + + if not os.path.exists(model_dir): + LOG.info("Model directory not found; skipping feature importance computation") + return - # test_statistics_html = "" - # # Read test statistics JSON and convert to HTML table - # try: - # test_statistics_path = os.path.join( - # ludwig_output_directory, TEST_STATISTICS_FILE_NAME) - # with open(test_statistics_path, "r") as f: - # test_statistics = json.load(f) - # test_statistics_html = "<h2>Test Statistics</h2>" - # test_statistics_html += json_to_html_table( - # test_statistics) - # except Exception as e: - # LOG.info(f"Error reading test statistics: {e}") + try: + ludwig_model = LudwigModel.load(model_dir) + except Exception as exc: + LOG.warning(f"Unable to load Ludwig model for explanations: {exc}") + return + + training_metadata = getattr(ludwig_model, "training_set_metadata", {}) + + output_feature_name, dataset_path = get_output_feature_name( + ludwig_output_directory) + + if not output_feature_name or not dataset_path: + LOG.warning("Output feature or dataset path missing; skipping feature importance") + if hasattr(ludwig_model, "close"): + ludwig_model.close() + return + + dataset_full_path = _resolve_dataset_path(dataset_path) + if not dataset_full_path: + LOG.warning(f"Unable to resolve dataset path '{dataset_path}' for explanations") + if hasattr(ludwig_model, "close"): + ludwig_model.close() + return + + dataframe = _load_dataset_dataframe(dataset_full_path) + if dataframe is None or dataframe.empty: + LOG.warning("Dataset unavailable or empty; skipping feature importance") + if hasattr(ludwig_model, "close"): + ludwig_model.close() + return + + dataframe = _sanitize_dataframe_columns(dataframe) + + data_subset = dataframe if len(dataframe) <= sample_size else dataframe.head(sample_size) + sample_df = dataframe.sample( + n=min(sample_size, len(dataframe)), + random_state=random_seed, + replace=False, + ) if len(dataframe) > sample_size else dataframe + + try: + from ludwig.explain.captum import IntegratedGradientsExplainer + except ImportError as exc: + LOG.warning(f"Integrated Gradients explainer unavailable: {exc}") + if hasattr(ludwig_model, "close"): + ludwig_model.close() + return + + sanitized_output_feature = sanitize_feature_name(output_feature_name) + + try: + explainer = IntegratedGradientsExplainer( + ludwig_model, + data_subset, + sample_df, + sanitized_output_feature, + ) + explanations = explainer.explain() + except Exception as exc: + LOG.warning(f"Unable to compute feature importance: {exc}") + if hasattr(ludwig_model, "close"): + ludwig_model.close() + return + + if hasattr(ludwig_model, "close"): + try: + ludwig_model.close() + except Exception: + pass - # Convert visualizations to HTML + label_names = [] + target_metadata = {} + if isinstance(training_metadata, dict): + target_metadata = training_metadata.get(sanitized_output_feature, {}) + + if isinstance(target_metadata, dict): + if "idx2str" in target_metadata: + idx2str = target_metadata["idx2str"] + if isinstance(idx2str, dict): + def _idx_key(item): + idx_key = item[0] + try: + return (0, int(idx_key)) + except (TypeError, ValueError): + return (1, str(idx_key)) + + label_names = [value for key, value in sorted( + idx2str.items(), key=_idx_key)] + else: + label_names = idx2str + elif "str2idx" in target_metadata and isinstance( + target_metadata["str2idx"], dict): + # invert mapping + label_names = [label for label, _ in sorted( + target_metadata["str2idx"].items(), + key=lambda item: item[1])] + + rows = [] + global_explanation = explanations.global_explanation + for label_index, label_explanation in enumerate( + global_explanation.label_explanations): + if label_names and label_index < len(label_names): + label_value = str(label_names[label_index]) + elif len(global_explanation.label_explanations) == 1: + label_value = output_feature_name + else: + label_value = str(label_index) + + for feature in label_explanation.feature_attributions: + rows.append({ + "label": label_value, + "feature": feature.feature_name, + "importance": feature.attribution, + "abs_importance": abs(feature.attribution), + }) + + if not rows: + LOG.warning("No feature importance rows produced") + return + + importance_df = pd.DataFrame(rows) + importance_df.sort_values([ + "label", + "abs_importance" + ], ascending=[True, False], inplace=True) + + importance_df.to_csv(output_csv_path, index=False) + + LOG.info(f"Feature importance saved to {output_csv_path}") + + +def generate_html_report(title, ludwig_output_directory_name): plots_html = "" - if len(os.listdir(viz_output_directory)) > 0: + plot_files = [] + if os.path.isdir(viz_output_directory): + plot_files = sorted(os.listdir(viz_output_directory)) + if plot_files: plots_html = "<h2>Visualizations</h2>" - for plot_file in sorted(os.listdir(viz_output_directory)): + for plot_file in plot_files: plot_path = os.path.join(viz_output_directory, plot_file) if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")): encoded_image = encode_image_to_base64(plot_path) + plot_title = os.path.splitext(plot_file)[0].replace("_", " ") plots_html += ( f'<div class="plot">' - f'<h3>{os.path.splitext(plot_file)[0]}</h3>' + f'<h3>{plot_title}</h3>' '<img src="data:image/png;base64,' f'{encoded_image}" alt="{plot_file}">' f'</div>' ) + feature_importance_html = "" + importance_path = os.path.join( + output_directory, + ludwig_output_directory_name, + "feature_importance.csv", + ) + if os.path.exists(importance_path): + try: + importance_df = pd.read_csv(importance_path) + if not importance_df.empty: + sorted_df = ( + importance_df + .sort_values(["label", "abs_importance"], ascending=[True, False]) + ) + top_rows = ( + sorted_df + .groupby("label", as_index=False) + .head(5) + ) + max_abs_importance = pd.to_numeric( + importance_df.get("abs_importance", pd.Series(dtype=float)), + errors="coerce", + ).max() + if pd.isna(max_abs_importance): + max_abs_importance = None + + plot_sections = [] + for label in sorted(importance_df["label"].unique()): + encoded_plot = _feature_importance_plot( + importance_df[importance_df["label"] == label], + label, + max_abs_importance=max_abs_importance, + ) + if encoded_plot: + plot_sections.append( + f'<div class="plot feature-importance-plot">' + f'<h3>Top features for {label}</h3>' + f'<img src="data:image/png;base64,{encoded_plot}" ' + f'alt="Feature importance plot for {label}">' + f'</div>' + ) + explanation_text = ( + "<p>Feature importance scores come from Ludwig's Integrated Gradients explainer. " + "It interpolates between each example and a neutral baseline sample, summing " + "the change in the model output along that path. Higher |importance| values " + "indicate stronger influence. Plots share a common x-axis to make magnitudes " + "comparable across labels, and the table columns can be sorted for quick scans.</p>" + ) + feature_importance_html = ( + "<h2>Feature Importance</h2>" + + explanation_text + + render_feature_importance_table(top_rows) + + "".join(plot_sections) + ) + except Exception as exc: + LOG.info(f"Unable to embed feature importance table: {exc}") + # Generate the full HTML content + feature_section = feature_importance_html or "<p>No feature importance artifacts were generated.</p>" + viz_section = plots_html or "<p>No visualizations were generated.</p>" + tabs_style = """ + <style> + .tabs { + display: flex; + border-bottom: 2px solid #ccc; + margin-top: 20px; + margin-bottom: 1rem; + } + .tablink { + padding: 9px 18px; + cursor: pointer; + border: 1px solid #ccc; + border-bottom: none; + background: #f9f9f9; + margin-right: 5px; + border-top-left-radius: 8px; + border-top-right-radius: 8px; + font-size: 0.95rem; + font-weight: 500; + font-family: Arial, sans-serif; + color: #4A4A4A; + } + .tablink.active { + background: #ffffff; + font-weight: bold; + } + .tabcontent { + border: 1px solid #ccc; + border-top: none; + padding: 20px; + display: none; + } + .tabcontent.active { + display: block; + } + </style> + """ + tabs_script = """ + <script> + function openTab(evt, tabId) { + var i, tabcontent, tablinks; + tabcontent = document.getElementsByClassName("tabcontent"); + for (i = 0; i < tabcontent.length; i++) { + tabcontent[i].style.display = "none"; + tabcontent[i].classList.remove("active"); + } + tablinks = document.getElementsByClassName("tablink"); + for (i = 0; i < tablinks.length; i++) { + tablinks[i].classList.remove("active"); + } + var current = document.getElementById(tabId); + if (current) { + current.style.display = "block"; + current.classList.add("active"); + } + if (evt && evt.currentTarget) { + evt.currentTarget.classList.add("active"); + } + } + document.addEventListener("DOMContentLoaded", function() { + openTab({currentTarget: document.querySelector(".tablink")}, "viz-tab"); + }); + </script> + """ + tabs_html = f""" + <div class="tabs"> + <button class="tablink active" onclick="openTab(event, 'viz-tab')">Visualizations</button> + <button class="tablink" onclick="openTab(event, 'feature-tab')">Feature Importance</button> + </div> + <div id="viz-tab" class="tabcontent active"> + {viz_section} + </div> + <div id="feature-tab" class="tabcontent"> + {feature_section} + </div> + """ html_content = f""" {get_html_template()} <h1>{title}</h1> - {plots_html} + {tabs_style} + {tabs_html} + {tabs_script} {get_html_closing()} """ @@ -217,4 +618,5 @@ make_visualizations(ludwig_output_directory_name) convert_parquet_to_csv(ludwig_output_directory_name) + compute_feature_importance(ludwig_output_directory_name) generate_html_report("Ludwig Experiment", ludwig_output_directory_name)
