changeset 9:9e912fce264c draft default tip

planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
author goeckslab
date Wed, 27 Aug 2025 21:02:48 +0000
parents 85e6f4b2ad18
children
files image_learner.xml image_learner_cli.py utils.py
diffstat 3 files changed, 232 insertions(+), 149 deletions(-) [+]
line wrap: on
line diff
--- a/image_learner.xml	Thu Aug 14 14:53:10 2025 +0000
+++ b/image_learner.xml	Wed Aug 27 21:02:48 2025 +0000
@@ -1,5 +1,5 @@
-<tool id="image_learner" name="Image Learner for Classification" version="0.1.2" profile="22.05">
-    <description>trains and evaluates a image classification model</description>
+<tool id="image_learner" name="Image Learner" version="0.1.2" profile="22.05">
+    <description>trains and evaluates an image classification/regression model</description>
     <requirements>
         <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:latest</container>
     </requirements>
@@ -46,6 +46,9 @@
                         --batch-size "$batch_size"
                     #end if
                     --split-probabilities "$train_split" "$val_split" "$test_split"
+                    #if $threshold
+                        --threshold "$threshold"
+                    #end if
                 #end if
                 #if $augmentation
                     --augmentation "$augmentation"
@@ -144,8 +147,7 @@
         <conditional name="scratch_fine_tune">
             <param name="use_pretrained" type="select"
                 label="Use pretrained weights?"
-                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch.  
-               (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
+                help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)">
                 <option value="false">No</option>
                 <option value="true" selected="true">Yes</option>
             </param>
@@ -317,16 +319,17 @@
                 </element>
             </output_collection>
         </test>
-        </tests>
+    </tests>
     <help>
         <![CDATA[
 **What it does**
-Image Learner for Classification: trains and evaluates a image classification model. 
+Image Learner for Classification/regression: trains and evaluates a image classification/regression model.
 It uses the metadata csv to find the image paths and labels. 
 The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'.
 Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). 
 If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default.
 
+**If the selected label column has more than 10 unique values, the tool will automatically treat the task as a regression problem and apply appropriate metrics (e.g., MSE, RMSE, R²).**
 
 **Outputs**
 The tool will output a trained model in the form of a ludwig_model file,
--- a/image_learner_cli.py	Thu Aug 14 14:53:10 2025 +0000
+++ b/image_learner_cli.py	Wed Aug 27 21:02:48 2025 +0000
@@ -21,7 +21,7 @@
     SPLIT_COLUMN_NAME,
     TEMP_CONFIG_FILENAME,
     TEMP_CSV_FILENAME,
-    TEMP_DIR_PREFIX
+    TEMP_DIR_PREFIX,
 )
 from ludwig.globals import (
     DESCRIPTION_FILE_NAME,
@@ -38,13 +38,13 @@
     encode_image_to_base64,
     get_html_closing,
     get_html_template,
-    get_metrics_help_modal
+    get_metrics_help_modal,
 )
 
 # --- Logging Setup ---
 logging.basicConfig(
     level=logging.INFO,
-    format='%(asctime)s %(levelname)s %(name)s: %(message)s',
+    format="%(asctime)s %(levelname)s %(name)s: %(message)s",
 )
 logger = logging.getLogger("ImageLearner")
 
@@ -67,7 +67,9 @@
         "early_stop",
         "threshold",
     ]
+
     rows = []
+
     for key in display_keys:
         val = config.get(key, None)
         if key == "threshold":
@@ -134,7 +136,9 @@
                         val_str = val
             else:
                 val_str = val if val is not None else "N/A"
-            if val_str == "N/A" and key not in ["task_type"]:  # Skip if N/A for non-essential
+            if val_str == "N/A" and key not in [
+                "task_type"
+            ]:  # Skip if N/A for non-essential
                 continue
         rows.append(
             f"<tr>"
@@ -166,7 +170,7 @@
               <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th>
             </tr></thead>
             <tbody>
-              {''.join(rows)}
+              {"".join(rows)}
             </tbody>
           </table>
         </div><br>
@@ -251,6 +255,7 @@
                 "roc_auc": get_last_value(label_stats, "roc_auc"),
                 "hits_at_k": get_last_value(label_stats, "hits_at_k"),
             }
+
     # Test metrics: dynamic extraction according to exclusions
     test_label_stats = test_stats.get("label", {})
     if not test_label_stats:
