comparison feature_importance.py @ 16:4fee4504646e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:26 +0000
parents e674b9e946fb
children
comparison
equal deleted inserted replaced
15:a2aeeb754d76 16:4fee4504646e
20 data_path=None, 20 data_path=None,
21 data=None, 21 data=None,
22 target_col=None, 22 target_col=None,
23 exp=None, 23 exp=None,
24 best_model=None, 24 best_model=None,
25 max_plot_features=None,
26 processed_data=None,
27 max_shap_rows=None,
25 ): 28 ):
26 self.task_type = task_type 29 self.task_type = task_type
27 self.output_dir = output_dir 30 self.output_dir = output_dir
28 self.exp = exp 31 self.exp = exp
29 self.best_model = best_model 32 self.best_model = best_model
33 self._skip_messages = []
34 self.shap_total_features = None
35 self.shap_used_features = None
36 if isinstance(max_plot_features, int) and max_plot_features > 0:
37 self.max_plot_features = max_plot_features
38 elif max_plot_features is None:
39 self.max_plot_features = 30
40 else:
41 self.max_plot_features = None
30 42
31 if exp is not None: 43 if exp is not None:
32 # Assume all configs (data, target) are in exp 44 # Assume all configs (data, target) are in exp
33 self.data = exp.dataset.copy() 45 self.data = exp.dataset.copy()
34 self.target = exp.target_param 46 self.target = exp.target_param
46 self.exp = ( 58 self.exp = (
47 ClassificationExperiment() 59 ClassificationExperiment()
48 if task_type == "classification" 60 if task_type == "classification"
49 else RegressionExperiment() 61 else RegressionExperiment()
50 ) 62 )
63 if processed_data is not None:
64 self.data = processed_data
51 65
52 self.plots = {} 66 self.plots = {}
67 self.max_shap_rows = max_shap_rows
68
69 def _get_feature_names_from_model(self, model):
70 """Best-effort extraction of feature names seen by the estimator."""
71 if model is None:
72 return None
73
74 candidates = [model]
75 if hasattr(model, "named_steps"):
76 candidates.extend(model.named_steps.values())
77 elif hasattr(model, "steps"):
78 candidates.extend(step for _, step in model.steps)
79
80 for candidate in candidates:
81 names = getattr(candidate, "feature_names_in_", None)
82 if names is not None:
83 return list(names)
84 return None
85
86 def _get_transformed_frame(self, model=None, prefer_test=True):
87 """Return a DataFrame that mirrors the matrix fed to the estimator."""
88 key_order = ["X_test_transformed", "X_train_transformed"]
89 if not prefer_test:
90 key_order.reverse()
91 key_order.append("X_transformed")
92
93 feature_names = self._get_feature_names_from_model(model)
94 for key in key_order:
95 try:
96 frame = self.exp.get_config(key)
97 except KeyError:
98 continue
99 if frame is None:
100 continue
101 if isinstance(frame, pd.DataFrame):
102 return frame.copy()
103 try:
104 n_features = frame.shape[1]
105 except Exception:
106 continue
107 if feature_names and len(feature_names) == n_features:
108 return pd.DataFrame(frame, columns=feature_names)
109 # Fallback to positional names so downstream logic still works
110 return pd.DataFrame(frame, columns=[f"f{i}" for i in range(n_features)])
111 return None
53 112
54 def setup_pycaret(self): 113 def setup_pycaret(self):
55 if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: 114 if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
56 LOG.info("Experiment already set up. Skipping PyCaret setup.") 115 LOG.info("Experiment already set up. Skipping PyCaret setup.")
57 return 116 return
65 } 124 }
66 self.exp.setup(self.data, **setup_params) 125 self.exp.setup(self.data, **setup_params)
67 126
68 def save_tree_importance(self): 127 def save_tree_importance(self):
69 model = self.best_model or self.exp.get_config("best_model") 128 model = self.best_model or self.exp.get_config("best_model")
70 processed_features = self.exp.get_config("X_transformed").columns 129 processed_frame = self._get_transformed_frame(model, prefer_test=False)
130 if processed_frame is None:
131 LOG.warning(
132 "Unable to determine transformed feature names; skipping tree importance plot."
133 )
134 self.tree_model_name = None
135 return
136 processed_features = list(processed_frame.columns)
71 137
72 importances = None 138 importances = None
73 model_type = model.__class__.__name__ 139 model_type = model.__class__.__name__
74 self.tree_model_name = model_type 140 self.tree_model_name = model_type
75 141
83 ) 149 )
84 self.tree_model_name = None 150 self.tree_model_name = None
85 return 151 return
86 152
87 if len(importances) != len(processed_features): 153 if len(importances) != len(processed_features):
88 LOG.warning( 154 model_feature_names = self._get_feature_names_from_model(model)
89 f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance." 155 if model_feature_names and len(model_feature_names) == len(importances):
90 ) 156 processed_features = model_feature_names
91 self.tree_model_name = None 157 else:
92 return 158 LOG.warning(
159 "Importances (%s) != features (%s). Skipping tree importance.",
160 len(importances),
161 len(processed_features),
162 )
163 self.tree_model_name = None
164 return
93 165
94 feature_importances = pd.DataFrame( 166 feature_importances = pd.DataFrame(
95 {"Feature": processed_features, "Importance": importances} 167 {"Feature": processed_features, "Importance": importances}
96 ).sort_values(by="Importance", ascending=False) 168 ).sort_values(by="Importance", ascending=False)
169 cap = (
170 min(self.max_plot_features, len(feature_importances))
171 if self.max_plot_features is not None
172 else len(feature_importances)
173 )
174 plot_importances = feature_importances.head(cap)
175 if cap < len(feature_importances):
176 LOG.info(
177 "Tree importance plot limited to top %s of %s features",
178 cap,
179 len(feature_importances),
180 )
97 plt.figure(figsize=(10, 6)) 181 plt.figure(figsize=(10, 6))
98 plt.barh(feature_importances["Feature"], feature_importances["Importance"]) 182 plt.barh(
183 plot_importances["Feature"],
184 plot_importances["Importance"],
185 )
99 plt.xlabel("Importance") 186 plt.xlabel("Importance")
100 plt.title(f"Feature Importance ({model_type})") 187 plt.title(f"Feature Importance ({model_type}) (top {cap})")
101 plot_path = os.path.join(self.output_dir, "tree_importance.png") 188 plot_path = os.path.join(self.output_dir, "tree_importance.png")
189 plt.tight_layout()
102 plt.savefig(plot_path, bbox_inches="tight") 190 plt.savefig(plot_path, bbox_inches="tight")
103 plt.close() 191 plt.close()
104 self.plots["tree_importance"] = plot_path 192 self.plots["tree_importance"] = plot_path
105 193
106 def save_shap_values(self, max_samples=None, max_display=None, max_features=None): 194 def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
107 model = self.best_model or self.exp.get_config("best_model") 195 model = self.best_model or self.exp.get_config("best_model")
108 196
109 X_data = None 197 X_data = self._get_transformed_frame(model)
110 for key in ("X_test_transformed", "X_train_transformed"):
111 try:
112 X_data = self.exp.get_config(key)
113 break
114 except KeyError:
115 continue
116 if X_data is None: 198 if X_data is None:
117 raise RuntimeError("No transformed dataset found for SHAP.") 199 raise RuntimeError("No transformed dataset found for SHAP.")
118 200
119 # --- Adaptive feature limiting (proportional cap) ---
120 n_rows, n_features = X_data.shape 201 n_rows, n_features = X_data.shape
202 self.shap_total_features = n_features
203 feature_cap = (
204 min(self.max_plot_features, n_features)
205 if self.max_plot_features is not None
206 else n_features
207 )
121 if max_features is None: 208 if max_features is None:
122 if n_features <= 200: 209 max_features = feature_cap
123 max_features = n_features 210 else:
124 else: 211 max_features = min(max_features, feature_cap)
125 max_features = min(200, max(20, int(n_features * 0.1))) 212 display_features = list(X_data.columns)
126 213
127 try: 214 try:
128 if hasattr(model, "feature_importances_"): 215 if hasattr(model, "feature_importances_"):
129 importances = pd.Series( 216 importances = pd.Series(
130 model.feature_importances_, index=X_data.columns 217 model.feature_importances_, index=X_data.columns
136 top_features = importances.nlargest(max_features).index 223 top_features = importances.nlargest(max_features).index
137 else: 224 else:
138 variances = X_data.var() 225 variances = X_data.var()
139 top_features = variances.nlargest(max_features).index 226 top_features = variances.nlargest(max_features).index
140 227
141 if len(top_features) < n_features: 228 candidate_features = list(top_features)
229 missing = [f for f in candidate_features if f not in X_data.columns]
230 display_features = [f for f in candidate_features if f in X_data.columns]
231 if missing:
232 LOG.warning(
233 "Dropping %s transformed feature(s) not present in SHAP frame: %s",
234 len(missing),
235 missing[:5],
236 )
237 if display_features and len(display_features) < n_features:
142 LOG.info( 238 LOG.info(
143 f"Restricted SHAP computation to top {len(top_features)} / {n_features} features" 239 "Restricting SHAP display to top %s of %s features",
144 ) 240 len(display_features),
145 X_data = X_data[top_features] 241 n_features,
242 )
243 elif not display_features:
244 display_features = list(X_data.columns)
146 except Exception as e: 245 except Exception as e:
147 LOG.warning( 246 LOG.warning(
148 f"Feature limiting failed: {e}. Using all {n_features} features." 247 f"Feature limiting failed: {e}. Using all {n_features} features."
149 ) 248 )
249 display_features = list(X_data.columns)
250
251 self.shap_used_features = len(display_features)
252
253 # Apply the column restriction so SHAP only runs on the selected features.
254 if display_features:
255 X_data = X_data[display_features]
256 n_rows, n_features = X_data.shape
150 257
151 # --- Adaptive row subsampling --- 258 # --- Adaptive row subsampling ---
152 if max_samples is None: 259 if max_samples is None:
153 if n_rows <= 500: 260 if n_rows <= 500:
154 max_samples = n_rows 261 max_samples = n_rows
155 elif n_rows <= 5000: 262 elif n_rows <= 5000:
156 max_samples = 500 263 max_samples = 500
157 else: 264 else:
158 max_samples = min(1000, int(n_rows * 0.1)) 265 max_samples = min(1000, int(n_rows * 0.1))
159 266
267 if self.max_shap_rows is not None:
268 max_samples = min(max_samples, self.max_shap_rows)
269
160 if n_rows > max_samples: 270 if n_rows > max_samples:
161 LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") 271 LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
162 X_data = X_data.sample(max_samples, random_state=42) 272 X_data = X_data.sample(max_samples, random_state=42)
163 273
164 # --- Adaptive feature display --- 274 # --- Adaptive feature display ---
275 display_cap = (
276 min(self.max_plot_features, len(display_features))
277 if self.max_plot_features is not None
278 else len(display_features)
279 )
165 if max_display is None: 280 if max_display is None:
166 if X_data.shape[1] <= 20: 281 max_display = display_cap
167 max_display = X_data.shape[1] 282 else:
168 elif X_data.shape[1] <= 100: 283 max_display = min(max_display, display_cap)
169 max_display = 30 284 if not display_features:
170 else: 285 display_features = list(X_data.columns)
171 max_display = 50 286 max_display = len(display_features)
172 287
173 # Background set 288 # Background set
174 bg = X_data.sample(min(len(X_data), 100), random_state=42) 289 bg = X_data.sample(min(len(X_data), 100), random_state=42)
175 predict_fn = ( 290 predict_fn = (
176 model.predict_proba if hasattr(model, "predict_proba") else model.predict 291 model.predict_proba if hasattr(model, "predict_proba") else model.predict
177 ) 292 )
178 293
179 # Optimized explainer 294 # Optimized explainer
295 explainer = None
296 explainer_label = None
180 if hasattr(model, "feature_importances_"): 297 if hasattr(model, "feature_importances_"):
181 explainer = shap.TreeExplainer( 298 explainer = shap.TreeExplainer(
182 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 299 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
183 ) 300 )
301 explainer_label = "tree_path_dependent"
184 elif hasattr(model, "coef_"): 302 elif hasattr(model, "coef_"):
185 explainer = shap.LinearExplainer(model, bg) 303 explainer = shap.LinearExplainer(model, bg)
304 explainer_label = "linear"
186 else: 305 else:
187 explainer = shap.Explainer(predict_fn, bg) 306 explainer = shap.Explainer(predict_fn, bg)
307 explainer_label = explainer.__class__.__name__
188 308
189 try: 309 try:
190 shap_values = explainer(X_data) 310 shap_values = explainer(X_data)
191 self.shap_model_name = explainer.__class__.__name__ 311 self.shap_model_name = explainer.__class__.__name__
192 except Exception as e: 312 except Exception as e:
193 LOG.error(f"SHAP computation failed: {e}") 313 error_message = str(e)
314 needs_tree_fallback = (
315 hasattr(model, "feature_importances_")
316 and "does not cover all the leaves" in error_message.lower()
317 )
318 feature_name_mismatch = "feature names should match" in error_message.lower()
319 if needs_tree_fallback:
320 LOG.warning(
321 "SHAP computation failed using '%s' perturbation (%s). "
322 "Retrying with interventional perturbation.",
323 explainer_label,
324 error_message,
325 )
326 try:
327 explainer = shap.TreeExplainer(
328 model,
329 bg,
330 feature_perturbation="interventional",
331 n_jobs=-1,
332 )
333 shap_values = explainer(X_data)
334 self.shap_model_name = (
335 f"{explainer.__class__.__name__} (interventional)"
336 )
337 except Exception as retry_exc:
338 LOG.error(
339 "SHAP computation failed even after fallback: %s",
340 retry_exc,
341 )
342 self.shap_model_name = None
343 return
344 elif feature_name_mismatch:
345 LOG.warning(
346 "SHAP computation failed due to feature-name mismatch (%s). "
347 "Falling back to model-agnostic SHAP explainer.",
348 error_message,
349 )
350 try:
351 agnostic_explainer = shap.Explainer(predict_fn, bg)
352 shap_values = agnostic_explainer(X_data)
353 self.shap_model_name = (
354 f"{agnostic_explainer.__class__.__name__} (fallback)"
355 )
356 except Exception as fallback_exc:
357 LOG.error(
358 "Model-agnostic SHAP fallback also failed: %s",
359 fallback_exc,
360 )
361 self.shap_model_name = None
362 return
363 else:
364 LOG.error(f"SHAP computation failed: {e}")
365 self.shap_model_name = None
366 return
367
368 def _limit_explanation_features(explanation):
369 if len(display_features) >= n_features:
370 return explanation
371 try:
372 limited = explanation[:, display_features]
373 LOG.info(
374 "SHAP explanation trimmed to %s display features.",
375 len(display_features),
376 )
377 return limited
378 except Exception as exc:
379 LOG.warning(
380 "Failed to restrict SHAP explanation to top features "
381 "(sample=%s); plot will include all features. Error: %s",
382 display_features[:5],
383 exc,
384 )
385 # Keep using full feature list if trimming fails
386 return explanation
387
388 shap_shape = getattr(shap_values, "shape", None)
389 class_labels = list(getattr(model, "classes_", []))
390 shap_outputs = []
391 if shap_shape is not None and len(shap_shape) == 3:
392 output_count = shap_shape[2]
393 LOG.info("Detected multi-output SHAP explanation with %s classes.", output_count)
394 for class_idx in range(output_count):
395 try:
396 class_expl = shap_values[..., class_idx]
397 except Exception as exc:
398 LOG.warning(
399 "Failed to extract SHAP explanation for class index %s: %s",
400 class_idx,
401 exc,
402 )
403 continue
404 label = (
405 class_labels[class_idx]
406 if class_labels and class_idx < len(class_labels)
407 else class_idx
408 )
409 shap_outputs.append((class_idx, label, class_expl))
410 else:
411 shap_outputs.append((None, None, shap_values))
412
413 if not shap_outputs:
414 LOG.error("No SHAP outputs available for plotting.")
194 self.shap_model_name = None 415 self.shap_model_name = None
195 return 416 return
196 417
197 # --- Plot SHAP summary --- 418 # --- Plot SHAP summary (one per class if needed) ---
198 out_path = os.path.join(self.output_dir, "shap_summary.png") 419 for class_idx, class_label, class_expl in shap_outputs:
199 plt.figure() 420 expl_to_plot = _limit_explanation_features(class_expl)
200 shap.plots.beeswarm(shap_values, max_display=max_display, show=False) 421 suffix = ""
201 plt.title( 422 plot_key = "shap_summary"
202 f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)" 423 if class_idx is not None:
203 ) 424 safe_label = str(class_label).replace("/", "_").replace(" ", "_")
204 plt.savefig(out_path, bbox_inches="tight") 425 suffix = f"_class_{safe_label}"
205 plt.close() 426 plot_key = f"shap_summary_class_{safe_label}"
206 self.plots["shap_summary"] = out_path 427 out_filename = f"shap_summary{suffix}.png"
428 out_path = os.path.join(self.output_dir, out_filename)
429 plt.figure()
430 shap.plots.beeswarm(expl_to_plot, max_display=max_display, show=False)
431 title = f"SHAP Summary for {model.__class__.__name__}"
432 if class_idx is not None:
433 title += f" (class {class_label})"
434 plt.title(f"{title} (top {max_display} features)")
435 plt.tight_layout()
436 plt.savefig(out_path, bbox_inches="tight")
437 plt.close()
438 self.plots[plot_key] = out_path
207 439
208 # --- Log summary --- 440 # --- Log summary ---
209 LOG.info( 441 LOG.info(
210 f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})." 442 "SHAP summary completed with %s rows and %s features "
443 "(displaying top %s) across %s output(s).",
444 X_data.shape[0],
445 X_data.shape[1],
446 max_display,
447 len(shap_outputs),
211 ) 448 )
212 449
213 def generate_html_report(self): 450 def generate_html_report(self):
214 LOG.info("Generating HTML report") 451 LOG.info("Generating HTML report")
215 plots_html = "" 452 plots_html = ""
225 section_title = f"Feature importance from {self.tree_model_name}" 462 section_title = f"Feature importance from {self.tree_model_name}"
226 elif plot_name == "shap_summary": 463 elif plot_name == "shap_summary":
227 section_title = ( 464 section_title = (
228 f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" 465 f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
229 ) 466 )
467 elif plot_name.startswith("shap_summary_class_"):
468 class_label = plot_name.replace("shap_summary_class_", "")
469 section_title = (
470 f"SHAP Summary for class {class_label} "
471 f"({getattr(self, 'shap_model_name', 'model')})"
472 )
230 else: 473 else:
231 section_title = plot_name 474 section_title = plot_name
232 plots_html += f""" 475 plots_html += f"""
233 <div class="plot" id="{plot_name}"> 476 <div class="plot" id="{plot_name}" style="text-align:center;margin-bottom:24px;">
234 <h2>{section_title}</h2> 477 <h2>{section_title}</h2>
235 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> 478 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"
479 style="max-width:95%;height:auto;display:block;margin:0 auto;border:1px solid #ddd;padding:8px;background:#fff;">
236 </div> 480 </div>
237 """ 481 """
238 return f"{plots_html}" 482 return f"{plots_html}"
239 483
240 def encode_image_to_base64(self, img_path): 484 def encode_image_to_base64(self, img_path):