Mercurial > repos > goeckslab > ludwig_experiment
comparison ludwig_experiment.py @ 8:0ee0bc6736a2 draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit e2ab4c0f9ce8b7a0a48f749ef5dd9899d6c2b1f8
| author | goeckslab |
|---|---|
| date | Sat, 22 Nov 2025 01:17:09 +0000 |
| parents | 78b1e3921576 |
| children |
comparison
equal
deleted
inserted
replaced
| 7:4d0d8859a0f2 | 8:0ee0bc6736a2 |
|---|---|
| 1 import base64 | |
| 2 import html | |
| 1 import json | 3 import json |
| 2 import logging | 4 import logging |
| 3 import os | 5 import os |
| 4 import pickle | 6 import pickle |
| 7 import re | |
| 5 import sys | 8 import sys |
| 9 from io import BytesIO | |
| 6 | 10 |
| 7 import pandas as pd | 11 import pandas as pd |
| 12 from ludwig.api import LudwigModel | |
| 8 from ludwig.experiment import cli | 13 from ludwig.experiment import cli |
| 9 from ludwig.globals import ( | 14 from ludwig.globals import ( |
| 10 DESCRIPTION_FILE_NAME, | 15 DESCRIPTION_FILE_NAME, |
| 11 PREDICTIONS_PARQUET_FILE_NAME, | 16 PREDICTIONS_PARQUET_FILE_NAME, |
| 12 TEST_STATISTICS_FILE_NAME, | 17 TEST_STATISTICS_FILE_NAME, |
| 18 from utils import ( | 23 from utils import ( |
| 19 encode_image_to_base64, | 24 encode_image_to_base64, |
| 20 get_html_closing, | 25 get_html_closing, |
| 21 get_html_template | 26 get_html_template |
| 22 ) | 27 ) |
| 28 | |
| 29 try: # pragma: no cover - optional dependency in runtime containers | |
| 30 import matplotlib.pyplot as plt | |
| 31 except ImportError: # pragma: no cover | |
| 32 plt = None | |
| 23 | 33 |
| 24 | 34 |
| 25 logging.basicConfig(level=logging.DEBUG) | 35 logging.basicConfig(level=logging.DEBUG) |
| 26 | 36 |
| 27 LOG = logging.getLogger(__name__) | 37 LOG = logging.getLogger(__name__) |
| 156 LOG.info(f"Converted Parquet to CSV: {csv_path}") | 166 LOG.info(f"Converted Parquet to CSV: {csv_path}") |
| 157 except Exception as e: | 167 except Exception as e: |
| 158 LOG.error(f"Error converting Parquet to CSV: {e}") | 168 LOG.error(f"Error converting Parquet to CSV: {e}") |
| 159 | 169 |
| 160 | 170 |
| 171 def _resolve_dataset_path(dataset_path): | |
| 172 if not dataset_path: | |
| 173 return None | |
| 174 | |
| 175 candidates = [dataset_path] | |
| 176 | |
| 177 if not os.path.isabs(dataset_path): | |
| 178 candidates.extend([ | |
| 179 os.path.join(output_directory, dataset_path), | |
| 180 os.path.join(os.getcwd(), dataset_path), | |
| 181 ]) | |
| 182 | |
| 183 for candidate in candidates: | |
| 184 if candidate and os.path.exists(candidate): | |
| 185 return os.path.abspath(candidate) | |
| 186 | |
| 187 return None | |
| 188 | |
| 189 | |
| 190 def _load_dataset_dataframe(dataset_path): | |
| 191 if not dataset_path: | |
| 192 return None | |
| 193 | |
| 194 _, ext = os.path.splitext(dataset_path.lower()) | |
| 195 | |
| 196 try: | |
| 197 if ext in {".csv", ".tsv"}: | |
| 198 sep = "\t" if ext == ".tsv" else "," | |
| 199 return pd.read_csv(dataset_path, sep=sep) | |
| 200 if ext == ".parquet": | |
| 201 return pd.read_parquet(dataset_path) | |
| 202 if ext == ".json": | |
| 203 return pd.read_json(dataset_path) | |
| 204 if ext == ".h5": | |
| 205 return pd.read_hdf(dataset_path) | |
| 206 except Exception as exc: | |
| 207 LOG.warning(f"Unable to load dataset '{dataset_path}': {exc}") | |
| 208 | |
| 209 LOG.warning("Unsupported dataset format for feature importance computation") | |
| 210 return None | |
| 211 | |
| 212 | |
| 213 def sanitize_feature_name(name): | |
| 214 """Mirror Ludwig's get_sanitized_feature_name implementation.""" | |
| 215 return re.sub(r"[(){}.:\"\"\'\'\[\]]", "_", str(name)) | |
| 216 | |
| 217 | |
| 218 def _sanitize_dataframe_columns(dataframe): | |
| 219 """Rename dataframe columns to Ludwig-sanitized names for explainability.""" | |
| 220 column_map = {col: sanitize_feature_name(col) for col in dataframe.columns} | |
| 221 | |
| 222 sanitized_df = dataframe.rename(columns=column_map) | |
| 223 if len(set(column_map.values())) != len(column_map.values()): | |
| 224 LOG.warning( | |
| 225 "Column name collision after sanitization; feature importance may be unreliable" | |
| 226 ) | |
| 227 | |
| 228 return sanitized_df | |
| 229 | |
| 230 | |
| 231 def _feature_importance_plot(label_df, label_name, top_n=10, max_abs_importance=None): | |
| 232 """ | |
| 233 Return base64-encoded bar plot for a label's top-N feature importances. | |
| 234 | |
| 235 max_abs_importance lets us pin the x-axis across labels so readers can | |
| 236 compare magnitudes. | |
| 237 """ | |
| 238 if plt is None or label_df.empty: | |
| 239 return "" | |
| 240 | |
| 241 top_features = label_df.nlargest(top_n, "abs_importance") | |
| 242 if top_features.empty: | |
| 243 return "" | |
| 244 | |
| 245 fig, ax = plt.subplots(figsize=(6, 3 + 0.2 * len(top_features))) | |
| 246 ax.barh(top_features["feature"], top_features["abs_importance"], color="#3f8fd2") | |
| 247 ax.set_xlabel("|importance|") | |
| 248 if max_abs_importance and max_abs_importance > 0: | |
| 249 ax.set_xlim(0, max_abs_importance * 1.05) | |
| 250 ax.invert_yaxis() | |
| 251 fig.tight_layout() | |
| 252 | |
| 253 buf = BytesIO() | |
| 254 fig.savefig(buf, format="png", dpi=150) | |
| 255 plt.close(fig) | |
| 256 encoded = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| 257 return encoded | |
| 258 | |
| 259 | |
| 260 def render_feature_importance_table(df: pd.DataFrame) -> str: | |
| 261 """Render a sortable HTML table for feature importance values.""" | |
| 262 if df.empty: | |
| 263 return "" | |
| 264 | |
| 265 columns = list(df.columns) | |
| 266 headers = "".join( | |
| 267 f"<th class='sortable'>{html.escape(str(col).replace('_', ' '))}</th>" | |
| 268 for col in columns | |
| 269 ) | |
| 270 | |
| 271 body_rows = [] | |
| 272 for _, row in df.iterrows(): | |
| 273 cells = [] | |
| 274 for col in columns: | |
| 275 val = row[col] | |
| 276 if isinstance(val, float): | |
| 277 val_str = f"{val:.6f}" | |
| 278 else: | |
| 279 val_str = str(val) | |
| 280 cells.append(f"<td>{html.escape(val_str)}</td>") | |
| 281 body_rows.append("<tr>" + "".join(cells) + "</tr>") | |
| 282 | |
| 283 return ( | |
| 284 "<div class='scroll-rows-30'>" | |
| 285 "<table class='feature-importance-table sortable-table'>" | |
| 286 f"<thead><tr>{headers}</tr></thead>" | |
| 287 f"<tbody>{''.join(body_rows)}</tbody>" | |
| 288 "</table>" | |
| 289 "</div>" | |
| 290 ) | |
| 291 | |
| 292 | |
| 293 def compute_feature_importance(ludwig_output_directory_name, | |
| 294 sample_size=200, | |
| 295 random_seed=42): | |
| 296 ludwig_output_directory = os.path.join( | |
| 297 output_directory, ludwig_output_directory_name) | |
| 298 model_dir = os.path.join(ludwig_output_directory, "model") | |
| 299 | |
| 300 output_csv_path = os.path.join( | |
| 301 ludwig_output_directory, "feature_importance.csv") | |
| 302 | |
| 303 if not os.path.exists(model_dir): | |
| 304 LOG.info("Model directory not found; skipping feature importance computation") | |
| 305 return | |
| 306 | |
| 307 try: | |
| 308 ludwig_model = LudwigModel.load(model_dir) | |
| 309 except Exception as exc: | |
| 310 LOG.warning(f"Unable to load Ludwig model for explanations: {exc}") | |
| 311 return | |
| 312 | |
| 313 training_metadata = getattr(ludwig_model, "training_set_metadata", {}) | |
| 314 | |
| 315 output_feature_name, dataset_path = get_output_feature_name( | |
| 316 ludwig_output_directory) | |
| 317 | |
| 318 if not output_feature_name or not dataset_path: | |
| 319 LOG.warning("Output feature or dataset path missing; skipping feature importance") | |
| 320 if hasattr(ludwig_model, "close"): | |
| 321 ludwig_model.close() | |
| 322 return | |
| 323 | |
| 324 dataset_full_path = _resolve_dataset_path(dataset_path) | |
| 325 if not dataset_full_path: | |
| 326 LOG.warning(f"Unable to resolve dataset path '{dataset_path}' for explanations") | |
| 327 if hasattr(ludwig_model, "close"): | |
| 328 ludwig_model.close() | |
| 329 return | |
| 330 | |
| 331 dataframe = _load_dataset_dataframe(dataset_full_path) | |
| 332 if dataframe is None or dataframe.empty: | |
| 333 LOG.warning("Dataset unavailable or empty; skipping feature importance") | |
| 334 if hasattr(ludwig_model, "close"): | |
| 335 ludwig_model.close() | |
| 336 return | |
| 337 | |
| 338 dataframe = _sanitize_dataframe_columns(dataframe) | |
| 339 | |
| 340 data_subset = dataframe if len(dataframe) <= sample_size else dataframe.head(sample_size) | |
| 341 sample_df = dataframe.sample( | |
| 342 n=min(sample_size, len(dataframe)), | |
| 343 random_state=random_seed, | |
| 344 replace=False, | |
| 345 ) if len(dataframe) > sample_size else dataframe | |
| 346 | |
| 347 try: | |
| 348 from ludwig.explain.captum import IntegratedGradientsExplainer | |
| 349 except ImportError as exc: | |
| 350 LOG.warning(f"Integrated Gradients explainer unavailable: {exc}") | |
| 351 if hasattr(ludwig_model, "close"): | |
| 352 ludwig_model.close() | |
| 353 return | |
| 354 | |
| 355 sanitized_output_feature = sanitize_feature_name(output_feature_name) | |
| 356 | |
| 357 try: | |
| 358 explainer = IntegratedGradientsExplainer( | |
| 359 ludwig_model, | |
| 360 data_subset, | |
| 361 sample_df, | |
| 362 sanitized_output_feature, | |
| 363 ) | |
| 364 explanations = explainer.explain() | |
| 365 except Exception as exc: | |
| 366 LOG.warning(f"Unable to compute feature importance: {exc}") | |
| 367 if hasattr(ludwig_model, "close"): | |
| 368 ludwig_model.close() | |
| 369 return | |
| 370 | |
| 371 if hasattr(ludwig_model, "close"): | |
| 372 try: | |
| 373 ludwig_model.close() | |
| 374 except Exception: | |
| 375 pass | |
| 376 | |
| 377 label_names = [] | |
| 378 target_metadata = {} | |
| 379 if isinstance(training_metadata, dict): | |
| 380 target_metadata = training_metadata.get(sanitized_output_feature, {}) | |
| 381 | |
| 382 if isinstance(target_metadata, dict): | |
| 383 if "idx2str" in target_metadata: | |
| 384 idx2str = target_metadata["idx2str"] | |
| 385 if isinstance(idx2str, dict): | |
| 386 def _idx_key(item): | |
| 387 idx_key = item[0] | |
| 388 try: | |
| 389 return (0, int(idx_key)) | |
| 390 except (TypeError, ValueError): | |
| 391 return (1, str(idx_key)) | |
| 392 | |
| 393 label_names = [value for key, value in sorted( | |
| 394 idx2str.items(), key=_idx_key)] | |
| 395 else: | |
| 396 label_names = idx2str | |
| 397 elif "str2idx" in target_metadata and isinstance( | |
| 398 target_metadata["str2idx"], dict): | |
| 399 # invert mapping | |
| 400 label_names = [label for label, _ in sorted( | |
| 401 target_metadata["str2idx"].items(), | |
| 402 key=lambda item: item[1])] | |
| 403 | |
| 404 rows = [] | |
| 405 global_explanation = explanations.global_explanation | |
| 406 for label_index, label_explanation in enumerate( | |
| 407 global_explanation.label_explanations): | |
| 408 if label_names and label_index < len(label_names): | |
| 409 label_value = str(label_names[label_index]) | |
| 410 elif len(global_explanation.label_explanations) == 1: | |
| 411 label_value = output_feature_name | |
| 412 else: | |
| 413 label_value = str(label_index) | |
| 414 | |
| 415 for feature in label_explanation.feature_attributions: | |
| 416 rows.append({ | |
| 417 "label": label_value, | |
| 418 "feature": feature.feature_name, | |
| 419 "importance": feature.attribution, | |
| 420 "abs_importance": abs(feature.attribution), | |
| 421 }) | |
| 422 | |
| 423 if not rows: | |
| 424 LOG.warning("No feature importance rows produced") | |
| 425 return | |
| 426 | |
| 427 importance_df = pd.DataFrame(rows) | |
| 428 importance_df.sort_values([ | |
| 429 "label", | |
| 430 "abs_importance" | |
| 431 ], ascending=[True, False], inplace=True) | |
| 432 | |
| 433 importance_df.to_csv(output_csv_path, index=False) | |
| 434 | |
| 435 LOG.info(f"Feature importance saved to {output_csv_path}") | |
| 436 | |
| 437 | |
| 161 def generate_html_report(title, ludwig_output_directory_name): | 438 def generate_html_report(title, ludwig_output_directory_name): |
| 162 # ludwig_output_directory = os.path.join( | |
| 163 # output_directory, ludwig_output_directory_name) | |
| 164 | |
| 165 # test_statistics_html = "" | |
| 166 # # Read test statistics JSON and convert to HTML table | |
| 167 # try: | |
| 168 # test_statistics_path = os.path.join( | |
| 169 # ludwig_output_directory, TEST_STATISTICS_FILE_NAME) | |
| 170 # with open(test_statistics_path, "r") as f: | |
| 171 # test_statistics = json.load(f) | |
| 172 # test_statistics_html = "<h2>Test Statistics</h2>" | |
| 173 # test_statistics_html += json_to_html_table( | |
| 174 # test_statistics) | |
| 175 # except Exception as e: | |
| 176 # LOG.info(f"Error reading test statistics: {e}") | |
| 177 | |
| 178 # Convert visualizations to HTML | |
| 179 plots_html = "" | 439 plots_html = "" |
| 180 if len(os.listdir(viz_output_directory)) > 0: | 440 plot_files = [] |
| 441 if os.path.isdir(viz_output_directory): | |
| 442 plot_files = sorted(os.listdir(viz_output_directory)) | |
| 443 if plot_files: | |
| 181 plots_html = "<h2>Visualizations</h2>" | 444 plots_html = "<h2>Visualizations</h2>" |
| 182 for plot_file in sorted(os.listdir(viz_output_directory)): | 445 for plot_file in plot_files: |
| 183 plot_path = os.path.join(viz_output_directory, plot_file) | 446 plot_path = os.path.join(viz_output_directory, plot_file) |
| 184 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")): | 447 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")): |
| 185 encoded_image = encode_image_to_base64(plot_path) | 448 encoded_image = encode_image_to_base64(plot_path) |
| 449 plot_title = os.path.splitext(plot_file)[0].replace("_", " ") | |
| 186 plots_html += ( | 450 plots_html += ( |
| 187 f'<div class="plot">' | 451 f'<div class="plot">' |
| 188 f'<h3>{os.path.splitext(plot_file)[0]}</h3>' | 452 f'<h3>{plot_title}</h3>' |
| 189 '<img src="data:image/png;base64,' | 453 '<img src="data:image/png;base64,' |
| 190 f'{encoded_image}" alt="{plot_file}">' | 454 f'{encoded_image}" alt="{plot_file}">' |
| 191 f'</div>' | 455 f'</div>' |
| 192 ) | 456 ) |
| 193 | 457 |
| 458 feature_importance_html = "" | |
| 459 importance_path = os.path.join( | |
| 460 output_directory, | |
| 461 ludwig_output_directory_name, | |
| 462 "feature_importance.csv", | |
| 463 ) | |
| 464 if os.path.exists(importance_path): | |
| 465 try: | |
| 466 importance_df = pd.read_csv(importance_path) | |
| 467 if not importance_df.empty: | |
| 468 sorted_df = ( | |
| 469 importance_df | |
| 470 .sort_values(["label", "abs_importance"], ascending=[True, False]) | |
| 471 ) | |
| 472 top_rows = ( | |
| 473 sorted_df | |
| 474 .groupby("label", as_index=False) | |
| 475 .head(5) | |
| 476 ) | |
| 477 max_abs_importance = pd.to_numeric( | |
| 478 importance_df.get("abs_importance", pd.Series(dtype=float)), | |
| 479 errors="coerce", | |
| 480 ).max() | |
| 481 if pd.isna(max_abs_importance): | |
| 482 max_abs_importance = None | |
| 483 | |
| 484 plot_sections = [] | |
| 485 for label in sorted(importance_df["label"].unique()): | |
| 486 encoded_plot = _feature_importance_plot( | |
| 487 importance_df[importance_df["label"] == label], | |
| 488 label, | |
| 489 max_abs_importance=max_abs_importance, | |
| 490 ) | |
| 491 if encoded_plot: | |
| 492 plot_sections.append( | |
| 493 f'<div class="plot feature-importance-plot">' | |
| 494 f'<h3>Top features for {label}</h3>' | |
| 495 f'<img src="data:image/png;base64,{encoded_plot}" ' | |
| 496 f'alt="Feature importance plot for {label}">' | |
| 497 f'</div>' | |
| 498 ) | |
| 499 explanation_text = ( | |
| 500 "<p>Feature importance scores come from Ludwig's Integrated Gradients explainer. " | |
| 501 "It interpolates between each example and a neutral baseline sample, summing " | |
| 502 "the change in the model output along that path. Higher |importance| values " | |
| 503 "indicate stronger influence. Plots share a common x-axis to make magnitudes " | |
| 504 "comparable across labels, and the table columns can be sorted for quick scans.</p>" | |
| 505 ) | |
| 506 feature_importance_html = ( | |
| 507 "<h2>Feature Importance</h2>" | |
| 508 + explanation_text | |
| 509 + render_feature_importance_table(top_rows) | |
| 510 + "".join(plot_sections) | |
| 511 ) | |
| 512 except Exception as exc: | |
| 513 LOG.info(f"Unable to embed feature importance table: {exc}") | |
| 514 | |
| 194 # Generate the full HTML content | 515 # Generate the full HTML content |
| 516 feature_section = feature_importance_html or "<p>No feature importance artifacts were generated.</p>" | |
| 517 viz_section = plots_html or "<p>No visualizations were generated.</p>" | |
| 518 tabs_style = """ | |
| 519 <style> | |
| 520 .tabs { | |
| 521 display: flex; | |
| 522 border-bottom: 2px solid #ccc; | |
| 523 margin-top: 20px; | |
| 524 margin-bottom: 1rem; | |
| 525 } | |
| 526 .tablink { | |
| 527 padding: 9px 18px; | |
| 528 cursor: pointer; | |
| 529 border: 1px solid #ccc; | |
| 530 border-bottom: none; | |
| 531 background: #f9f9f9; | |
| 532 margin-right: 5px; | |
| 533 border-top-left-radius: 8px; | |
| 534 border-top-right-radius: 8px; | |
| 535 font-size: 0.95rem; | |
| 536 font-weight: 500; | |
| 537 font-family: Arial, sans-serif; | |
| 538 color: #4A4A4A; | |
| 539 } | |
| 540 .tablink.active { | |
| 541 background: #ffffff; | |
| 542 font-weight: bold; | |
| 543 } | |
| 544 .tabcontent { | |
| 545 border: 1px solid #ccc; | |
| 546 border-top: none; | |
| 547 padding: 20px; | |
| 548 display: none; | |
| 549 } | |
| 550 .tabcontent.active { | |
| 551 display: block; | |
| 552 } | |
| 553 </style> | |
| 554 """ | |
| 555 tabs_script = """ | |
| 556 <script> | |
| 557 function openTab(evt, tabId) { | |
| 558 var i, tabcontent, tablinks; | |
| 559 tabcontent = document.getElementsByClassName("tabcontent"); | |
| 560 for (i = 0; i < tabcontent.length; i++) { | |
| 561 tabcontent[i].style.display = "none"; | |
| 562 tabcontent[i].classList.remove("active"); | |
| 563 } | |
| 564 tablinks = document.getElementsByClassName("tablink"); | |
| 565 for (i = 0; i < tablinks.length; i++) { | |
| 566 tablinks[i].classList.remove("active"); | |
| 567 } | |
| 568 var current = document.getElementById(tabId); | |
| 569 if (current) { | |
| 570 current.style.display = "block"; | |
| 571 current.classList.add("active"); | |
| 572 } | |
| 573 if (evt && evt.currentTarget) { | |
| 574 evt.currentTarget.classList.add("active"); | |
| 575 } | |
| 576 } | |
| 577 document.addEventListener("DOMContentLoaded", function() { | |
| 578 openTab({currentTarget: document.querySelector(".tablink")}, "viz-tab"); | |
| 579 }); | |
| 580 </script> | |
| 581 """ | |
| 582 tabs_html = f""" | |
| 583 <div class="tabs"> | |
| 584 <button class="tablink active" onclick="openTab(event, 'viz-tab')">Visualizations</button> | |
| 585 <button class="tablink" onclick="openTab(event, 'feature-tab')">Feature Importance</button> | |
| 586 </div> | |
| 587 <div id="viz-tab" class="tabcontent active"> | |
| 588 {viz_section} | |
| 589 </div> | |
| 590 <div id="feature-tab" class="tabcontent"> | |
| 591 {feature_section} | |
| 592 </div> | |
| 593 """ | |
| 195 html_content = f""" | 594 html_content = f""" |
| 196 {get_html_template()} | 595 {get_html_template()} |
| 197 <h1>{title}</h1> | 596 <h1>{title}</h1> |
| 198 {plots_html} | 597 {tabs_style} |
| 598 {tabs_html} | |
| 599 {tabs_script} | |
| 199 {get_html_closing()} | 600 {get_html_closing()} |
| 200 """ | 601 """ |
| 201 | 602 |
| 202 # Save the HTML report | 603 # Save the HTML report |
| 203 title: str | 604 title: str |
| 215 | 616 |
| 216 ludwig_output_directory_name = "experiment_run" | 617 ludwig_output_directory_name = "experiment_run" |
| 217 | 618 |
| 218 make_visualizations(ludwig_output_directory_name) | 619 make_visualizations(ludwig_output_directory_name) |
| 219 convert_parquet_to_csv(ludwig_output_directory_name) | 620 convert_parquet_to_csv(ludwig_output_directory_name) |
| 621 compute_feature_importance(ludwig_output_directory_name) | |
| 220 generate_html_report("Ludwig Experiment", ludwig_output_directory_name) | 622 generate_html_report("Ludwig Experiment", ludwig_output_directory_name) |