@@ -258,11 +263,13 @@
     else:
         combined_stats = test_stats.get("combined", {})
         overall_stats = test_label_stats.get("overall_stats", {})
+
         # Define exclusions
         if output_type == "binary":
             exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"}
         else:
             exclude = {"per_class_stats", "confusion_matrix"}
+
         # 1. Get all scalar test_label_stats not excluded
         test_metrics = {}
         for k, v in test_label_stats.items():
@@ -272,9 +279,11 @@
                 continue
             if isinstance(v, (int, float, str, bool)):
                 test_metrics[k] = v
+
         # 2. Add overall_stats (flattened)
         for k, v in overall_stats.items():
             test_metrics[k] = v
+
         # 3. Optionally include combined/loss if present and not already
         if "loss" in combined_stats and "loss" not in test_metrics:
             test_metrics["loss"] = combined_stats["loss"]
@@ -315,8 +324,10 @@
             te = all_metrics["test"].get(metric_key)
             if all(x is not None for x in [t, v, te]):
                 rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"])
+
     if not rows:
         return "<table><tr><td>No metric values found.</td></tr></table>"
+
     html = (
         "<h2 style='text-align: center;'>Model Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
@@ -331,7 +342,7 @@
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
         )
     html += "</tbody></table></div><br>"
     return html
@@ -357,8 +368,10 @@
             v = all_metrics["validation"].get(metric_key)
             if t is not None and v is not None:
                 rows.append([display_name, f"{t:.4f}", f"{v:.4f}"])
+
     if not rows:
         return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>"
+
     html = (
         "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
@@ -372,7 +385,7 @@
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
         )
     html += "</tbody></table></div><br>"
     return html
@@ -393,8 +406,10 @@
         value = test_metrics[key]
         if value is not None:
             rows.append([display_name, f"{value:.4f}"])
+
     if not rows:
         return "<table><tr><td>No test metric values found.</td></tr></table>"
+
     html = (
         "<h2 style='text-align: center;'>Test Performance Summary</h2>"
         "<div style='display: flex; justify-content: center;'>"
@@ -407,7 +422,7 @@
     for row in rows:
         html += generate_table_row(
             row,
-            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;"
+            "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;",
         )
     html += "</tbody></table></div><br>"
     return html
@@ -436,10 +451,14 @@
             min_samples_per_class = label_counts.min()
             if min_samples_per_class * validation_size < 1:
                 # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size
-                adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class)
+                adjusted_validation_size = min(
+                    validation_size, 1.0 / min_samples_per_class
+                )
                 if adjusted_validation_size != validation_size:
                     validation_size = adjusted_validation_size
-                    logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation")
+                    logger.info(
+                        f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation"
+                    )
             stratify_arr = out.loc[idx_train, label_column]
             logger.info("Using stratified split for validation set")
         else:
@@ -486,7 +505,9 @@
     # initialize split column
     out[split_column] = 0
     if not label_column or label_column not in out.columns:
-        logger.warning("No label column found; using random split without stratification")
+        logger.warning(
+            "No label column found; using random split without stratification"
+        )
         # fall back to simple random assignment
         indices = out.index.tolist()
         np.random.seed(random_state)
@@ -529,7 +550,9 @@
         stratify=out[label_column],
     )
     # second split: separate training and validation from remaining data
-    val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1])
+    val_size_adjusted = split_probabilities[1] / (
+        split_probabilities[0] + split_probabilities[1]
+    )
     train_idx, val_idx = train_test_split(
         train_val_idx,
         test_size=val_size_adjusted,
@@ -541,12 +564,15 @@
     out.loc[val_idx, split_column] = 1
     out.loc[test_idx, split_column] = 2
     logger.info("Successfully applied stratified random split")
-    logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}")
+    logger.info(
+        f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}"
+    )
     return out.astype({split_column: int})
 
 
 class Backend(Protocol):
     """Interface for a machine learning backend."""
+
     def prepare_config(
         self,
         config_params: Dict[str, Any],
@@ -578,12 +604,14 @@
 
 class LudwigDirectBackend:
     """Backend for running Ludwig experiments directly via the internal experiment_cli function."""
+
     def prepare_config(
         self,
         config_params: Dict[str, Any],
         split_config: Dict[str, Any],
     ) -> str:
         logger.info("LudwigDirectBackend: Preparing YAML configuration.")
+
         model_name = config_params.get("model_name", "resnet18")
         use_pretrained = config_params.get("use_pretrained", False)
         fine_tune = config_params.get("fine_tune", False)
@@ -606,7 +634,9 @@
             }
         else:
             encoder_config = {"type": raw_encoder}
