Mercurial > repos > goeckslab > image_learner
changeset 9:9e912fce264c draft default tip
planemo upload for repository https://github.com/goeckslab/gleam.git commit eace0d7c2b2939029c052991d238a54947d2e191
author | goeckslab |
---|---|
date | Wed, 27 Aug 2025 21:02:48 +0000 |
parents | 85e6f4b2ad18 |
children | |
files | image_learner.xml image_learner_cli.py utils.py |
diffstat | 3 files changed, 232 insertions(+), 149 deletions(-) [+] |
line wrap: on
line diff
--- a/image_learner.xml Thu Aug 14 14:53:10 2025 +0000 +++ b/image_learner.xml Wed Aug 27 21:02:48 2025 +0000 @@ -1,5 +1,5 @@ -<tool id="image_learner" name="Image Learner for Classification" version="0.1.2" profile="22.05"> - <description>trains and evaluates a image classification model</description> +<tool id="image_learner" name="Image Learner" version="0.1.2" profile="22.05"> + <description>trains and evaluates an image classification/regression model</description> <requirements> <container type="docker">quay.io/goeckslab/galaxy-ludwig-gpu:latest</container> </requirements> @@ -46,6 +46,9 @@ --batch-size "$batch_size" #end if --split-probabilities "$train_split" "$val_split" "$test_split" + #if $threshold + --threshold "$threshold" + #end if #end if #if $augmentation --augmentation "$augmentation" @@ -144,8 +147,7 @@ <conditional name="scratch_fine_tune"> <param name="use_pretrained" type="select" label="Use pretrained weights?" - help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. - (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)"> + help="If select no, the encoder, combiner, and decoder will all be initialized and trained from scratch. (e.g. when your images are very different from ImageNet or no suitable pretrained model exists.)"> <option value="false">No</option> <option value="true" selected="true">Yes</option> </param> @@ -317,16 +319,17 @@ </element> </output_collection> </test> - </tests> + </tests> <help> <![CDATA[ **What it does** -Image Learner for Classification: trains and evaluates a image classification model. +Image Learner for Classification/regression: trains and evaluates a image classification/regression model. It uses the metadata csv to find the image paths and labels. The metadata csv should contain a column with the name 'image_path' and a column with the name 'label'. Optionally, you can also add a column with the name 'split' to specify which split each row belongs to (train, val, test). If you do not provide a split column, the tool will automatically split the data into train, val, and test sets based on the proportions you specify or [0.7, 0.1, 0.2] by default. +**If the selected label column has more than 10 unique values, the tool will automatically treat the task as a regression problem and apply appropriate metrics (e.g., MSE, RMSE, R²).** **Outputs** The tool will output a trained model in the form of a ludwig_model file,
--- a/image_learner_cli.py Thu Aug 14 14:53:10 2025 +0000 +++ b/image_learner_cli.py Wed Aug 27 21:02:48 2025 +0000 @@ -21,7 +21,7 @@ SPLIT_COLUMN_NAME, TEMP_CONFIG_FILENAME, TEMP_CSV_FILENAME, - TEMP_DIR_PREFIX + TEMP_DIR_PREFIX, ) from ludwig.globals import ( DESCRIPTION_FILE_NAME, @@ -38,13 +38,13 @@ encode_image_to_base64, get_html_closing, get_html_template, - get_metrics_help_modal + get_metrics_help_modal, ) # --- Logging Setup --- logging.basicConfig( level=logging.INFO, - format='%(asctime)s %(levelname)s %(name)s: %(message)s', + format="%(asctime)s %(levelname)s %(name)s: %(message)s", ) logger = logging.getLogger("ImageLearner") @@ -67,7 +67,9 @@ "early_stop", "threshold", ] + rows = [] + for key in display_keys: val = config.get(key, None) if key == "threshold": @@ -134,7 +136,9 @@ val_str = val else: val_str = val if val is not None else "N/A" - if val_str == "N/A" and key not in ["task_type"]: # Skip if N/A for non-essential + if val_str == "N/A" and key not in [ + "task_type" + ]: # Skip if N/A for non-essential continue rows.append( f"<tr>" @@ -166,7 +170,7 @@ <th style="padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap; overflow: hidden; text-overflow: ellipsis;">Value</th> </tr></thead> <tbody> - {''.join(rows)} + {"".join(rows)} </tbody> </table> </div><br> @@ -251,6 +255,7 @@ "roc_auc": get_last_value(label_stats, "roc_auc"), "hits_at_k": get_last_value(label_stats, "hits_at_k"), } + # Test metrics: dynamic extraction according to exclusions test_label_stats = test_stats.get("label", {}) if not test_label_stats: @@ -258,11 +263,13 @@ else: combined_stats = test_stats.get("combined", {}) overall_stats = test_label_stats.get("overall_stats", {}) + # Define exclusions if output_type == "binary": exclude = {"per_class_stats", "precision_recall_curve", "roc_curve"} else: exclude = {"per_class_stats", "confusion_matrix"} + # 1. Get all scalar test_label_stats not excluded test_metrics = {} for k, v in test_label_stats.items(): @@ -272,9 +279,11 @@ continue if isinstance(v, (int, float, str, bool)): test_metrics[k] = v + # 2. Add overall_stats (flattened) for k, v in overall_stats.items(): test_metrics[k] = v + # 3. Optionally include combined/loss if present and not already if "loss" in combined_stats and "loss" not in test_metrics: test_metrics["loss"] = combined_stats["loss"] @@ -315,8 +324,10 @@ te = all_metrics["test"].get(metric_key) if all(x is not None for x in [t, v, te]): rows.append([display_name, f"{t:.4f}", f"{v:.4f}", f"{te:.4f}"]) + if not rows: return "<table><tr><td>No metric values found.</td></tr></table>" + html = ( "<h2 style='text-align: center;'>Model Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" @@ -331,7 +342,7 @@ for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", ) html += "</tbody></table></div><br>" return html @@ -357,8 +368,10 @@ v = all_metrics["validation"].get(metric_key) if t is not None and v is not None: rows.append([display_name, f"{t:.4f}", f"{v:.4f}"]) + if not rows: return "<table><tr><td>No metric values found for Train/Validation.</td></tr></table>" + html = ( "<h2 style='text-align: center;'>Train/Validation Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" @@ -372,7 +385,7 @@ for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", ) html += "</tbody></table></div><br>" return html @@ -393,8 +406,10 @@ value = test_metrics[key] if value is not None: rows.append([display_name, f"{value:.4f}"]) + if not rows: return "<table><tr><td>No test metric values found.</td></tr></table>" + html = ( "<h2 style='text-align: center;'>Test Performance Summary</h2>" "<div style='display: flex; justify-content: center;'>" @@ -407,7 +422,7 @@ for row in rows: html += generate_table_row( row, - "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;" + "padding: 10px; border: 1px solid #ccc; text-align: center; white-space: nowrap;", ) html += "</tbody></table></div><br>" return html @@ -436,10 +451,14 @@ min_samples_per_class = label_counts.min() if min_samples_per_class * validation_size < 1: # Adjust validation_size to ensure at least 1 sample per class, but do not exceed original validation_size - adjusted_validation_size = min(validation_size, 1.0 / min_samples_per_class) + adjusted_validation_size = min( + validation_size, 1.0 / min_samples_per_class + ) if adjusted_validation_size != validation_size: validation_size = adjusted_validation_size - logger.info(f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation") + logger.info( + f"Adjusted validation_size to {validation_size:.3f} to ensure at least one sample per class in validation" + ) stratify_arr = out.loc[idx_train, label_column] logger.info("Using stratified split for validation set") else: @@ -486,7 +505,9 @@ # initialize split column out[split_column] = 0 if not label_column or label_column not in out.columns: - logger.warning("No label column found; using random split without stratification") + logger.warning( + "No label column found; using random split without stratification" + ) # fall back to simple random assignment indices = out.index.tolist() np.random.seed(random_state) @@ -529,7 +550,9 @@ stratify=out[label_column], ) # second split: separate training and validation from remaining data - val_size_adjusted = split_probabilities[1] / (split_probabilities[0] + split_probabilities[1]) + val_size_adjusted = split_probabilities[1] / ( + split_probabilities[0] + split_probabilities[1] + ) train_idx, val_idx = train_test_split( train_val_idx, test_size=val_size_adjusted, @@ -541,12 +564,15 @@ out.loc[val_idx, split_column] = 1 out.loc[test_idx, split_column] = 2 logger.info("Successfully applied stratified random split") - logger.info(f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}") + logger.info( + f"Split counts: Train={len(train_idx)}, Val={len(val_idx)}, Test={len(test_idx)}" + ) return out.astype({split_column: int}) class Backend(Protocol): """Interface for a machine learning backend.""" + def prepare_config( self, config_params: Dict[str, Any], @@ -578,12 +604,14 @@ class LudwigDirectBackend: """Backend for running Ludwig experiments directly via the internal experiment_cli function.""" + def prepare_config( self, config_params: Dict[str, Any], split_config: Dict[str, Any], ) -> str: logger.info("LudwigDirectBackend: Preparing YAML configuration.") + model_name = config_params.get("model_name", "resnet18") use_pretrained = config_params.get("use_pretrained", False) fine_tune = config_params.get("fine_tune", False) @@ -606,7 +634,9 @@ } else: encoder_config = {"type": raw_encoder} + batch_size_cfg = batch_size or "auto" + label_column_path = config_params.get("label_column_data_path") label_series = None if label_column_path is not None and Path(label_column_path).exists(): @@ -614,6 +644,7 @@ label_series = pd.read_csv(label_column_path)[LABEL_COLUMN_NAME] except Exception as e: logger.warning(f"Could not read label column for task detection: {e}") + if ( label_series is not None and ptypes.is_numeric_dtype(label_series.dtype) @@ -622,7 +653,9 @@ task_type = "regression" else: task_type = "classification" + config_params["task_type"] = task_type + image_feat: Dict[str, Any] = { "name": IMAGE_PATH_COLUMN_NAME, "type": "image", @@ -630,6 +663,7 @@ } if config_params.get("augmentation") is not None: image_feat["augmentation"] = config_params["augmentation"] + if task_type == "regression": output_feat = { "name": LABEL_COLUMN_NAME, @@ -645,6 +679,7 @@ }, } val_metric = config_params.get("validation_metric", "mean_squared_error") + else: num_unique_labels = ( label_series.nunique() if label_series is not None else 2 @@ -654,6 +689,7 @@ if output_type == "binary" and config_params.get("threshold") is not None: output_feat["threshold"] = float(config_params["threshold"]) val_metric = None + conf: Dict[str, Any] = { "model_type": "ecd", "input_features": [image_feat], @@ -673,6 +709,7 @@ "in_memory": False, }, } + logger.debug("LudwigDirectBackend: Config dict built.") try: yaml_str = yaml.dump(conf, sort_keys=False, indent=2) @@ -694,6 +731,7 @@ ) -> None: """Invoke Ludwig's internal experiment_cli function to run the experiment.""" logger.info("LudwigDirectBackend: Starting experiment execution.") + try: from ludwig.experiment import experiment_cli except ImportError as e: @@ -702,7 +740,9 @@ exc_info=True, ) raise RuntimeError("Ludwig import failed.") from e + output_dir.mkdir(parents=True, exist_ok=True) + try: experiment_cli( dataset=str(dataset_path), @@ -733,13 +773,16 @@ output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, ) + if not exp_dirs: logger.warning(f"No experiment run directories found in {output_dir}") return None + progress_file = exp_dirs[-1] / "model" / "training_progress.json" if not progress_file.exists(): logger.warning(f"No training_progress.json found in {progress_file}") return None + try: with progress_file.open("r", encoding="utf-8") as f: data = json.load(f) @@ -775,6 +818,7 @@ def generate_plots(self, output_dir: Path) -> None: """Generate all registered Ludwig visualizations for the latest experiment run.""" logger.info("Generating all Ludwig visualizations…") + test_plots = { "compare_performance", "compare_classifiers_performance_from_prob", @@ -798,6 +842,7 @@ "learning_curves", "compare_classifiers_performance_subset", } + output_dir = Path(output_dir) exp_dirs = sorted( output_dir.glob("experiment_run*"), @@ -807,6 +852,7 @@ logger.warning(f"No experiment run dirs found in {output_dir}") return exp_dir = exp_dirs[-1] + viz_dir = exp_dir / "visualizations" viz_dir.mkdir(exist_ok=True) train_viz = viz_dir / "train" @@ -821,6 +867,7 @@ test_stats = _check(exp_dir / TEST_STATISTICS_FILE_NAME) probs_path = _check(exp_dir / PREDICTIONS_PARQUET_FILE_NAME) gt_metadata = _check(exp_dir / "model" / TRAIN_SET_METADATA_FILE_NAME) + dataset_path = None split_file = None desc = exp_dir / DESCRIPTION_FILE_NAME @@ -829,6 +876,7 @@ cfg = json.load(f) dataset_path = _check(Path(cfg.get("dataset", ""))) split_file = _check(Path(get_split_path(cfg.get("dataset", "")))) + output_feature = "" if desc.exists(): try: @@ -839,6 +887,7 @@ with open(test_stats, "r") as f: stats = json.load(f) output_feature = next(iter(stats.keys()), "") + viz_registry = get_visualizations_registry() for viz_name, viz_func in viz_registry.items(): if viz_name in train_plots: @@ -847,6 +896,7 @@ viz_dir_plot = test_viz else: continue + try: viz_func( training_statistics=[training_stats] if training_stats else [], @@ -866,6 +916,7 @@ logger.info(f"✔ Generated {viz_name}") except Exception as e: logger.warning(f"✘ Skipped {viz_name}: {e}") + logger.info(f"All visualizations written to {viz_dir}") def generate_html_report( @@ -881,6 +932,7 @@ report_path = cwd / report_name output_dir = Path(output_dir) output_type = None + exp_dirs = sorted( output_dir.glob("experiment_run*"), key=lambda p: p.stat().st_mtime, @@ -888,11 +940,14 @@ if not exp_dirs: raise RuntimeError(f"No 'experiment*' dirs found in {output_dir}") exp_dir = exp_dirs[-1] + base_viz_dir = exp_dir / "visualizations" train_viz_dir = base_viz_dir / "train" test_viz_dir = base_viz_dir / "test" + html = get_html_template() html += f"<h1>{title}</h1>" + metrics_html = "" train_val_metrics_html = "" test_metrics_html = "" @@ -918,11 +973,12 @@ logger.warning( f"Could not load stats for HTML report: {type(e).__name__}: {e}" ) + config_html = "" training_progress = self.get_training_process(output_dir) try: config_html = format_config_table_html( - config, split_info, training_progress + config, split_info, training_progress, output_type ) except Exception as e: logger.warning(f"Could not load config for HTML report: {e}") @@ -936,7 +992,8 @@ imgs = list(dir_path.glob("*.png")) # --- EXCLUDE Ludwig's base confusion matrix and any top-N confusion_matrix files --- imgs = [ - img for img in imgs + img + for img in imgs if not ( img.name == "confusion_matrix.png" or img.name.startswith("confusion_matrix__label_top") @@ -972,7 +1029,9 @@ valid_imgs = [img for img in imgs if img.name not in unwanted] img_map = {img.name: img for img in valid_imgs} ordered = [img_map[n] for n in display_order if n in img_map] - others = sorted(img for img in valid_imgs if img.name not in display_order) + others = sorted( + img for img in valid_imgs if img.name not in display_order + ) imgs = ordered + others else: # regression: just sort whatever's left @@ -1012,7 +1071,9 @@ df_pred = df_preds[[pred_col]].rename(columns={pred_col: "prediction"}) # 2) load ground truth for the test split from prepared CSV df_all = pd.read_csv(config["label_column_data_path"]) - df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][LABEL_COLUMN_NAME].reset_index(drop=True) + df_gt = df_all[df_all[SPLIT_COLUMN_NAME] == 2][ + LABEL_COLUMN_NAME + ].reset_index(drop=True) # 3) concatenate side-by-side df_table = pd.concat([df_gt, df_pred], axis=1) df_table.columns = [LABEL_COLUMN_NAME, "prediction"] @@ -1036,7 +1097,9 @@ for plot in interactive_plots: # 2) inject the static "roc_curves_from_prediction_statistics.png" if plot["title"] == "ROC-AUC": - static_img = test_viz_dir / "roc_curves_from_prediction_statistics.png" + static_img = ( + test_viz_dir / "roc_curves_from_prediction_statistics.png" + ) if static_img.exists(): b64 = encode_image_to_base64(str(static_img)) tab3_content += ( @@ -1054,14 +1117,13 @@ + plot["html"] ) tab3_content += render_img_section( - "Test Visualizations", - test_viz_dir, - output_type + "Test Visualizations", test_viz_dir, output_type ) # assemble the tabs and help modal tabbed_html = build_tabbed_html(tab1_content, tab2_content, tab3_content) modal_html = get_metrics_help_modal() html += tabbed_html + modal_html + get_html_closing() + try: with open(report_path, "w") as f: f.write(html) @@ -1069,11 +1131,13 @@ except Exception as e: logger.error(f"Failed to write HTML report: {e}") raise + return report_path class WorkflowOrchestrator: """Manages the image-classification workflow.""" + def __init__(self, args: argparse.Namespace, backend: Backend): self.args = args self.backend = backend @@ -1113,16 +1177,19 @@ """Load CSV, update image paths, handle splits, and write prepared CSV.""" if not self.temp_dir or not self.image_extract_dir: raise RuntimeError("Temp dirs not initialized before data prep.") + try: df = pd.read_csv(self.args.csv_file) logger.info(f"Loaded CSV: {self.args.csv_file}") except Exception: logger.error("Error loading CSV file", exc_info=True) raise + required = {IMAGE_PATH_COLUMN_NAME, LABEL_COLUMN_NAME} missing = required - set(df.columns) if missing: raise ValueError(f"Missing CSV columns: {', '.join(missing)}") + try: df[IMAGE_PATH_COLUMN_NAME] = df[IMAGE_PATH_COLUMN_NAME].apply( lambda p: str((self.image_extract_dir / p).resolve()) @@ -1150,13 +1217,16 @@ f"{[int(p * 100) for p in self.args.split_probabilities]}% " f"for train/val/test with balanced label distribution." ) + final_csv = self.temp_dir / TEMP_CSV_FILENAME + try: df.to_csv(final_csv, index=False) logger.info(f"Saved prepared data to {final_csv}") except Exception: logger.error("Error saving prepared CSV", exc_info=True) raise + return final_csv, split_config, split_info def _process_fixed_split( @@ -1171,6 +1241,7 @@ ) if df[SPLIT_COLUMN_NAME].isna().any(): logger.warning("Split column contains non-numeric/missing values.") + unique = set(df[SPLIT_COLUMN_NAME].dropna().unique()) logger.info(f"Unique split values: {unique}") if unique == {0, 2}: @@ -1193,7 +1264,9 @@ logger.info("Using fixed split as-is.") else: raise ValueError(f"Unexpected split values: {unique}") + return df, {"type": "fixed", "column": SPLIT_COLUMN_NAME}, split_info + except Exception: logger.error("Error processing fixed split", exc_info=True) raise @@ -1209,11 +1282,14 @@ """Execute the full workflow end-to-end.""" logger.info("Starting workflow...") self.args.output_dir.mkdir(parents=True, exist_ok=True) + try: self._create_temp_dirs() self._extract_images() csv_path, split_cfg, split_info = self._prepare_data() + use_pretrained = self.args.use_pretrained or self.args.fine_tune + backend_args = { "model_name": self.args.model_name, "fine_tune": self.args.fine_tune, @@ -1230,9 +1306,11 @@ "threshold": self.args.threshold, } yaml_str = self.backend.prepare_config(backend_args, split_cfg) + config_file = self.temp_dir / TEMP_CONFIG_FILENAME config_file.write_text(yaml_str) logger.info(f"Wrote backend config: {config_file}") + self.backend.run_experiment( csv_path, config_file, @@ -1374,8 +1452,7 @@ action=SplitProbAction, default=[0.7, 0.1, 0.2], help=( - "Random split proportions (e.g., 0.7 0.1 0.2)." - "Only used if no split column." + "Random split proportions (e.g., 0.7 0.1 0.2).Only used if no split column." ), ) parser.add_argument( @@ -1408,9 +1485,10 @@ help=( "Decision threshold for binary classification (0.0–1.0)." "Overrides default 0.5." - ) + ), ) args = parser.parse_args() + if not 0.0 <= args.validation_size <= 1.0: parser.error("validation-size must be between 0.0 and 1.0") if not args.csv_file.is_file(): @@ -1423,8 +1501,10 @@ setattr(args, "augmentation", augmentation_setup) except ValueError as e: parser.error(str(e)) + backend_instance = LudwigDirectBackend() orchestrator = WorkflowOrchestrator(args, backend_instance) + exit_code = 0 try: orchestrator.run() @@ -1439,6 +1519,7 @@ if __name__ == "__main__": try: import ludwig + logger.debug(f"Found Ludwig version: {ludwig.globals.LUDWIG_VERSION}") except ImportError: logger.error( @@ -1446,4 +1527,5 @@ "('pip install ludwig[image]')" ) sys.exit(1) + main()
--- a/utils.py Thu Aug 14 14:53:10 2025 +0000 +++ b/utils.py Wed Aug 27 21:02:48 2025 +0000 @@ -8,8 +8,6 @@ <head> <meta charset="UTF-8"> <title>Galaxy-Ludwig Report</title> - - <!-- your existing styles --> <style> body { font-family: Arial, sans-serif; @@ -328,121 +326,121 @@ '<div id="metricsHelpModal" class="modal">' ' <div class="modal-content">' ' <span class="close">×</span>' - ' <h2>Model Evaluation Metrics — Help Guide</h2>' + " <h2>Model Evaluation Metrics — Help Guide</h2>" ' <div class="metrics-guide">' - ' <h3>1) General Metrics (Regression and Classification)</h3>' - ' <p><strong>Loss (Regression & Classification):</strong> ' - 'Measures the difference between predicted and actual values, ' - 'optimized during training. Lower is better. ' - 'For regression, this is often Mean Squared Error (MSE) or ' - 'Mean Absolute Error (MAE). For classification, it’s typically ' - 'cross-entropy or log loss.</p>' - ' <h3>2) Regression Metrics</h3>' - ' <p><strong>Mean Absolute Error (MAE):</strong> ' - 'Average of absolute differences between predicted and actual values, ' - 'in the same units as the target. Use for interpretable error measurement ' - 'when all errors are equally important. Less sensitive to outliers than MSE.</p>' - ' <p><strong>Mean Squared Error (MSE):</strong> ' - 'Average of squared differences between predicted and actual values. ' - 'Penalizes larger errors more heavily, useful when large deviations are critical. ' - 'Often used as the loss function in regression.</p>' - ' <p><strong>Root Mean Squared Error (RMSE):</strong> ' - 'Square root of MSE, in the same units as the target. ' - 'Balances interpretability and sensitivity to large errors. ' - 'Widely used for regression evaluation.</p>' - ' <p><strong>Mean Absolute Percentage Error (MAPE):</strong> ' - 'Average absolute error as a percentage of actual values. ' - 'Scale-independent, ideal for comparing relative errors across datasets. ' - 'Avoid when actual values are near zero.</p>' - ' <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> ' - 'Square root of mean squared percentage error. Scale-independent, ' - 'penalizes larger relative errors more than MAPE. Use for forecasting ' - 'or when relative accuracy matters.</p>' - ' <p><strong>R² Score:</strong> Proportion of variance in the target ' - 'explained by the model. Ranges from negative infinity to 1 (perfect prediction). ' - 'Use to assess model fit; negative values indicate poor performance ' - 'compared to predicting the mean.</p>' - ' <h3>3) Classification Metrics</h3>' - ' <p><strong>Accuracy:</strong> Proportion of correct predictions ' - 'among all predictions. Simple but misleading for imbalanced datasets, ' - 'where high accuracy may hide poor performance on minority classes.</p>' - ' <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives ' - 'across all classes before computing accuracy. Suitable for multiclass or ' - 'multilabel problems with imbalanced data.</p>' - ' <p><strong>Token Accuracy:</strong> Measures how often predicted tokens ' - '(e.g., in sequences) match true tokens. Common in NLP tasks like text generation ' - 'or token classification.</p>' - ' <p><strong>Precision:</strong> Proportion of positive predictions that are ' - 'correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>' - ' <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives ' - 'correctly predicted (TP / (TP + FN)). Use when missing positives is risky, ' - 'e.g., disease detection.</p>' - ' <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). ' - 'Measures ability to identify negatives. Useful in medical testing to avoid ' - 'false alarms.</p>' - ' <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>' - ' <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric ' - 'across all classes, treating each equally. Best for balanced datasets where ' - 'all classes are equally important.</p>' - ' <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, ' - 'false positives, and false negatives across all classes before computing. ' - 'Ideal for imbalanced or multilabel classification.</p>' - ' <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics ' - 'across classes, weighted by the number of true instances per class. Balances ' - 'class importance based on frequency.</p>' - ' <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>' - ' <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged ' - 'equally across classes. Use for balanced multiclass problems.</p>' - ' <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC ' - 'using all instances. Best for imbalanced or multilabel classification.</p>' - ' <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged ' - 'across individual samples. Ideal for multilabel tasks where samples have multiple ' - 'labels.</p>' - ' <h3>6) Classification: ROC-AUC Variants</h3>' - ' <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. ' - 'AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>' - ' <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. ' - 'Suitable for balanced multiclass problems.</p>' - ' <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions ' - 'across all classes. Useful for imbalanced or multilabel settings.</p>' - ' <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>' - ' <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions ' - 'for positives and negatives, respectively.</p>' - ' <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions ' - '— false alarms and missed detections.</p>' - ' <h3>8) Classification: Ranking Metrics</h3>' - ' <p><strong>Hits at K:</strong> Measures whether the true label is among the ' - 'top-K predictions. Common in recommendation systems and retrieval tasks.</p>' - ' <h3>9) Other Metrics (Classification)</h3>' - ' <p><strong>Cohen\'s Kappa:</strong> Measures agreement between predicted and ' - 'actual labels, adjusted for chance. Useful for multiclass classification with ' - 'imbalanced data.</p>' - ' <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure ' - 'using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>' - ' <h3>10) Metric Recommendations</h3>' - ' <ul>' - ' <li><strong>Regression:</strong> Use <strong>RMSE</strong> or ' - '<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative ' - 'errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or ' - '<strong>RMSPE</strong> when large errors are critical.</li>' - ' <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> ' - 'and <strong>F1</strong> for overall performance.</li>' - ' <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, ' - '<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class ' - 'performance.</li>' - ' <li><strong>Multilabel or Imbalanced Classification:</strong> Use ' - '<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>' - ' <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> ' - 'or <strong>Macro ROC-AUC</strong>.</li>' - ' <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> ' - 'to account for class imbalance.</li>' - ' <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>' - ' <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> ' - 'for class-wise performance in classification.</li>' - ' </ul>' - ' </div>' - ' </div>' - '</div>' + " <h3>1) General Metrics (Regression and Classification)</h3>" + " <p><strong>Loss (Regression & Classification):</strong> " + "Measures the difference between predicted and actual values, " + "optimized during training. Lower is better. " + "For regression, this is often Mean Squared Error (MSE) or " + "Mean Absolute Error (MAE). For classification, it’s typically " + "cross-entropy or log loss.</p>" + " <h3>2) Regression Metrics</h3>" + " <p><strong>Mean Absolute Error (MAE):</strong> " + "Average of absolute differences between predicted and actual values, " + "in the same units as the target. Use for interpretable error measurement " + "when all errors are equally important. Less sensitive to outliers than MSE.</p>" + " <p><strong>Mean Squared Error (MSE):</strong> " + "Average of squared differences between predicted and actual values. " + "Penalizes larger errors more heavily, useful when large deviations are critical. " + "Often used as the loss function in regression.</p>" + " <p><strong>Root Mean Squared Error (RMSE):</strong> " + "Square root of MSE, in the same units as the target. " + "Balances interpretability and sensitivity to large errors. " + "Widely used for regression evaluation.</p>" + " <p><strong>Mean Absolute Percentage Error (MAPE):</strong> " + "Average absolute error as a percentage of actual values. " + "Scale-independent, ideal for comparing relative errors across datasets. " + "Avoid when actual values are near zero.</p>" + " <p><strong>Root Mean Squared Percentage Error (RMSPE):</strong> " + "Square root of mean squared percentage error. Scale-independent, " + "penalizes larger relative errors more than MAPE. Use for forecasting " + "or when relative accuracy matters.</p>" + " <p><strong>R² Score:</strong> Proportion of variance in the target " + "explained by the model. Ranges from negative infinity to 1 (perfect prediction). " + "Use to assess model fit; negative values indicate poor performance " + "compared to predicting the mean.</p>" + " <h3>3) Classification Metrics</h3>" + " <p><strong>Accuracy:</strong> Proportion of correct predictions " + "among all predictions. Simple but misleading for imbalanced datasets, " + "where high accuracy may hide poor performance on minority classes.</p>" + " <p><strong>Micro Accuracy:</strong> Sums true positives and true negatives " + "across all classes before computing accuracy. Suitable for multiclass or " + "multilabel problems with imbalanced data.</p>" + " <p><strong>Token Accuracy:</strong> Measures how often predicted tokens " + "(e.g., in sequences) match true tokens. Common in NLP tasks like text generation " + "or token classification.</p>" + " <p><strong>Precision:</strong> Proportion of positive predictions that are " + "correct (TP / (TP + FP)). Use when false positives are costly, e.g., spam detection.</p>" + " <p><strong>Recall (Sensitivity):</strong> Proportion of actual positives " + "correctly predicted (TP / (TP + FN)). Use when missing positives is risky, " + "e.g., disease detection.</p>" + " <p><strong>Specificity:</strong> True negative rate (TN / (TN + FP)). " + "Measures ability to identify negatives. Useful in medical testing to avoid " + "false alarms.</p>" + " <h3>4) Classification: Macro, Micro, and Weighted Averages</h3>" + " <p><strong>Macro Precision / Recall / F1:</strong> Averages the metric " + "across all classes, treating each equally. Best for balanced datasets where " + "all classes are equally important.</p>" + " <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives, " + "false positives, and false negatives across all classes before computing. " + "Ideal for imbalanced or multilabel classification.</p>" + " <p><strong>Weighted Precision / Recall / F1:</strong> Averages metrics " + "across classes, weighted by the number of true instances per class. Balances " + "class importance based on frequency.</p>" + " <h3>5) Classification: Average Precision (PR-AUC Variants)</h3>" + " <p><strong>Average Precision Macro:</strong> Precision-Recall AUC averaged " + "equally across classes. Use for balanced multiclass problems.</p>" + " <p><strong>Average Precision Micro:</strong> Global Precision-Recall AUC " + "using all instances. Best for imbalanced or multilabel classification.</p>" + " <p><strong>Average Precision Samples:</strong> Precision-Recall AUC averaged " + "across individual samples. Ideal for multilabel tasks where samples have multiple " + "labels.</p>" + " <h3>6) Classification: ROC-AUC Variants</h3>" + " <p><strong>ROC-AUC:</strong> Measures ability to distinguish between classes. " + "AUC = 1 is perfect; 0.5 is random guessing. Use for binary classification.</p>" + " <p><strong>Macro ROC-AUC:</strong> Averages AUC across all classes equally. " + "Suitable for balanced multiclass problems.</p>" + " <p><strong>Micro ROC-AUC:</strong> Computes AUC from aggregated predictions " + "across all classes. Useful for imbalanced or multilabel settings.</p>" + " <h3>7) Classification: Confusion Matrix Stats (Per Class)</h3>" + " <p><strong>True Positives / Negatives (TP / TN):</strong> Correct predictions " + "for positives and negatives, respectively.</p>" + " <p><strong>False Positives / Negatives (FP / FN):</strong> Incorrect predictions " + "— false alarms and missed detections.</p>" + " <h3>8) Classification: Ranking Metrics</h3>" + " <p><strong>Hits at K:</strong> Measures whether the true label is among the " + "top-K predictions. Common in recommendation systems and retrieval tasks.</p>" + " <h3>9) Other Metrics (Classification)</h3>" + " <p><strong>Cohen's Kappa:</strong> Measures agreement between predicted and " + "actual labels, adjusted for chance. Useful for multiclass classification with " + "imbalanced data.</p>" + " <p><strong>Matthews Correlation Coefficient (MCC):</strong> Balanced measure " + "using TP, TN, FP, and FN. Effective for imbalanced datasets.</p>" + " <h3>10) Metric Recommendations</h3>" + " <ul>" + " <li><strong>Regression:</strong> Use <strong>RMSE</strong> or " + "<strong>MAE</strong> for general evaluation, <strong>MAPE</strong> for relative " + "errors, and <strong>R²</strong> to assess model fit. Use <strong>MSE</strong> or " + "<strong>RMSPE</strong> when large errors are critical.</li>" + " <li><strong>Classification (Balanced Data):</strong> Use <strong>Accuracy</strong> " + "and <strong>F1</strong> for overall performance.</li>" + " <li><strong>Classification (Imbalanced Data):</strong> Use <strong>Precision</strong>, " + "<strong>Recall</strong>, and <strong>ROC-AUC</strong> to focus on minority class " + "performance.</li>" + " <li><strong>Multilabel or Imbalanced Classification:</strong> Use " + "<strong>Micro Precision/Recall/F1</strong> or <strong>Micro ROC-AUC</strong>.</li>" + " <li><strong>Balanced Multiclass:</strong> Use <strong>Macro Precision/Recall/F1</strong> " + "or <strong>Macro ROC-AUC</strong>.</li>" + " <li><strong>Class Frequency Matters:</strong> Use <strong>Weighted Precision/Recall/F1</strong> " + "to account for class imbalance.</li>" + " <li><strong>Recommendation/Ranking:</strong> Use <strong>Hits at K</strong> for retrieval tasks.</li>" + " <li><strong>Detailed Analysis:</strong> Use <strong>Confusion Matrix stats</strong> " + "for class-wise performance in classification.</li>" + " </ul>" + " </div>" + " </div>" + "</div>" ) modal_css = ( "<style>" @@ -497,17 +495,17 @@ ' var span = document.getElementsByClassName("close")[0];' " if (openBtn && modal) {" " openBtn.onclick = function() {" - " modal.style.display = \"block\";" + ' modal.style.display = "block";' " };" " }" " if (span && modal) {" " span.onclick = function() {" - " modal.style.display = \"none\";" + ' modal.style.display = "none";' " };" " }" " window.onclick = function(event) {" " if (event.target == modal) {" - " modal.style.display = \"none\";" + ' modal.style.display = "none";' " }" " }" "});"