comparison pycaret_classification.py @ 13:bf0df21a1ea3 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 7fc20c9ddc2b641975138c9d67b5da240af0484c
author goeckslab
date Sat, 06 Dec 2025 14:20:23 +0000
parents a76dfceb62e0
children
comparison
equal deleted inserted replaced
12:15707141e7da 13:bf0df21a1ea3
6 import pandas as pd 6 import pandas as pd
7 import plotly.graph_objects as go 7 import plotly.graph_objects as go
8 from base_model_trainer import BaseModelTrainer 8 from base_model_trainer import BaseModelTrainer
9 from dashboard import generate_classifier_explainer_dashboard 9 from dashboard import generate_classifier_explainer_dashboard
10 from pycaret.classification import ClassificationExperiment 10 from pycaret.classification import ClassificationExperiment
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve 11 from sklearn.metrics import (
12 auc,
13 confusion_matrix,
14 matthews_corrcoef,
15 precision_recall_curve,
16 precision_recall_fscore_support,
17 roc_curve,
18 )
12 from utils import predict_proba 19 from utils import predict_proba
13 20
14 LOG = logging.getLogger(__name__) 21 LOG = logging.getLogger(__name__)
15 22
16 23
135 explainer = ClassifierExplainer(self.best_model, X_test, y_test) 142 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
136 143
137 # a dict to hold the raw Figure objects or callables 144 # a dict to hold the raw Figure objects or callables
138 self.explainer_plots: Dict[str, go.Figure] = {} 145 self.explainer_plots: Dict[str, go.Figure] = {}
139 146
147 y_true, y_pred, label_values, y_scores = self._get_test_predictions()
148
149 # — Classification report (Plotly table) —
150 try:
151 fig_report = self._build_classification_report_fig(
152 y_true, y_pred, label_values
153 )
154 if fig_report is not None:
155 self.explainer_plots["class_report"] = fig_report
156 except Exception as e:
157 LOG.warning(f"Could not generate Plotly classification report: {e}")
158
159 # — Confusion matrix with actual labels —
160 try:
161 fig_cm = self._build_confusion_matrix_fig(y_true, y_pred, label_values)
162 if fig_cm is not None:
163 self.explainer_plots["confusion_matrix"] = fig_cm
164 except Exception as e:
165 LOG.warning(f"Could not generate Plotly confusion matrix: {e}")
166
140 # --- Threshold-aware overrides for CM / ROC / PR --- 167 # --- Threshold-aware overrides for CM / ROC / PR ---
141 prob_thresh = getattr(self, "probability_threshold", None) 168 prob_thresh = getattr(self, "probability_threshold", None)
142 169
143 # Only for binary classification and when threshold is provided 170 # Only for binary classification and when threshold is provided
144 if (prob_thresh is not None) and (not self.exp.is_multiclass): 171 if (prob_thresh is not None) and (not self.exp.is_multiclass):
145 X = self.exp.X_test_transformed
146 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
147
148 # Get positive-class scores (robust defaults)
149 classes = list(getattr(self.best_model, "classes_", [0, 1]))
150 try:
151 pos_idx = classes.index(1) if 1 in classes else 1
152 except Exception:
153 pos_idx = 1
154
155 proba = self.best_model.predict_proba(X)
156 y_scores = proba[:, pos_idx]
157
158 # Derive label names consistently
159 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
160 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
161
162 # ---- Confusion Matrix @ threshold ----
163 try:
164 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
165 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
166 fig_cm = go.Figure(
167 data=go.Heatmap(
168 z=cm,
169 x=[f"Pred {neg_label}", f"Pred {pos_label}"],
170 y=[f"True {neg_label}", f"True {pos_label}"],
171 text=cm,
172 texttemplate="%{text}",
173 colorscale="Blues",
174 showscale=False,
175 )
176 )
177 fig_cm.update_layout(
178 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
179 xaxis_title="Predicted label",
180 yaxis_title="True label",
181 )
182 _apply_report_layout(fig_cm)
183 self.explainer_plots["confusion_matrix"] = fig_cm
184 except Exception as e:
185 LOG.warning(
186 f"Threshold-aware confusion matrix failed; falling back: {e}"
187 )
188
189 # ---- ROC with threshold marker ---- 172 # ---- ROC with threshold marker ----
190 try: 173 try:
191 fpr, tpr, thr = roc_curve(y, y_scores) 174 if y_scores is None:
175 raise ValueError("Predicted probabilities unavailable")
176 fpr, tpr, thr = roc_curve(y_true, y_scores)
192 roc_auc = auc(fpr, tpr) 177 roc_auc = auc(fpr, tpr)
193 fig_roc = go.Figure() 178 fig_roc = go.Figure()
194 fig_roc.add_scatter( 179 fig_roc.add_scatter(
195 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})" 180 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})"
196 ) 181 )
217 except Exception as e: 202 except Exception as e:
218 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}") 203 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}")
219 204
220 # ---- PR with threshold marker ---- 205 # ---- PR with threshold marker ----
221 try: 206 try:
222 precision, recall, thr_pr = precision_recall_curve(y, y_scores) 207 if y_scores is None:
208 raise ValueError("Predicted probabilities unavailable")
209 precision, recall, thr_pr = precision_recall_curve(y_true, y_scores)
223 pr_auc = auc(recall, precision) 210 pr_auc = auc(recall, precision)
224 fig_pr = go.Figure() 211 fig_pr = go.Figure()
225 fig_pr.add_scatter( 212 fig_pr.add_scatter(
226 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})" 213 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})"
227 ) 214 )
302 return None 289 return None
303 290
304 return _plot 291 return _plot
305 292
306 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) 293 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
294
295 def _get_test_predictions(self):
296 """
297 Return y_true, y_pred, label list, and (optionally) positive-class
298 probabilities when available. Ensures predictions respect the optional
299 probability threshold for binary tasks.
300 """
301 y_true = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
302 X_test = self.exp.X_test_transformed
303 prob_thresh = getattr(self, "probability_threshold", None)
304
305 y_scores = None
306 try:
307 proba = self.best_model.predict_proba(X_test)
308 y_scores = proba
309 except Exception:
310 LOG.debug("predict_proba unavailable for test predictions.")
311
312 try:
313 if (
314 prob_thresh is not None
315 and not self.exp.is_multiclass
316 and y_scores is not None
317 and y_scores.ndim == 2
318 and y_scores.shape[1] > 1
319 ):
320 classes = list(getattr(self.best_model, "classes_", []))
321 try:
322 pos_idx = classes.index(1) if 1 in classes else 1
323 except Exception:
324 pos_idx = 1
325 neg_idx = 1 - pos_idx if y_scores.shape[1] > 1 else 0
326 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
327 neg_label = classes[neg_idx] if len(classes) > neg_idx else 0
328 y_pred = np.where(y_scores[:, pos_idx] >= prob_thresh, pos_label, neg_label)
329 y_scores = y_scores[:, pos_idx]
330 else:
331 y_pred = self.best_model.predict(X_test)
332 except Exception as exc:
333 LOG.warning("Falling back to raw predict for test predictions: %s", exc)
334 y_pred = self.best_model.predict(X_test)
335
336 y_pred = pd.Series(y_pred).reset_index(drop=True)
337 if y_scores is not None:
338 y_scores = np.asarray(y_scores)
339 if y_scores.ndim > 1 and y_scores.shape[1] == 1:
340 y_scores = y_scores.ravel()
341 if self.exp.is_multiclass and y_scores.ndim > 1:
342 # Avoid passing multiclass score matrices to ROC/PR utilities
343 y_scores = None
344 label_values = pd.unique(pd.concat([y_true, y_pred], ignore_index=True))
345 return y_true, y_pred, label_values.tolist(), y_scores
346
347 def _threshold_suffix(self) -> str:
348 """
349 Build a suffix like ' (threshold=0.50)' for binary tasks; omit for
350 multiclass where thresholds are not applied.
351 """
352 if getattr(self, "task_type", None) != "classification":
353 return ""
354 if getattr(self.exp, "is_multiclass", False):
355 return ""
356 prob_thresh = getattr(self, "probability_threshold", None)
357 if prob_thresh is None:
358 return " (threshold=0.50)"
359 try:
360 return f" (threshold={float(prob_thresh):.2f})"
361 except Exception:
362 return f" (threshold={prob_thresh})"
363
364 def _build_confusion_matrix_fig(self, y_true, y_pred, labels):
365 def _label_sort_key(lbl):
366 try:
367 return (0, float(lbl))
368 except Exception:
369 return (1, str(lbl))
370
371 ordered_labels = sorted(labels, key=_label_sort_key)
372 cm = confusion_matrix(y_true, y_pred, labels=ordered_labels)
373 label_names = [str(lbl) for lbl in ordered_labels]
374 fig_cm = go.Figure(
375 data=go.Heatmap(
376 z=cm,
377 x=[f"Pred {lbl}" for lbl in label_names],
378 y=[f"True {lbl}" for lbl in label_names],
379 text=cm,
380 texttemplate="%{text}",
381 colorscale="Blues",
382 showscale=False,
383 )
384 )
385 fig_cm.update_layout(
386 title=f"Confusion Matrix{self._threshold_suffix()}",
387 xaxis_title=f"Predicted label ({self.target})",
388 yaxis_title=f"True label ({self.target})",
389 )
390 fig_cm.update_xaxes(
391 type="category",
392 categoryorder="array",
393 categoryarray=[f"Pred {lbl}" for lbl in label_names],
394 )
395 fig_cm.update_yaxes(
396 type="category",
397 categoryorder="array",
398 categoryarray=[f"True {lbl}" for lbl in label_names],
399 autorange="reversed",
400 )
401 _apply_report_layout(fig_cm)
402 return fig_cm
403
404 def _build_classification_report_fig(self, y_true, y_pred, labels):
405 precision, recall, f1, support = precision_recall_fscore_support(
406 y_true, y_pred, labels=labels, zero_division=0
407 )
408 mcc_scores = []
409 for lbl in labels:
410 y_true_bin = (y_true == lbl).astype(int)
411 y_pred_bin = (y_pred == lbl).astype(int)
412 try:
413 mcc_val = matthews_corrcoef(y_true_bin, y_pred_bin)
414 except Exception:
415 mcc_val = 0.0
416 mcc_scores.append(mcc_val)
417
418 label_names = [str(lbl) for lbl in labels]
419 metrics = ["precision", "recall", "f1", "support"]
420
421 max_support = float(max(support) if len(support) else 0)
422 z_rows = []
423 text_rows = []
424 for i, lbl in enumerate(label_names):
425 norm_support = (support[i] / max_support) if max_support else 0.0
426 z_rows.append(
427 [
428 precision[i],
429 recall[i],
430 f1[i],
431 norm_support,
432 ]
433 )
434 text_rows.append(
435 [
436 f"{precision[i]:.3f}",
437 f"{recall[i]:.3f}",
438 f"{f1[i]:.3f}",
439 f"{int(support[i])}",
440 ]
441 )
442
443 fig = go.Figure(
444 data=go.Heatmap(
445 z=z_rows,
446 x=metrics,
447 y=label_names,
448 colorscale="YlOrRd",
449 zmin=0,
450 zmax=1,
451 colorbar=dict(title="Scale"),
452 text=text_rows,
453 texttemplate="%{text}",
454 hovertemplate="Label=%{y}<br>Metric=%{x}<br>Value=%{text}<extra></extra>",
455 )
456 )
457 fig.update_yaxes(
458 title_text=f"Label ({self.target})",
459 autorange="reversed",
460 type="category",
461 tickmode="array",
462 tickvals=label_names,
463 ticktext=label_names,
464 showgrid=False,
465 )
466 fig.update_xaxes(title_text="", tickangle=45)
467 fig.update_layout(
468 title=f"Per-Class Metrics{self._threshold_suffix()}",
469 margin=dict(l=70, r=60, t=70, b=80),
470 )
471 _apply_report_layout(fig)
472 return fig