+
         batch_size_cfg = batch_size or "auto"
+
         label_column_path = config_params.get("label_column_data_path")
         label_series = None
         if label_column_path is not None and Path(label_column_path).exists():
@@ -614,6 +644,7 @@
                 label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME]
             except Exception as e:
                 logger.warning(f"Could not read label column for task detection: {e}")
+
         if (
             label_series is not None
             and ptypes.is_numeric_dtype(label_series.dtype)
@@ -622,7 +653,9 @@
             task_type = "regression"
         else:
             task_type = "classification"
+
         config_params["task_type"] = task_type
+
         image_feat: Dict[str, Any] = {
             "name": IMAGE_PATH_COLUMN_NAME,
             "type": "image",
@@ -630,6 +663,7 @@
         }
         if config_params.get("augmentation") is not None:
             image_feat["augmentation"] = config_params["augmentation"]
+
         if task_type == "regression":
             output_feat = {
                 "name": LABEL_COLUMN_NAME,
@@ -645,6 +679,7 @@
                 },
             }
             val_metric = config_params.get("validation_metric", "mean_squared_error")
+
         else:
             num_unique_labels = (
                 label_series.nunique() if label_series is not None else 2
@@ -654,6 +689,7 @@
             if output_type == "binary" and config_params.get("threshold") is not None:
                 output_feat["threshold"] = float(config_params["threshold"])
             val_metric = None
+
         conf: Dict[str, Any] = {
             "model_type": "ecd",
             "input_features": [image_feat],
@@ -673,6 +709,7 @@
                 "in_memory": False,
             },
         }
+
         logger.debug("LudwigDirectBackend: Config dict built.")
         try:
             yaml_str = yaml.dump(conf, sort_keys=False, indent=2)
@@ -694,6 +731,7 @@
     ) -> None:
         """Invoke Ludwig's internal experiment_cli function to run the experiment."""
         logger.info("LudwigDirectBackend: Starting experiment execution.")
+
         try:
             from ludwig.experiment import experiment_cli
         except ImportError as e:
@@ -702,7 +740,9 @@
                 exc_info=True,
             )
             raise RuntimeError("Ludwig import failed.") from e
+
         output_dir.mkdir(parents=True, exist_ok=True)
+
         try:
             experiment_cli(
                 dataset=str(dataset_path),
@@ -733,13 +773,16 @@
             output_dir.glob("experiment_run*"),
             key=lambda p: p.stat().st_mtime,
         )
+
         if not exp_dirs:
             logger.warning(f"No experiment run directories found in {output_dir}")
             return None
+
         progress_file = exp_dirs[-1] / "model" / "training_progress.json"
         if not progress_file.exists():
             logger.warning(f"No training_progress.json found in {progress_file}")
             return None
+
         try:
             with progress_file.open("r", encoding="utf-8") as f:
                 data = json.load(f)
@@ -775,6 +818,7 @@
     def generate_plots(self, output_dir: Path) -> None:
         """Generate all registered Ludwig visualizations for the latest experiment run."""
         logger.info("Generating all Ludwig visualizations…")
+
         test_plots = {
             "compare_performance",
             "compare_classifiers_performance_from_prob",
@@ -798,6 +842,7 @@
             "learning_curves",
             "compare_classifiers_performance_subset",
         }
+
         output_dir = Path(output_dir)
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
@@ -807,6 +852,7 @@
             logger.warning(f"No experiment run dirs found in {output_dir}")
             return
         exp_dir = exp_dirs[-1]
+
         viz_dir = exp_dir / "visualizations"
         viz_dir.mkdir(exist_ok=True)
         train_viz = viz_dir / "train"
@@ -821,6 +867,7 @@
         test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME)
         probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME)
         gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME)
+
         dataset_path = None
         split_file = None
         desc = exp_dir / DESCRIPTION_FILE_NAME
@@ -829,6 +876,7 @@
                 cfg = json.load(f)
             dataset_path = _check(Path(cfg.get("dataset", "")))
             split_file = _check(Path(get_split_path(cfg.get("dataset", ""))))
