Mercurial > repos > goeckslab > pycaret_predict
changeset 17:c5c324ac29fc draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
| author | goeckslab |
|---|---|
| date | Sat, 06 Dec 2025 14:20:36 +0000 |
| parents | 4fee4504646e |
| children | |
| files | base_model_trainer.py feature_importance.py pycaret_classification.py utils.py |
| diffstat | 4 files changed, 1241 insertions(+), 88 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Fri Nov 28 22:28:26 2025 +0000 +++ b/base_model_trainer.py Sat Dec 06 14:20:36 2025 +0000 @@ -9,7 +9,16 @@ import pandas as pd from feature_help_modal import get_feature_metrics_help_modal from feature_importance import FeatureImportanceAnalyzer -from sklearn.metrics import average_precision_score +from sklearn.metrics import ( + accuracy_score, + average_precision_score, + confusion_matrix, + f1_score, + matthews_corrcoef, + precision_score, + recall_score, + roc_auc_score, +) from utils import ( add_hr_to_html, add_plot_to_html, @@ -387,6 +396,693 @@ with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + def _build_dataset_overview(self): + """ + Build an HTML table showing label counts with labels as rows and splits + (Train / Validation / Test) as columns. Each cell shows count and + percentage of that split. Returns empty string for regression or when + no label data is available. + """ + if self.task_type != "classification": + return "" + + def _safe_series(obj): + try: + return pd.Series(obj).reset_index(drop=True) + except Exception: + return None + + def _get_from_config(keys): + if self.exp is None: + return None + for key in keys: + try: + val = self.exp.get_config(key) + except Exception: + val = getattr(self.exp, key, None) + if val is not None: + return val + return None + + # Prefer PyCaret-configured splits; fall back to raw inputs. + X_train = _get_from_config(["X_train_transformed", "X_train"]) + y_train = _get_from_config(["y_train_transformed", "y_train"]) + y_test_cfg = _get_from_config(["y_test_transformed", "y_test"]) + + if y_train is None and self.data is not None and self.target in self.data.columns: + y_train = self.data[self.target] + + y_train_series = _safe_series(y_train) + + # Build a cross-validation generator to derive a validation subset size. + cv_gen = self._get_cv_generator(y_train_series) + y_train_fold = y_train_series + y_val_fold = None + if cv_gen is not None and y_train_series is not None: + try: + # Use the first fold to approximate Train/Validation split sizes. + splitter = cv_gen.split( + pd.DataFrame(X_train).reset_index(drop=True) + if X_train is not None + else y_train_series, + y_train_series, + ) + train_idx, val_idx = next(iter(splitter)) + y_train_fold = y_train_series.iloc[train_idx].reset_index(drop=True) + y_val_fold = y_train_series.iloc[val_idx].reset_index(drop=True) + except Exception as exc: + LOG.warning("Could not derive validation split for dataset overview: %s", exc) + + # Test labels: prefer PyCaret transformed holdout (single file) or external test. + if self.test_data is not None: + if y_test_cfg is not None: + y_test = y_test_cfg + elif self.target in self.test_data.columns: + y_test = self.test_data[self.target] + else: + y_test = None + else: + y_test = y_test_cfg + + split_map = { + "Train": _safe_series(y_train_fold), + "Validation": _safe_series(y_val_fold), + "Test": _safe_series(y_test), + } + available = {k: v for k, v in split_map.items() if v is not None and not v.empty} + if not available: + return "" + + # Collect all labels across available splits (including NaN) + label_pool = pd.concat( + available.values(), ignore_index=True + ) + labels = pd.unique(label_pool) + + def _count_for_label(series, label): + if series is None or series.empty: + return None, None + total = len(series) + if pd.isna(label): + cnt = series.isna().sum() + else: + cnt = (series == label).sum() + return int(cnt), total + + rows = [] + for label in labels: + row = ["NaN" if pd.isna(label) else str(label)] + for split_name in ["Train", "Validation", "Test"]: + cnt, total = _count_for_label(split_map.get(split_name), label) + if cnt is None or total is None: + cell = "—" + else: + pct = (cnt / total * 100) if total else 0 + cell = f"{cnt} ({pct:.1f}%)" + row.append(cell) + rows.append(row) + + df = pd.DataFrame(rows, columns=["Label", "Train", "Validation", "Test"]) + df.sort_values("Label", inplace=True) + + return ( + "<h2>Dataset Overview</h2>" + + '<div class="table-wrapper">' + + df.to_html( + index=False, + classes=["table", "sortable", "table-dataset-overview"], + ) + + "</div>" + ) + + def _predict_with_thresholds(self, X, y_true): + """ + Generate predictions/probabilities for a split, respecting an optional + probability threshold for binary tasks. Returns a dict with y_true, + y_pred, y_scores (positive-class probs when available), pos_label, + and neg_label. + """ + if X is None or y_true is None: + return None + + y_true_series = pd.Series(y_true).reset_index(drop=True) + classes = list(getattr(self.best_model, "classes_", [])) + if not classes: + try: + classes = pd.unique(y_true_series).tolist() + except Exception: + classes = [] + if len(classes) > 1: + try: + pos_idx = classes.index(1) + except Exception: + pos_idx = 1 + else: + pos_idx = 0 + pos_idx = min(pos_idx, len(classes) - 1) if classes else 0 + pos_label = ( + classes[pos_idx] + if len(classes) > pos_idx and pos_idx >= 0 + else (classes[-1] if classes else 1) + ) + neg_label = None + if len(classes) >= 2: + neg_candidates = [c for c in classes if c != pos_label] + if neg_candidates: + neg_label = neg_candidates[0] + + prob_thresh = getattr(self, "probability_threshold", None) + y_scores = None + try: + proba = self.best_model.predict_proba(X) + y_scores = np.asarray(proba) if proba is not None else None + except Exception: + y_scores = None + + try: + if ( + prob_thresh is not None + and not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 + if neg_label is None and len(classes) > neg_idx: + neg_label = classes[neg_idx] + y_pred = np.where( + y_scores[:, pos_idx] >= prob_thresh, + pos_label, + neg_label if neg_label is not None else 0, + ) + y_scores = y_scores[:, pos_idx] + else: + y_pred = self.best_model.predict(X) + if ( + not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + y_scores = y_scores[:, pos_idx] + except Exception as exc: + LOG.warning( + "Falling back to raw predict while computing performance summary: %s", + exc, + ) + try: + y_pred = self.best_model.predict(X) + except Exception as exc_inner: + LOG.warning( + "Unable to score split after fallback prediction: %s", + exc_inner, + ) + return None + y_scores = None + + y_pred_series = pd.Series(y_pred).reset_index(drop=True) + if y_scores is not None: + y_scores = np.asarray(y_scores) + if y_scores.ndim > 1 and y_scores.shape[1] == 1: + y_scores = y_scores.ravel() + if getattr(self.exp, "is_multiclass", False) and y_scores.ndim > 1: + # Avoid passing multiclass score matrices to ROC/PR utilities + y_scores = None + + return { + "y_true": y_true_series, + "y_pred": y_pred_series, + "y_scores": y_scores, + "pos_label": pos_label, + "neg_label": neg_label, + } + + def _get_cv_generator(self, y_series): + """ + Build a cross-validation splitter that mirrors the experiment's + configuration. Returns None when CV is disabled or not applicable. + """ + if self.task_type != "classification": + return None + + if getattr(self, "cross_validation", None) is False: + return None + + try: + cfg_gen = self.exp.get_config("fold_generator") + if cfg_gen is not None: + return cfg_gen + except Exception: + cfg_gen = None + + folds = ( + getattr(self, "cross_validation_folds", None) + or self.setup_params.get("fold") + or getattr(self.exp, "fold", None) + or 10 + ) + try: + folds = int(folds) + except Exception: + folds = 10 + + try: + y_series = pd.Series(y_series).reset_index(drop=True) + except Exception: + y_series = None + if y_series is None or y_series.empty: + return None + + if folds < 2: + return None + if len(y_series) < folds: + folds = len(y_series) + if folds < 2: + return None + + try: + from sklearn.model_selection import KFold, StratifiedKFold + + if self.task_type == "classification": + return StratifiedKFold( + n_splits=folds, + shuffle=True, + random_state=self.random_seed, + ) + return KFold( + n_splits=folds, + shuffle=True, + random_state=self.random_seed, + ) + except Exception as exc: + LOG.warning("Could not build CV generator: %s", exc) + return None + + def _get_cross_validated_predictions(self, X, y): + """ + Generate cross-validated predictions for the validation split so we + can report validation metrics for the selected best model. + """ + if self.task_type != "classification": + return None + if getattr(self, "cross_validation", None) is False: + return None + if X is None or y is None: + return None + + try: + from sklearn.model_selection import cross_val_predict + except Exception as exc: + LOG.warning("cross_val_predict unavailable: %s", exc) + return None + + y_series = pd.Series(y).reset_index(drop=True) + if y_series.empty: + return None + + cv_gen = self._get_cv_generator(y_series) + if cv_gen is None: + return None + + X_df = pd.DataFrame(X).reset_index(drop=True) + if len(X_df) != len(y_series): + X_df = X_df.iloc[: len(y_series)].reset_index(drop=True) + + classes = list(getattr(self.best_model, "classes_", [])) + if len(classes) > 1: + try: + pos_idx = classes.index(1) + except Exception: + pos_idx = 1 + else: + pos_idx = 0 + pos_idx = min(pos_idx, len(classes) - 1) if classes else 0 + pos_label = ( + classes[pos_idx] if len(classes) > pos_idx else 1 + ) + neg_label = None + if len(classes) >= 2: + neg_candidates = [c for c in classes if c != pos_label] + if neg_candidates: + neg_label = neg_candidates[0] + + prob_thresh = getattr(self, "probability_threshold", None) + n_jobs = getattr(self, "n_jobs", None) + + y_scores = None + if not getattr(self.exp, "is_multiclass", False): + try: + proba = cross_val_predict( + self.best_model, + X_df, + y_series, + cv=cv_gen, + method="predict_proba", + n_jobs=n_jobs, + ) + y_scores = np.asarray(proba) + except Exception as exc: + LOG.debug("Could not compute CV probabilities: %s", exc) + + y_pred = None + if ( + prob_thresh is not None + and not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 + if neg_label is None and len(classes) > neg_idx: + neg_label = classes[neg_idx] + y_pred = np.where( + y_scores[:, pos_idx] >= prob_thresh, + pos_label, + neg_label if neg_label is not None else 0, + ) + y_scores = y_scores[:, pos_idx] + else: + try: + y_pred = cross_val_predict( + self.best_model, + X_df, + y_series, + cv=cv_gen, + method="predict", + n_jobs=n_jobs, + ) + except Exception as exc: + LOG.warning( + "Could not compute cross-validated predictions: %s", + exc, + ) + return None + if ( + not getattr(self.exp, "is_multiclass", False) + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + pos_idx = min(pos_idx, y_scores.shape[1] - 1) + y_scores = y_scores[:, pos_idx] + + if y_scores is not None and getattr(self.exp, "is_multiclass", False): + y_scores = None + + return { + "y_true": y_series, + "y_pred": pd.Series(y_pred).reset_index(drop=True), + "y_scores": y_scores, + "pos_label": pos_label, + "neg_label": neg_label, + } + + def _get_split_predictions_for_report(self): + """ + Collect predictions/probabilities for Train/Validation/Test splits so the + performance table can show consistent metrics across splits. + """ + if self.task_type != "classification": + return {} + + def _get_from_config(keys): + for key in keys: + try: + val = self.exp.get_config(key) + except Exception: + val = getattr(self.exp, key, None) + if val is not None: + return val + return None + + X_train = _get_from_config(["X_train_transformed", "X_train"]) + y_train = _get_from_config(["y_train_transformed", "y_train"]) + X_holdout = _get_from_config(["X_test_transformed", "X_test"]) + y_holdout = _get_from_config(["y_test_transformed", "y_test"]) + + predictions = {} + + # Train metrics (best model on training data) + if X_train is not None and y_train is not None: + try: + train_preds = self._predict_with_thresholds(X_train, y_train) + if train_preds is not None: + predictions["Train"] = train_preds + except Exception as exc: + LOG.warning( + "Could not score Train split for performance summary: %s", + exc, + ) + + # Validation metrics via cross-validation on training data + try: + val_preds = self._get_cross_validated_predictions(X_train, y_train) + if val_preds is not None: + predictions["Validation"] = val_preds + except Exception as exc: + LOG.warning( + "Could not score Validation split for performance summary: %s", + exc, + ) + + # Test metrics (holdout from single file, or provided test file) + X_test = X_holdout + y_test = y_holdout + if (X_test is None or y_test is None) and self.test_data is not None: + try: + X_test = self.test_data.drop(columns=[self.target]) + y_test = self.test_data[self.target] + except Exception as exc: + LOG.warning( + "Could not prepare external test data for performance summary: %s", + exc, + ) + + if X_test is not None and y_test is not None: + try: + test_preds = self._predict_with_thresholds(X_test, y_test) + if test_preds is not None: + predictions["Test"] = test_preds + except Exception as exc: + LOG.warning( + "Could not score Test split for performance summary: %s", + exc, + ) + return predictions + + def _compute_metric_value(self, metric_name, preds, split_name): + """ + Compute a single metric for a given split prediction bundle. + """ + if preds is None: + return None + + y_true = preds["y_true"] + y_pred = preds["y_pred"] + y_scores = preds.get("y_scores") + pos_label = preds.get("pos_label") + neg_label = preds.get("neg_label") + is_multiclass = getattr(self.exp, "is_multiclass", False) + + def _format_binary_labels(series): + if pos_label is None: + return series + try: + return (series == pos_label).astype(int) + except Exception: + return series + + try: + if metric_name == "Accuracy": + return accuracy_score(y_true, y_pred) + if metric_name == "ROC-AUC": + if y_scores is None: + return None + y_true_bin = _format_binary_labels(y_true) + if len(pd.unique(y_true_bin)) < 2: + return None + return roc_auc_score(y_true_bin, y_scores) + if metric_name == "Precision": + if is_multiclass: + return precision_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + try: + return precision_score( + y_true, y_pred, pos_label=pos_label, zero_division=0 + ) + except Exception: + return precision_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + if metric_name == "Recall": + if is_multiclass: + return recall_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + try: + return recall_score( + y_true, y_pred, pos_label=pos_label, zero_division=0 + ) + except Exception: + return recall_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + if metric_name == "F1-Score": + if is_multiclass: + return f1_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + try: + return f1_score( + y_true, y_pred, pos_label=pos_label, zero_division=0 + ) + except Exception: + return f1_score( + y_true, y_pred, average="weighted", zero_division=0 + ) + if metric_name == "PR-AUC": + if y_scores is None: + return None + y_true_bin = _format_binary_labels(y_true) + if len(pd.unique(y_true_bin)) < 2: + return None + return average_precision_score(y_true_bin, y_scores) + if metric_name == "Specificity": + labels = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) + if len(labels) != 2: + return None + if pos_label is None or pos_label not in labels: + pos_label = labels[1] + neg_candidates = [lbl for lbl in labels if lbl != pos_label] + neg_label_final = ( + neg_label if neg_label in labels else (neg_candidates[0] if neg_candidates else None) + ) + if neg_label_final is None: + return None + cm = confusion_matrix( + y_true, y_pred, labels=[neg_label_final, pos_label] + ) + if cm.shape != (2, 2): + return None + tn, fp, fn, tp = cm.ravel() + denom = tn + fp + return (tn / denom) if denom else None + if metric_name == "MCC": + return matthews_corrcoef(y_true, y_pred) + except Exception as exc: + LOG.warning( + "Could not compute %s for %s split: %s", + metric_name, + split_name, + exc, + ) + return None + return None + + def _build_performance_summary_table(self): + """ + Build a Train/Validation/Test metrics table for classification tasks. + Returns empty string when metrics are unavailable or not applicable. + """ + if self.task_type != "classification": + return "" + + split_predictions = self._get_split_predictions_for_report() + validation_best_row = None + try: + if isinstance(self.results, pd.DataFrame) and not self.results.empty: + validation_best_row = self.results.iloc[0] + except Exception: + validation_best_row = None + + if not split_predictions and validation_best_row is None: + return "" + + metric_names = [ + "Accuracy", + "ROC-AUC", + "Precision", + "Recall", + "F1-Score", + "PR-AUC", + "Specificity", + "MCC", + ] + + validation_column_map = { + "Accuracy": ["Accuracy"], + "ROC-AUC": ["ROC-AUC", "AUC"], + "Precision": ["Precision", "Prec.", "Prec"], + "Recall": ["Recall"], + "F1-Score": ["F1-Score", "F1"], + "PR-AUC": ["PR-AUC", "PR-AUC-Weighted", "PRC"], + "Specificity": ["Specificity"], + "MCC": ["MCC"], + } + + def _fmt(value): + if value is None: + return "—" + try: + if isinstance(value, (float, np.floating)) and ( + np.isnan(value) or np.isinf(value) + ): + return "—" + return f"{value:.3f}" + except Exception: + return str(value) + + def _validation_metric(metric_name): + if validation_best_row is None: + return None + cols = validation_column_map.get(metric_name, []) + for col in cols: + if col in validation_best_row: + try: + return validation_best_row[col] + except Exception: + return None + return None + + rows = [] + for metric in metric_names: + row = [metric] + # Train + train_val = self._compute_metric_value( + metric, split_predictions.get("Train"), "Train" + ) + row.append(_fmt(train_val)) + + # Validation from Train & Validation Summary first row; fallback to computed CV. + val_val = _validation_metric(metric) + if val_val is None: + val_val = self._compute_metric_value( + metric, split_predictions.get("Validation"), "Validation" + ) + row.append(_fmt(val_val)) + + # Test + test_val = self._compute_metric_value( + metric, split_predictions.get("Test"), "Test" + ) + row.append(_fmt(test_val)) + rows.append(row) + + df = pd.DataFrame(rows, columns=["Metric", "Train", "Validation", "Test"]) + return ( + "<h2>Model Performance Summary</h2>" + + '<div class="table-wrapper">' + + df.to_html( + index=False, + classes=["table", "sortable", "table-perf-summary"], + ) + + "</div>" + ) + def _resolve_plot_callable(self, key, fig_or_fn, section): """ Safely execute stored plot callables so a single failure does not @@ -521,17 +1217,19 @@ # — Validation Summary & Configuration — val_df = self.results.copy() + dataset_overview_html = self._build_dataset_overview() + performance_summary_html = self._build_performance_summary_table() # mapping raw plot keys to user-friendly titles plot_title_map = { "learning": "Learning Curve", "vc": "Validation Curve", "calibration": "Calibration Curve", "dimension": "Dimensionality Reduction", - "manifold": "Manifold Learning", + "manifold": "t-SNE", "rfe": "Recursive Feature Elimination", "threshold": "Threshold Plot", "percentage_above_below": "Percentage Above vs. Below Cutoff", - "class_report": "Classification Report", + "class_report": "Per-Class Metrics", "pr_auc": "Precision-Recall AUC", "roc_auc": "Receiver Operating Characteristic AUC", "residuals": "Residuals Distribution", @@ -560,10 +1258,16 @@ + "</div>" ) - summary_html += ( - "<h2>Setup Parameters</h2>" + config_html = ( + header + + dataset_overview_html + + performance_summary_html + + "<h2>Setup Parameters</h2>" + '<div class="table-wrapper">' - + df_setup.to_html(index=False, classes="table sortable") + + df_setup.to_html( + index=False, + classes=["table", "sortable", "table-setup-params"], + ) + "</div>" # — Hyperparameters + "<h2>Best Model Hyperparameters</h2>" @@ -571,20 +1275,23 @@ + pd.DataFrame( self.best_model.get_params().items(), columns=["Parameter", "Value"] - ).to_html(index=False, classes="table sortable") + ).to_html( + index=False, + classes=["table", "sortable", "table-hyperparams"], + ) + "</div>" ) # choose summary plots based on task type if self.task_type == "classification": summary_plots = [ + "threshold", "learning", + "calibration", + "rfe", "vc", - "calibration", "dimension", "manifold", - "rfe", - "threshold", "percentage_above_below", ] else: @@ -649,11 +1356,13 @@ else: test_order = [ "confusion_matrix", + "class_report", "roc_auc", "pr_auc", "lift_curve", "cumulative_precision", ] + rendered_test_plots = set() for key in test_order: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: @@ -662,6 +1371,7 @@ ) if fig is None: continue + rendered_test_plots.add(key) title = plot_title_map.get( key, key.replace("_", " ").title() ) @@ -679,6 +1389,8 @@ "class_report", } ): + if name in rendered_test_plots: + continue title = plot_title_map.get( name, name.replace("_", " ").title() ) @@ -750,7 +1462,7 @@ if cap_rows: cap_table = ( "<div class='table-wrapper'>" - "<table class='table sortable'>" + "<table class='table sortable table-fi-scope'>" "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>" "<tbody>" + "".join( @@ -803,7 +1515,13 @@ # 7) Assemble final HTML (three tabs) html = get_html_template() html += "<h1>Tabular Learner Model Report</h1>" - html += build_tabbed_html(summary_html, test_html, feature_html) + html += build_tabbed_html( + summary_html, + test_html, + feature_html, + explainer_html=None, + config_html=config_html, + ) html += get_feature_metrics_help_modal() html += get_html_closing() @@ -823,11 +1541,11 @@ raise NotImplementedError("Subclasses should implement this method") def generate_tree_plots(self): + from explainerdashboard.explainers import RandomForestExplainer from sklearn.ensemble import ( RandomForestClassifier, RandomForestRegressor ) from xgboost import XGBClassifier, XGBRegressor - from explainerdashboard.explainers import RandomForestExplainer LOG.info("Generating tree plots") X_test = self.exp.X_test_transformed.copy()
--- a/feature_importance.py Fri Nov 28 22:28:26 2025 +0000 +++ b/feature_importance.py Sat Dec 06 14:20:36 2025 +0000 @@ -287,24 +287,16 @@ # Background set bg = X_data.sample(min(len(X_data), 100), random_state=42) - predict_fn = ( - model.predict_proba if hasattr(model, "predict_proba") else model.predict - ) + predict_fn = self._get_predict_fn(model) - # Optimized explainer - explainer = None - explainer_label = None - if hasattr(model, "feature_importances_"): - explainer = shap.TreeExplainer( - model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 - ) - explainer_label = "tree_path_dependent" - elif hasattr(model, "coef_"): - explainer = shap.LinearExplainer(model, bg) - explainer_label = "linear" - else: - explainer = shap.Explainer(predict_fn, bg) - explainer_label = explainer.__class__.__name__ + # Optimized explainer based on model type + explainer, explainer_label, tree_based = self._choose_shap_explainer( + model, bg, predict_fn + ) + if explainer is None: + LOG.warning("No suitable SHAP explainer for model %s; skipping SHAP.", model) + self.shap_model_name = None + return try: shap_values = explainer(X_data) @@ -312,7 +304,7 @@ except Exception as e: error_message = str(e) needs_tree_fallback = ( - hasattr(model, "feature_importances_") + tree_based and "does not cover all the leaves" in error_message.lower() ) feature_name_mismatch = "feature names should match" in error_message.lower() @@ -348,7 +340,9 @@ error_message, ) try: - agnostic_explainer = shap.Explainer(predict_fn, bg) + agnostic_explainer = shap.Explainer( + predict_fn, bg, algorithm="permutation" + ) shap_values = agnostic_explainer(X_data) self.shap_model_name = ( f"{agnostic_explainer.__class__.__name__} (fallback)" @@ -485,6 +479,241 @@ with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") + def _get_predict_fn(self, model): + if hasattr(model, "predict_proba"): + return model.predict_proba + if hasattr(model, "decision_function"): + return model.decision_function + return model.predict + + def _choose_shap_explainer(self, model, bg, predict_fn): + """ + Select a SHAP explainer following the prescribed priority order for + algorithms. Returns (explainer, label, is_tree_based). + """ + if model is None: + return None, None, False + + name = model.__class__.__name__ + lname = name.lower() + task = getattr(self, "task_type", None) + + def _permutation(fn): + return shap.Explainer(fn, bg, algorithm="permutation") + + if task == "classification": + # 1) Logistic Regression + if "logisticregression" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 2) Ridge Classifier + if "ridgeclassifier" in lname: + fn = ( + model.decision_function + if hasattr(model, "decision_function") + else predict_fn + ) + return _permutation(fn), "permutation-decision_function", False + + # 3) LDA + if "lineardiscriminantanalysis" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 4) Random Forest + if "randomforestclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 5) Gradient Boosting + if "gradientboostingclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 6) AdaBoost + if "adaboostclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 7) Extra Trees + if "extratreesclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 8) LightGBM + if "lgbmclassifier" in lname: + return ( + shap.TreeExplainer( + model, + bg, + model_output="raw", + feature_perturbation="tree_path_dependent", + n_jobs=-1, + ), + "tree_path_dependent", + True, + ) + + # 9) XGBoost + if "xgbclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 10) CatBoost (classifier) + if "catboost" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 11) KNN + if "kneighborsclassifier" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 12) SVM - linear kernel + if "svc" in lname or "svm" in lname: + kernel = getattr(model, "kernel", None) + if kernel == "linear": + return shap.LinearExplainer(model, bg), "linear", False + return _permutation(predict_fn), "permutation-svm", False + + # 13) Decision Tree + if "decisiontreeclassifier" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # 14) Naive Bayes + if "naive_bayes" in lname or lname.endswith("nb"): + fn = model.predict_proba if hasattr(model, "predict_proba") else predict_fn + return _permutation(fn), "permutation-proba", False + + # 15) QDA + if "quadraticdiscriminantanalysis" in lname: + return _permutation(model.predict_proba), "permutation-proba", False + + # 16) Dummy + if "dummyclassifier" in lname: + return None, None, False + + # Default classification: permutation on predict_fn + return _permutation(predict_fn), "permutation-default", False + + # Regression path + # Linear family + linear_keys = [ + "linearregression", + "lasso", + "ridge", + "elasticnet", + "lars", + "lassolars", + "orthogonalmatchingpursuit", + "bayesianridge", + "ardregression", + "passiveaggressiveregressor", + "theilsenregressor", + "huberregressor", + ] + if any(k in lname for k in linear_keys): + return shap.LinearExplainer(model, bg), "linear", False + + # Kernel ridge / SVR / KNN / MLP / RANSAC (model-agnostic) + if "kernelridge" in lname: + return _permutation(predict_fn), "permutation-kernelridge", False + if "svr" in lname or "svm" in lname: + kernel = getattr(model, "kernel", None) + if kernel == "linear": + return shap.LinearExplainer(model, bg), "linear", False + return _permutation(predict_fn), "permutation-svr", False + if "kneighborsregressor" in lname: + return _permutation(predict_fn), "permutation-knn", False + if "mlpregressor" in lname: + return _permutation(predict_fn), "permutation-mlp", False + if "ransacregressor" in lname: + return _permutation(predict_fn), "permutation-ransac", False + + # Tree-based regressors + tree_class_names = [ + "decisiontreeregressor", + "randomforestregressor", + "extratreesregressor", + "adaboostregressor", + "gradientboostingregressor", + ] + if any(k in lname for k in tree_class_names): + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # Boosting libraries + if "lgbmregressor" in lname or "lightgbm" in lname: + return ( + shap.TreeExplainer( + model, + bg, + model_output="raw", + feature_perturbation="tree_path_dependent", + n_jobs=-1, + ), + "tree_path_dependent", + True, + ) + if "xgbregressor" in lname or "xgboost" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + if "catboost" in lname: + return ( + shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ), + "tree_path_dependent", + True, + ) + + # Default regression: model-agnostic permutation explainer + return _permutation(predict_fn), "permutation-default", False + def run(self): if ( self.exp is None
--- a/pycaret_classification.py Fri Nov 28 22:28:26 2025 +0000 +++ b/pycaret_classification.py Sat Dec 06 14:20:36 2025 +0000 @@ -8,7 +8,14 @@ from base_model_trainer import BaseModelTrainer from dashboard import generate_classifier_explainer_dashboard from pycaret.classification import ClassificationExperiment -from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve +from sklearn.metrics import ( + auc, + confusion_matrix, + matthews_corrcoef, + precision_recall_curve, + precision_recall_fscore_support, + roc_curve, +) from utils import predict_proba LOG = logging.getLogger(__name__) @@ -137,58 +144,36 @@ # a dict to hold the raw Figure objects or callables self.explainer_plots: Dict[str, go.Figure] = {} + y_true, y_pred, label_values, y_scores = self._get_test_predictions() + + # — Classification report (Plotly table) — + try: + fig_report = self._build_classification_report_fig( + y_true, y_pred, label_values + ) + if fig_report is not None: + self.explainer_plots["class_report"] = fig_report + except Exception as e: + LOG.warning(f"Could not generate Plotly classification report: {e}") + + # — Confusion matrix with actual labels — + try: + fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values) + if fig_cm is not None: + self.explainer_plots["confusion_matrix"] = fig_cm + except Exception as e: + LOG.warning(f"Could not generate Plotly confusion matrix: {e}") + # --- Threshold-aware overrides for CM / ROC / PR --- prob_thresh = getattr(self, "probability_threshold", None) # Only for binary classification and when threshold is provided if (prob_thresh is not None) and (not self.exp.is_multiclass): - X = self.exp.X_test_transformed - y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) - - # Get positive-class scores (robust defaults) - classes = list(getattr(self.best_model, "classes_", [0, 1])) - try: - pos_idx = classes.index(1) if 1 in classes else 1 - except Exception: - pos_idx = 1 - - proba = self.best_model.predict_proba(X) - y_scores = proba[:, pos_idx] - - # Derive label names consistently - pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 - neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0 - - # ---- Confusion Matrix @ threshold ---- - try: - y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label) - cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label]) - fig_cm = go.Figure( - data=go.Heatmap( - z=cm, - x=[f"Pred {neg_label}", f"Pred {pos_label}"], - y=[f"True {neg_label}", f"True {pos_label}"], - text=cm, - texttemplate="%{text}", - colorscale="Blues", - showscale=False, - ) - ) - fig_cm.update_layout( - title=f"Confusion Matrix @ threshold={prob_thresh:.2f}", - xaxis_title="Predicted label", - yaxis_title="True label", - ) - _apply_report_layout(fig_cm) - self.explainer_plots["confusion_matrix"] = fig_cm - except Exception as e: - LOG.warning( - f"Threshold-aware confusion matrix failed; falling back: {e}" - ) - # ---- ROC with threshold marker ---- try: - fpr, tpr, thr = roc_curve(y, y_scores) + if y_scores is None: + raise ValueError("Predicted probabilities unavailable") + fpr, tpr, thr = roc_curve(y_true, y_scores) roc_auc = auc(fpr, tpr) fig_roc = go.Figure() fig_roc.add_scatter( @@ -219,7 +204,9 @@ # ---- PR with threshold marker ---- try: - precision, recall, thr_pr = precision_recall_curve(y, y_scores) + if y_scores is None: + raise ValueError("Predicted probabilities unavailable") + precision, recall, thr_pr = precision_recall_curve(y_true, y_scores) pr_auc = auc(recall, precision) fig_pr = go.Figure() fig_pr.add_scatter( @@ -304,3 +291,182 @@ return _plot self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) + + def _get_test_predictions(self): + """ + Return y_true, y_pred, label list, and (optionally) positive-class + probabilities when available. Ensures predictions respect the optional + probability threshold for binary tasks. + """ + y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True) + X_test = self.exp.X_test_transformed + prob_thresh = getattr(self, "probability_threshold", None) + + y_scores = None + try: + proba = self.best_model.predict_proba(X_test) + y_scores = proba + except Exception: + LOG.debug("predict_proba unavailable for test predictions.") + + try: + if ( + prob_thresh is not None + and not self.exp.is_multiclass + and y_scores is not None + and y_scores.ndim == 2 + and y_scores.shape[1] > 1 + ): + classes = list(getattr(self.best_model, "classes_", [])) + try: + pos_idx = classes.index(1) if 1 in classes else 1 + except Exception: + pos_idx = 1 + neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0 + pos_label = classes[pos_idx] if len(classes) > pos_idx else 1 + neg_label = classes[neg_idx] if len(classes) > neg_idx else 0 + y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label) + y_scores = y_scores[:, pos_idx] + else: + y_pred = self.best_model.predict(X_test) + except Exception as exc: + LOG.warning("Falling back to raw predict for test predictions: %s", exc) + y_pred = self.best_model.predict(X_test) + + y_pred = pd.Series(y_pred).reset_index(drop=True) + if y_scores is not None: + y_scores = np.asarray(y_scores) + if y_scores.ndim > 1 and y_scores.shape[1] == 1: + y_scores = y_scores.ravel() + if self.exp.is_multiclass and y_scores.ndim > 1: + # Avoid passing multiclass score matrices to ROC/PR utilities + y_scores = None + label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True)) + return y_true, y_pred, label_values.tolist(), y_scores + + def _threshold_suffix(self) -> str: + """ + Build a suffix like ' (threshold=0.50)' for binary tasks; omit for + multiclass where thresholds are not applied. + """ + if getattr(self, "task_type", None) != "classification": + return "" + if getattr(self.exp, "is_multiclass", False): + return "" + prob_thresh = getattr(self, "probability_threshold", None) + if prob_thresh is None: + return " (threshold=0.50)" + try: + return f" (threshold={float(prob_thresh):.2f})" + except Exception: + return f" (threshold={prob_thresh})" + + def _build_confusion_matrix_fig(self, y_true, y_pred, labels): + def _label_sort_key(lbl): + try: + return (0, float(lbl)) + except Exception: + return (1, str(lbl)) + + ordered_labels = sorted(labels, key=_label_sort_key) + cm = confusion_matrix(y_true, y_pred, labels=ordered_labels) + label_names = [str(lbl) for lbl in ordered_labels] + fig_cm = go.Figure( + data=go.Heatmap( + z=cm, + x=[f"Pred {lbl}" for lbl in label_names], + y=[f"True {lbl}" for lbl in label_names], + text=cm, + texttemplate="%{text}", + colorscale="Blues", + showscale=False, + ) + ) + fig_cm.update_layout( + title=f"Confusion Matrix{self._threshold_suffix()}", + xaxis_title=f"Predicted label ({self.target})", + yaxis_title=f"True label ({self.target})", + ) + fig_cm.update_xaxes( + type="category", + categoryorder="array", + categoryarray=[f"Pred {lbl}" for lbl in label_names], + ) + fig_cm.update_yaxes( + type="category", + categoryorder="array", + categoryarray=[f"True {lbl}" for lbl in label_names], + autorange="reversed", + ) + _apply_report_layout(fig_cm) + return fig_cm + + def _build_classification_report_fig(self, y_true, y_pred, labels): + precision, recall, f1, support = precision_recall_fscore_support( + y_true, y_pred, labels=labels, zero_division=0 + ) + mcc_scores = [] + for lbl in labels: + y_true_bin = (y_true == lbl).astype(int) + y_pred_bin = (y_pred == lbl).astype(int) + try: + mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin) + except Exception: + mcc_val = 0.0 + mcc_scores.append(mcc_val) + + label_names = [str(lbl) for lbl in labels] + metrics = ["precision", "recall", "f1", "support"] + + max_support = float(max(support) if len(support) else 0) + z_rows = [] + text_rows = [] + for i, lbl in enumerate(label_names): + norm_support = (support[i] / max_support) if max_support else 0.0 + z_rows.append( + [ + precision[i], + recall[i], + f1[i], + norm_support, + ] + ) + text_rows.append( + [ + f"{precision[i]:.3f}", + f"{recall[i]:.3f}", + f"{f1[i]:.3f}", + f"{int(support[i])}", + ] + ) + + fig = go.Figure( + data=go.Heatmap( + z=z_rows, + x=metrics, + y=label_names, + colorscale="YlOrRd", + zmin=0, + zmax=1, + colorbar=dict(title="Scale"), + text=text_rows, + texttemplate="%{text}", + hovertemplate="Label=%{y}<br>Metric=%{x}<br>Value=%{text}<extra></extra>", + ) + ) + fig.update_yaxes( + title_text=f"Label ({self.target})", + autorange="reversed", + type="category", + tickmode="array", + tickvals=label_names, + ticktext=label_names, + showgrid=False, + ) + fig.update_xaxes(title_text="", tickangle=45) + fig.update_layout( + title=f"Per-Class Metrics{self._threshold_suffix()}", + margin=dict(l=70, r=60, t=70, b=80), + ) + _apply_report_layout(fig) + return fig
--- a/utils.py Fri Nov 28 22:28:26 2025 +0000 +++ b/utils.py Sat Dec 06 14:20:36 2025 +0000 @@ -65,6 +65,28 @@ color: white; } + /* Center specific numeric columns */ + .table-dataset-overview td:nth-child(n+2), + .table-dataset-overview th:nth-child(n+2) { + text-align: center; + } + .table-perf-summary td:nth-child(n+2), + .table-perf-summary th:nth-child(n+2) { + text-align: center; + } + .table-setup-params td:nth-child(2), + .table-setup-params th:nth-child(2) { + text-align: center; + } + .table-hyperparams td:nth-child(2), + .table-hyperparams th:nth-child(2) { + text-align: center; + } + .table-fi-scope td:nth-child(2), + .table-fi-scope th:nth-child(2) { + text-align: center; + } + .plot { text-align: center; margin: 20px 0; @@ -194,6 +216,7 @@ test_html: str, feature_html: str, explainer_html: Optional[str] = None, + config_html: Optional[str] = None, ) -> str: """ Render the tabbed sections and an always-visible Help button. @@ -202,12 +225,24 @@ css = get_html_template().split("<body>")[1].rsplit("</style>", 1)[0] + "</style>" # Tabs header - tabs = [ - '<div class="tabs">', - '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary and Config</div>', + tabs = ['<div class="tabs">'] + default_active = "summary" + if config_html: + default_active = "config" + tabs.append( + '<div class="tab active" onclick="showTab(\'config\')">Model Config Summary</div>' + ) + tabs.append( + '<div class="tab" onclick="showTab(\'summary\')">Validation Summary</div>' + ) + else: + tabs.append( + '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary</div>' + ) + tabs.extend([ '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>', '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>', - ] + ]) if explainer_html: tabs.append( '<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>' @@ -217,11 +252,16 @@ tabs_section = "\n".join(tabs) # Content - contents = [ - f'<div id="summary" class="tab-content active">{summary_html}</div>', - f'<div id="test" class="tab-content">{test_html}</div>', - f'<div id="feature" class="tab-content">{feature_html}</div>', - ] + contents = [] + if config_html: + contents.append( + f'<div id="config" class="tab-content {"active" if default_active == "config" else ""}">{config_html}</div>' + ) + contents.append( + f'<div id="summary" class="tab-content {"active" if default_active == "summary" else ""}">{summary_html}</div>' + ) + contents.append(f'<div id="test" class="tab-content">{test_html}</div>') + contents.append(f'<div id="feature" class="tab-content">{feature_html}</div>') if explainer_html: contents.append( f'<div id="explainer" class="tab-content">{explainer_html}</div>'
