comparison pycaret_classification.py @ 12:e674b9e946fb draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author goeckslab
date Mon, 08 Sep 2025 22:39:12 +0000
parents 1aed7d47c5ec
children
comparison
equal deleted inserted replaced
11:4eca9d109de1 12:e674b9e946fb
1 import logging 1 import logging
2 import types 2 import types
3 from typing import Dict 3 from typing import Dict
4 4
5 import numpy as np
6 import pandas as pd
7 import plotly.graph_objects as go
5 from base_model_trainer import BaseModelTrainer 8 from base_model_trainer import BaseModelTrainer
6 from dashboard import generate_classifier_explainer_dashboard 9 from dashboard import generate_classifier_explainer_dashboard
7 from plotly.graph_objects import Figure
8 from pycaret.classification import ClassificationExperiment 10 from pycaret.classification import ClassificationExperiment
11 from sklearn.metrics import auc, confusion_matrix, precision_recall_curve, roc_curve
9 from utils import predict_proba 12 from utils import predict_proba
10 13
11 LOG = logging.getLogger(__name__) 14 LOG = logging.getLogger(__name__)
15
16
17 def _apply_report_layout(fig: go.Figure) -> go.Figure:
18 # Give the left side more space for y-axis title/ticks and let axes auto-reserve room
19 fig.update_xaxes(automargin=True, title_standoff=12)
20 fig.update_yaxes(automargin=True, title_standoff=12)
21 fig.update_layout(
22 autosize=True,
23 margin=dict(l=120, r=40, t=60, b=60), # bump 'l' if you still see clipping
24 )
25 return fig
12 26
13 27
14 class ClassificationModelTrainer(BaseModelTrainer): 28 class ClassificationModelTrainer(BaseModelTrainer):
15 def __init__( 29 def __init__(
16 self, 30 self,
48 LOG.warning( 62 LOG.warning(
49 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch." 63 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
50 ) 64 )
51 65
52 plots = [ 66 plots = [
53 'confusion_matrix', 67 "auc",
54 'auc', 68 "threshold",
55 'threshold', 69 "pr",
56 'pr', 70 "error",
57 'error', 71 "class_report",
58 'class_report', 72 "learning",
59 'learning', 73 "calibration",
60 'calibration', 74 "vc",
61 'vc', 75 "dimension",
62 'dimension', 76 "manifold",
63 'manifold', 77 "rfe",
64 'rfe', 78 "feature",
65 'feature', 79 "feature_all",
66 'feature_all',
67 ] 80 ]
68 for plot_name in plots: 81 for plot_name in plots:
69 try: 82 try:
70 if plot_name == "threshold": 83 if plot_name == "threshold":
71 plot_path = self.exp.plot_model( 84 plot_path = self.exp.plot_model(
100 def generate_plots_explainer(self): 113 def generate_plots_explainer(self):
101 from explainerdashboard import ClassifierExplainer 114 from explainerdashboard import ClassifierExplainer
102 115
103 LOG.info("Generating explainer plots") 116 LOG.info("Generating explainer plots")
104 117
118 # Ensure predict_proba is available here too
119 if not hasattr(self.best_model, "predict_proba"):
120 self.best_model.predict_proba = types.MethodType(
121 predict_proba, self.best_model
122 )
123 LOG.warning(
124 f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
125 )
126
105 X_test = self.exp.X_test_transformed.copy() 127 X_test = self.exp.X_test_transformed.copy()
106 y_test = self.exp.y_test_transformed 128 y_test = self.exp.y_test_transformed
107 explainer = ClassifierExplainer(self.best_model, X_test, y_test) 129 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
108 130
109 # a dict to hold the raw Figure objects or callables 131 # a dict to hold the raw Figure objects or callables
110 self.explainer_plots: Dict[str, Figure] = {} 132 self.explainer_plots: Dict[str, go.Figure] = {}
111 133
112 # these go into the Test tab 134 # --- Threshold-aware overrides for CM / ROC / PR ---
135 prob_thresh = getattr(self, "probability_threshold", None)
136
137 # Only for binary classification and when threshold is provided
138 if (prob_thresh is not None) and (not self.exp.is_multiclass):
139 X = self.exp.X_test_transformed
140 y = pd.Series(self.exp.y_test_transformed).reset_index(drop=True)
141
142 # Get positive-class scores (robust defaults)
143 classes = list(getattr(self.best_model, "classes_", [0, 1]))
144 try:
145 pos_idx = classes.index(1) if 1 in classes else 1
146 except Exception:
147 pos_idx = 1
148
149 proba = self.best_model.predict_proba(X)
150 y_scores = proba[:, pos_idx]
151
152 # Derive label names consistently
153 pos_label = classes[pos_idx] if len(classes) > pos_idx else 1
154 neg_label = classes[1 - pos_idx] if len(classes) > 1 else 0
155
156 # ---- Confusion Matrix @ threshold ----
157 try:
158 y_pred = np.where(y_scores >= prob_thresh, pos_label, neg_label)
159 cm = confusion_matrix(y, y_pred, labels=[neg_label, pos_label])
160 fig_cm = go.Figure(
161 data=go.Heatmap(
162 z=cm,
163 x=[f"Pred {neg_label}", f"Pred {pos_label}"],
164 y=[f"True {neg_label}", f"True {pos_label}"],
165 text=cm,
166 texttemplate="%{text}",
167 colorscale="Blues",
168 showscale=False,
169 )
170 )
171 fig_cm.update_layout(
172 title=f"Confusion Matrix @ threshold={prob_thresh:.2f}",
173 xaxis_title="Predicted label",
174 yaxis_title="True label",
175 )
176 _apply_report_layout(fig_cm)
177 self.explainer_plots["confusion_matrix"] = fig_cm
178 except Exception as e:
179 LOG.warning(
180 f"Threshold-aware confusion matrix failed; falling back: {e}"
181 )
182
183 # ---- ROC with threshold marker ----
184 try:
185 fpr, tpr, thr = roc_curve(y, y_scores)
186 roc_auc = auc(fpr, tpr)
187 fig_roc = go.Figure()
188 fig_roc.add_scatter(
189 x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={roc_auc:.3f})"
190 )
191 if len(thr):
192 mask = np.isfinite(thr)
193 if mask.any():
194 idx_local = int(np.argmin(np.abs(thr[mask] - prob_thresh)))
195 idx = np.where(mask)[0][idx_local]
196 if 0 <= idx < len(fpr):
197 fig_roc.add_scatter(
198 x=[fpr[idx]],
199 y=[tpr[idx]],
200 mode="markers",
201 name=f"@ {prob_thresh:.2f}",
202 marker=dict(size=10),
203 )
204 fig_roc.update_layout(
205 title=f"ROC Curve (marker at threshold={prob_thresh:.2f})",
206 xaxis_title="False Positive Rate",
207 yaxis_title="True Positive Rate",
208 )
209 _apply_report_layout(fig_roc)
210 self.explainer_plots["roc_auc"] = fig_roc
211 except Exception as e:
212 LOG.warning(f"Threshold marker on ROC failed; falling back: {e}")
213
214 # ---- PR with threshold marker ----
215 try:
216 precision, recall, thr_pr = precision_recall_curve(y, y_scores)
217 pr_auc = auc(recall, precision)
218 fig_pr = go.Figure()
219 fig_pr.add_scatter(
220 x=recall, y=precision, mode="lines", name=f"PR (AUC={pr_auc:.3f})"
221 )
222 if len(thr_pr):
223 idx_pr = int(np.argmin(np.abs(thr_pr - prob_thresh)))
224 # note: thr_pr has length = len(precision) - 1
225 idx_pr = max(0, min(idx_pr, len(recall) - 1))
226 fig_pr.add_scatter(
227 x=[recall[idx_pr]],
228 y=[precision[idx_pr]],
229 mode="markers",
230 name=f"@ {prob_thresh:.2f}",
231 marker=dict(size=10),
232 )
233 fig_pr.update_layout(
234 title=f"Precision–Recall (marker at threshold={prob_thresh:.2f})",
235 xaxis_title="Recall",
236 yaxis_title="Precision",
237 )
238 _apply_report_layout(fig_pr)
239 self.explainer_plots["pr_auc"] = fig_pr
240 except Exception as e:
241 LOG.warning(f"Threshold marker on PR failed; falling back: {e}")
242
243 # these go into the Test tab (don't overwrite overrides)
113 for key, fn in [ 244 for key, fn in [
114 ("roc_auc", explainer.plot_roc_auc), 245 ("roc_auc", explainer.plot_roc_auc),
115 ("pr_auc", explainer.plot_pr_auc), 246 ("pr_auc", explainer.plot_pr_auc),
116 ("lift_curve", explainer.plot_lift_curve), 247 ("lift_curve", explainer.plot_lift_curve),
117 ("confusion_matrix", explainer.plot_confusion_matrix), 248 ("confusion_matrix", explainer.plot_confusion_matrix),
118 ("threshold", explainer.plot_precision), # Percentage vs probability 249 ("threshold", explainer.plot_precision), # percentage vs probability
119 ("cumulative_precision", explainer.plot_cumulative_precision), 250 ("cumulative_precision", explainer.plot_cumulative_precision),
120 ]: 251 ]:
121 try: 252 if key in self.explainer_plots:
122 self.explainer_plots[key] = fn() 253 continue
254 try:
255 fig = fn()
256 if fig is not None:
257 self.explainer_plots[key] = fig
123 except Exception as e: 258 except Exception as e:
124 LOG.error(f"Error generating explainer plot {key}: {e}") 259 LOG.error(f"Error generating explainer plot {key}: {e}")
125 260
126 # mean SHAP importances 261 # mean SHAP importances
127 try: 262 try:
141 valid_feats = [] 276 valid_feats = []
142 for feat in self.features_name: 277 for feat in self.features_name:
143 if feat in explainer.X.columns or feat in explainer.onehot_cols: 278 if feat in explainer.X.columns or feat in explainer.onehot_cols:
144 valid_feats.append(feat) 279 valid_feats.append(feat)
145 else: 280 else:
146 LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data") 281 LOG.warning(
282 f"Skipping PDP for feature {feat!r}: not found in explainer data"
283 )
147 284
148 for feat in valid_feats: 285 for feat in valid_feats:
149 # wrap each PDP call to catch any unexpected AssertionErrors 286 # wrap each PDP call to catch any unexpected AssertionErrors
150 def make_pdp_plotter(f): 287 def make_pdp_plotter(f):
151 def _plot(): 288 def _plot():
155 LOG.warning(f"PDP AssertionError for {f!r}: {ae}") 292 LOG.warning(f"PDP AssertionError for {f!r}: {ae}")
156 return None 293 return None
157 except Exception as e: 294 except Exception as e:
158 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}") 295 LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
159 return None 296 return None
297
160 return _plot 298 return _plot
161 299
162 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat) 300 self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)