+
         output_feature = ""
         if desc.exists():
             try:
@@ -839,6 +887,7 @@
             with open(test_stats, "r") as f:
                 stats = json.load(f)
             output_feature = next(iter(stats.keys()), "")
+
         viz_registry = get_visualizations_registry()
         for viz_name, viz_func in viz_registry.items():
             if viz_name in train_plots:
@@ -847,6 +896,7 @@
                 viz_dir_plot = test_viz
             else:
                 continue
+
             try:
                 viz_func(
                     training_statistics=[training_stats] if training_stats else [],
@@ -866,6 +916,7 @@
                 logger.info(f"✔ Generated {viz_name}")
             except Exception as e:
                 logger.warning(f"✘ Skipped {viz_name}: {e}")
+
         logger.info(f"All visualizations written to {viz_dir}")
 
     def generate_html_report(
@@ -881,6 +932,7 @@
         report_path = cwd / report_name
         output_dir = Path(output_dir)
         output_type = None
+
         exp_dirs = sorted(
             output_dir.glob("experiment_run*"),
             key=lambda p: p.stat().st_mtime,
@@ -888,11 +940,14 @@
         if not exp_dirs:
             raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}")
         exp_dir = exp_dirs[-1]
+
         base_viz_dir = exp_dir / "visualizations"
         train_viz_dir = base_viz_dir / "train"
         test_viz_dir = base_viz_dir / "test"
+
         html = get_html_template()
         html += f"<h1>{title}</h1>"
+
         metrics_html = ""
         train_val_metrics_html = ""
         test_metrics_html = ""
@@ -918,11 +973,12 @@
             logger.warning(
                 f"Could not load stats for HTML report: {type(e).__name__}: {e}"
             )
+
         config_html = ""
         training_progress = self.get_training_process(output_dir)
         try:
             config_html = format_config_table_html(
-                config, split_info, training_progress
+                config, split_info, training_progress, output_type
             )
         except Exception as e:
             logger.warning(f"Could not load config for HTML report: {e}")
@@ -936,7 +992,8 @@
             imgs = list(dir_path.glob("*.png"))
             # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files ---
             imgs = [
-                img for img in imgs
+                img
+                for img in imgs
                 if not (
                     img.name == "confusion_matrix.png"
                     or img.name.startswith("confusion_matrix__label_top")
@@ -972,7 +1029,9 @@
                 valid_imgs = [img for img in imgs if img.name not in unwanted]
                 img_map = {img.name: img for img in valid_imgs}
                 ordered = [img_map[n] for n in display_order if n in img_map]
-                others = sorted(img for img in valid_imgs if img.name not in display_order)
+                others = sorted(
+                    img for img in valid_imgs if img.name not in display_order
+                )
                 imgs = ordered + others
             else:
                 # regression: just sort whatever's left
@@ -1012,7 +1071,9 @@
                 df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"})
                 # 2) load ground truth for the test split from prepared CSV
                 df_all = pd.read_csv(config["label_column_data_path"])
-                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True)
+                df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][
+                    LABEL_COLUMN_NAME
+                ].reset_index(drop=True)
                 # 3) concatenate side-by-side
                 df_table = pd.concat([df_gt, df_pred], axis=1)
                 df_table.columns = [LABEL_COLUMN_NAME, "prediction"]
@@ -1036,7 +1097,9 @@
             for plot in interactive_plots:
                 # 2) inject the static "roc_curves_from_prediction_statistics.png"
                 if plot["title"] == "ROC-AUC":
-                    static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png"
+                    static_img = (
+                        test_viz_dir / "roc_curves_from_prediction_statistics.png"
+                    )
                     if static_img.exists():
                         b64 = encode_image_to_base64(str(static_img))
                         tab3_content += (
@@ -1054,14 +1117,13 @@
                     + plot["html"]
                 )
             tab3_content += render_img_section(
-                "Test Visualizations",
-                test_viz_dir,
-                output_type
+                "Test Visualizations", test_viz_dir, output_type
             )
         # assemble the tabs and help modal
         tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content)
         modal_html = get_metrics_help_modal()
         html += tabbed_html + modal_html + get_html_closing()
+
         try:
             with open(report_path, "w") as f:
                 f.write(html)
@@ -1069,11 +1131,13 @@
         except Exception as e:
             logger.error(f"Failed to write HTML report: {e}")
             raise
+
         return report_path
 
 
 class WorkflowOrchestrator:
     """Manages the image-classification workflow."""
+
     def __init__(self, args: argparse.Namespace, backend: Backend):
         self.args = args
         self.backend = backend
@@ -1113,16 +1177,19 @@
         """Load CSV, update image paths, handle splits, and write prepared CSV."""
         if not self.temp_dir or not self.image_extract_dir:
             raise RuntimeError("Temp dirs not initialized before data prep.")
+
         try:
             df = pd.read_csv(self.args.csv_file)
             logger.info(f"Loaded CSV: {self.args.csv_file}")
         except Exception:
             logger.error("Error loading CSV file", exc_info=True)
             raise
+
         required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME}
         missing = required - set(df.columns)
         if missing:
             raise ValueError(f"Missing CSV columns: {', '.join(missing)}")
+
         try:
             df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply(
                 lambda p: str((self.image_extract_dir / p).resolve())
@@ -1150,13 +1217,16 @@
                 f"{[int(p * 100) for p in self.args.split_probabilities]}% "
                 f"for train/val/test with balanced label distribution."
             )
+
         final_csv = self.temp_dir / TEMP_CSV_FILENAME
+
         try:
             df.to_csv(final_csv, index=False)
             logger.info(f"Saved prepared data to {final_csv}")
         except Exception:
             logger.error("Error saving prepared CSV", exc_info=True)
             raise
+
         return final_csv, split_config, split_info
 
     def _process_fixed_split(
@@ -1171,6 +1241,7 @@
             )
             if df[SPLIT_COLUMN_NAME].isna().any():
                 logger.warning("Split column contains non-numeric/missing values.")
+
             unique = set(df[SPLIT_COLUMN_NAME].dropna().unique())
             logger.info(f"Unique split values: {unique}")
             if unique == {0, 2}:
@@ -1193,7 +1264,9 @@
                 logger.info("Using fixed split as-is.")
             else:
                 raise ValueError(f"Unexpected split values: {unique}")
+
             return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info
+
         except Exception:
             logger.error("Error processing fixed split", exc_info=True)
             raise
@@ -1209,11 +1282,14 @@
         """Execute the full workflow end-to-end."""
         logger.info("Starting workflow...")
         self.args.output_dir.mkdir(parents=True, exist_ok=True)
+
         try:
             self._create_temp_dirs()
             self._extract_images()
             csv_path, split_cfg, split_info = self._prepare_data()
+
             use_pretrained = self.args.use_pretrained or self.args.fine_tune
+
             backend_args = {
                 "model_name": self.args.model_name,
                 "fine_tune": self.args.fine_tune,
@@ -1230,9 +1306,11 @@
                 "threshold": self.args.threshold,
             }
             yaml_str = self.backend.prepare_config(backend_args, split_cfg)
+
             config_file = self.temp_dir / TEMP_CONFIG_FILENAME
             config_file.write_text(yaml_str)
             logger.info(f"Wrote backend config: {config_file}")
+
             self.backend.run_experiment(
                 csv_path,
                 config_file,
@@ -1374,8 +1452,7 @@
         action=SplitProbAction,
         default=[0.7, 0.1, 0.2],
         help=(
-            "Random split proportions (e.g., 0.7 0.1 0.2)."
-            "Only used if no split column."
+            "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column."
         ),
     )
     parser.add_argument(
@@ -1408,9 +1485,10 @@
         help=(
             "Decision threshold for binary classification (0.0–1.0)."
             "Overrides default 0.5."
-        )
+        ),
     )
     args = parser.parse_args()
+
     if not 0.0 <= args.validation_size <= 1.0:
         parser.error("validation-size must be between 0.0 and 1.0")
     if not args.csv_file.is_file():
@@ -1423,8 +1501,10 @@
             setattr(args, "augmentation", augmentation_setup)
         except ValueError as e:
             parser.error(str(e))
+
     backend_instance = LudwigDirectBackend()
     orchestrator = WorkflowOrchestrator(args, backend_instance)
+
     exit_code = 0
     try:
         orchestrator.run()
@@ -1439,6 +1519,7 @@
 if __name__ == "__main__":
     try:
         import ludwig
+
         logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}")
     except ImportError:
         logger.error(
@@ -1446,4 +1527,5 @@
             "('pip install ludwig[image]')"
         )
         sys.exit(1)
+
     main()
--- a/utils.py	Thu Aug 14 14:53:10 2025 +0000
+++ b/utils.py	Wed Aug 27 21:02:48 2025 +0000
@@ -8,8 +8,6 @@
     <head>
         <meta charset="UTF-8">
         <title>Galaxy-Ludwig Report</title>
-
-        <!-- your existing styles -->
         <style>
           body {
               font-family: Arial, sans-serif;
@@ -328,121 +326,121 @@
         '<div id="metricsHelpModal" class="modal">'
         '  <div class="modal-content">'
         '    <span class="close">×</span>'
-        '    <h2>Model Evaluation Metrics — Help Guide</h2>'
+        "    <h2>Model Evaluation Metrics — Help Guide</h2>"
         '    <div class="metrics-guide">'
-        '      <h3>1) General Metrics (Regression and Classification)</h3>'
-        '      <p><strong>Loss (Regression & Classification):</strong> '
-        'Measures the difference between predicted and actual values, '
-        'optimized during training. Lower is better. '
-        'For regression, this is often Mean Squared Error (MSE) or '
-        'Mean Absolute Error (MAE). For classification, it’s typically '
-        'cross-entropy or log loss.</p>'
-        '      <h3>2) Regression Metrics</h3>'
-        '      <p><strong>Mean Absolute Error (MAE):</strong> '
-        'Average of absolute differences between predicted and actual values, '
-        'in the same units as the target. Use for interpretable error measurement '
-        'when all errors are equally important. Less sensitive to outliers than MSE.</p>'
-        '      <p><strong>Mean Squared Error (MSE):</strong> '
-        'Average of squared differences between predicted and actual values. '
-        'Penalizes larger errors more heavily, useful when large deviations are critical. '
-        'Often used as the loss function in regression.</p>'
-        '      <p><strong>Root Mean Squared Error (RMSE):</strong> '
-        'Square root of MSE, in the same units as the target. '
-        'Balances interpretability and sensitivity to large errors. '
-        'Widely used for regression evaluation.</p>'
-        '      <p><strong>Mean Absolute Percentage Error (MAPE):</strong> '
-        'Average absolute error as a percentage of actual values. '
-        'Scale-independent, ideal for comparing relative errors across datasets. '
-        'Avoid when actual values are near zero.</p>'
-        '      <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> '
-        'Square root of mean squared percentage error. Scale-independent, '
-        'penalizes larger relative errors more than MAPE. Use for forecasting '
-        'or when relative accuracy matters.</p>'
-        '      <p><strong>R² Score:</strong> Proportion of variance in the target '
-        'explained by the model. Ranges from negative infinity to 1 (perfect prediction). '
-        'Use to assess model fit; negative values indicate poor performance '
-        'compared to predicting the mean.</p>'
-        '      <h3>3) Classification Metrics</h3>'
-        '      <p><strong>Accuracy:</strong> Proportion of correct predictions '
-        'among all predictions. Simple but misleading for imbalanced datasets, '
-        'where high accuracy may hide poor performance on minority classes.</p>'
-        '      <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives '
-        'across all classes before computing accuracy. Suitable for multiclass or '
-        'multilabel problems with imbalanced data.</p>'
-        '      <p><strong>Token Accuracy:</strong> Measures how often predicted tokens '
-        '(e.g., in sequences) match true tokens. Common in NLP tasks like text generation '
-        'or token classification.</p>'
-        '      <p><strong>Precision:</strong> Proportion of positive predictions that are '
-        'correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>'
-        '      <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives '
-        'correctly predicted (TP / (TP + FN)). Use when missing positives is risky, '
-        'e.g., disease detection.</p>'
-        '      <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). '
-        'Measures ability to identify negatives. Useful in medical testing to avoid '
-        'false alarms.</p>'
-        '      <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>'
-        '      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric '
-        'across all classes, treating each equally. Best for balanced datasets where '
-        'all classes are equally important.</p>'
-        '      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, '
-        'false positives, and false negatives across all classes before computing. '
-        'Ideal for imbalanced or multilabel classification.</p>'
-        '      <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics '
-        'across classes, weighted by the number of true instances per class. Balances '
-        'class importance based on frequency.</p>'
-        '      <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>'
-        '      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged '
-        'equally across classes. Use for balanced multiclass problems.</p>'
-        '      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC '
-        'using all instances. Best for imbalanced or multilabel classification.</p>'
-        '      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged '
-        'across individual samples. Ideal for multilabel tasks where samples have multiple '
-        'labels.</p>'
-        '      <h3>6) Classification: ROC-AUC Variants</h3>'
-        '      <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. '
-        'AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>'
-        '      <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. '
-        'Suitable for balanced multiclass problems.</p>'
-        '      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions '
-        'across all classes. Useful for imbalanced or multilabel settings.</p>'
-        '      <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>'
-        '      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions '
-        'for positives and negatives, respectively.</p>'
-        '      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions '
-        '— false alarms and missed detections.</p>'
-        '      <h3>8) Classification: Ranking Metrics</h3>'
-        '      <p><strong>Hits at K:</strong> Measures whether the true label is among the '
-        'top-K predictions. Common in recommendation systems and retrieval tasks.</p>'
-        '      <h3>9) Other Metrics (Classification)</h3>'
-        '      <p><strong>Cohen\'s Kappa:</strong> Measures agreement between predicted and '
-        'actual labels, adjusted for chance. Useful for multiclass classification with '
-        'imbalanced data.</p>'
-        '      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure '
-        'using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>'
-        '      <h3>10) Metric Recommendations</h3>'
-        '      <ul>'
-        '        <li><strong>Regression:</strong> Use <strong>RMSE</strong> or '
-        '<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative '
-        'errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or '
-        '<strong>RMSPE</strong> when large errors are critical.</li>'
-        '        <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> '
-        'and <strong>F1</strong> for overall performance.</li>'
-        '        <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, '
-        '<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class '
-        'performance.</li>'
-        '        <li><strong>Multilabel or Imbalanced Classification:</strong> Use '
-        '<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>'
-        '        <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> '
-        'or <strong>Macro ROC-AUC</strong>.</li>'
-        '        <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> '
-        'to account for class imbalance.</li>'
-        '        <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>'
-        '        <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> '
-        'for class-wise performance in classification.</li>'
-        '      </ul>'
-        '    </div>'
-        '  </div>'
-        '</div>'
+        "      <h3>1) General Metrics (Regression and Classification)</h3>"
+        "      <p><strong>Loss (Regression & Classification):</strong> "
+        "Measures the difference between predicted and actual values, "
+        "optimized during training. Lower is better. "
+        "For regression, this is often Mean Squared Error (MSE) or "
+        "Mean Absolute Error (MAE). For classification, it’s typically "
+        "cross-entropy or log loss.</p>"
+        "      <h3>2) Regression Metrics</h3>"
+        "      <p><strong>Mean Absolute Error (MAE):</strong> "
+        "Average of absolute differences between predicted and actual values, "
+        "in the same units as the target. Use for interpretable error measurement "
+        "when all errors are equally important. Less sensitive to outliers than MSE.</p>"
+        "      <p><strong>Mean Squared Error (MSE):</strong> "
+        "Average of squared differences between predicted and actual values. "
+        "Penalizes larger errors more heavily, useful when large deviations are critical. "
+        "Often used as the loss function in regression.</p>"
+        "      <p><strong>Root Mean Squared Error (RMSE):</strong> "
+        "Square root of MSE, in the same units as the target. "
+        "Balances interpretability and sensitivity to large errors. "
+        "Widely used for regression evaluation.</p>"
+        "      <p><strong>Mean Absolute Percentage Error (MAPE):</strong> "
+        "Average absolute error as a percentage of actual values. "
+        "Scale-independent, ideal for comparing relative errors across datasets. "
+        "Avoid when actual values are near zero.</p>"
+        "      <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> "
+        "Square root of mean squared percentage error. Scale-independent, "
+        "penalizes larger relative errors more than MAPE. Use for forecasting "
+        "or when relative accuracy matters.</p>"
+        "      <p><strong>R² Score:</strong> Proportion of variance in the target "
+        "explained by the model. Ranges from negative infinity to 1 (perfect prediction). "
+        "Use to assess model fit; negative values indicate poor performance "
+        "compared to predicting the mean.</p>"
+        "      <h3>3) Classification Metrics</h3>"
+        "      <p><strong>Accuracy:</strong> Proportion of correct predictions "
+        "among all predictions. Simple but misleading for imbalanced datasets, "
+        "where high accuracy may hide poor performance on minority classes.</p>"
+        "      <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives "
+        "across all classes before computing accuracy. Suitable for multiclass or "
+        "multilabel problems with imbalanced data.</p>"
+        "      <p><strong>Token Accuracy:</strong> Measures how often predicted tokens "
+        "(e.g., in sequences) match true tokens. Common in NLP tasks like text generation "
+        "or token classification.</p>"
+        "      <p><strong>Precision:</strong> Proportion of positive predictions that are "
+        "correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>"
+        "      <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives "
+        "correctly predicted (TP / (TP + FN)). Use when missing positives is risky, "
+        "e.g., disease detection.</p>"
+        "      <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). "
+        "Measures ability to identify negatives. Useful in medical testing to avoid "
+        "false alarms.</p>"
+        "      <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>"
+        "      <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric "
+        "across all classes, treating each equally. Best for balanced datasets where "
+        "all classes are equally important.</p>"
+        "      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, "
+        "false positives, and false negatives across all classes before computing. "
+        "Ideal for imbalanced or multilabel classification.</p>"
+        "      <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics "
+        "across classes, weighted by the number of true instances per class. Balances "
+        "class importance based on frequency.</p>"
+        "      <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>"
+        "      <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged "
+        "equally across classes. Use for balanced multiclass problems.</p>"
+        "      <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC "
+        "using all instances. Best for imbalanced or multilabel classification.</p>"
+        "      <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged "
+        "across individual samples. Ideal for multilabel tasks where samples have multiple "
+        "labels.</p>"
+        "      <h3>6) Classification: ROC-AUC Variants</h3>"
+        "      <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. "
+        "AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>"
+        "      <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. "
+        "Suitable for balanced multiclass problems.</p>"
+        "      <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions "
+        "across all classes. Useful for imbalanced or multilabel settings.</p>"
+        "      <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>"
+        "      <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions "
+        "for positives and negatives, respectively.</p>"
+        "      <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions "
+        "— false alarms and missed detections.</p>"
+        "      <h3>8) Classification: Ranking Metrics</h3>"
+        "      <p><strong>Hits at K:</strong> Measures whether the true label is among the "
+        "top-K predictions. Common in recommendation systems and retrieval tasks.</p>"
+        "      <h3>9) Other Metrics (Classification)</h3>"
+        "      <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and "
+        "actual labels, adjusted for chance. Useful for multiclass classification with "
+        "imbalanced data.</p>"
+        "      <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure "
+        "using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>"
+        "      <h3>10) Metric Recommendations</h3>"
+        "      <ul>"
+        "        <li><strong>Regression:</strong> Use <strong>RMSE</strong> or "
+        "<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative "
+        "errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or "
+        "<strong>RMSPE</strong> when large errors are critical.</li>"
+        "        <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> "
+        "and <strong>F1</strong> for overall performance.</li>"
+        "        <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, "
+        "<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class "
+        "performance.</li>"
+        "        <li><strong>Multilabel or Imbalanced Classification:</strong> Use "
+        "<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>"
+        "        <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> "
+        "or <strong>Macro ROC-AUC</strong>.</li>"
+        "        <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> "
+        "to account for class imbalance.</li>"
+        "        <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>"
+        "        <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> "
+        "for class-wise performance in classification.</li>"
+        "      </ul>"
+        "    </div>"
+        "  </div>"
+        "</div>"
     )
     modal_css = (
         "<style>"
@@ -497,17 +495,17 @@
         '  var span = document.getElementsByClassName("close")[0];'
         "  if (openBtn && modal) {"
         "    openBtn.onclick = function() {"
-        "      modal.style.display = \"block\";"
+        '      modal.style.display = "block";'
         "    };"
         "  }"
         "  if (span && modal) {"
         "    span.onclick = function() {"
-        "      modal.style.display = \"none\";"
+        '      modal.style.display = "none";'
         "    };"
         "  }"
         "  window.onclick = function(event) {"
         "    if (event.target == modal) {"
-        "      modal.style.display = \"none\";"
+        '      modal.style.display = "none";'
         "    }"
         "  }"
         "});"