Mercurial > repos > goeckslab > tabular_learner
diff feature_importance.py @ 8:ba45bc057d70 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
| author | goeckslab | 
|---|---|
| date | Mon, 08 Sep 2025 22:38:55 +0000 | 
| parents | 0afd970bd8ae | 
| children | 
line wrap: on
 line diff
--- a/feature_importance.py Fri Aug 22 21:13:44 2025 +0000 +++ b/feature_importance.py Mon Sep 08 22:38:55 2025 +0000 @@ -23,7 +23,6 @@ exp=None, best_model=None, ): - self.task_type = task_type self.output_dir = output_dir self.exp = exp @@ -40,8 +39,8 @@ LOG.info("Data loaded from memory") else: self.target_col = target_col - self.data = pd.read_csv(data_path, sep=None, engine='python') - self.data.columns = self.data.columns.str.replace('.', '_') + self.data = pd.read_csv(data_path, sep=None, engine="python") + self.data.columns = self.data.columns.str.replace(".", "_") self.data = self.data.fillna(self.data.median(numeric_only=True)) self.target = self.data.columns[int(target_col) - 1] self.exp = ( @@ -53,63 +52,58 @@ self.plots = {} def setup_pycaret(self): - if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: + if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: LOG.info("Experiment already set up. Skipping PyCaret setup.") return LOG.info("Initializing PyCaret") setup_params = { - 'target': self.target, - 'session_id': 123, - 'html': True, - 'log_experiment': False, - 'system_log': False, + "target": self.target, + "session_id": 123, + "html": True, + "log_experiment": False, + "system_log": False, } self.exp.setup(self.data, **setup_params) def save_tree_importance(self): - model = self.best_model or self.exp.get_config('best_model') - processed_features = self.exp.get_config('X_transformed').columns + model = self.best_model or self.exp.get_config("best_model") + processed_features = self.exp.get_config("X_transformed").columns - # Try feature_importances_ or coef_ if available importances = None model_type = model.__class__.__name__ - self.tree_model_name = model_type # Store the model name for reporting + self.tree_model_name = model_type - if hasattr(model, 'feature_importances_'): + if hasattr(model, "feature_importances_"): importances = model.feature_importances_ - elif hasattr(model, 'coef_'): - # For linear models, flatten coef_ and take abs (importance as magnitude) + elif hasattr(model, "coef_"): importances = abs(model.coef_).flatten() else: - # Neither attribute exists; skip the plot LOG.warning( - f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." + f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance." ) - self.tree_model_name = None # No plot generated + self.tree_model_name = None return - # Defensive: handle mismatch in number of features if len(importances) != len(processed_features): LOG.warning( - f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." + f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance." ) self.tree_model_name = None return feature_importances = pd.DataFrame( - {'Feature': processed_features, 'Importance': importances} - ).sort_values(by='Importance', ascending=False) + {"Feature": processed_features, "Importance": importances} + ).sort_values(by="Importance", ascending=False) plt.figure(figsize=(10, 6)) - plt.barh(feature_importances['Feature'], feature_importances['Importance']) - plt.xlabel('Importance') - plt.title(f'Feature Importance ({model_type})') - plot_path = os.path.join(self.output_dir, 'tree_importance.png') - plt.savefig(plot_path) + plt.barh(feature_importances["Feature"], feature_importances["Importance"]) + plt.xlabel("Importance") + plt.title(f"Feature Importance ({model_type})") + plot_path = os.path.join(self.output_dir, "tree_importance.png") + plt.savefig(plot_path, bbox_inches="tight") plt.close() - self.plots['tree_importance'] = plot_path + self.plots["tree_importance"] = plot_path - def save_shap_values(self): - + def save_shap_values(self, max_samples=None, max_display=None, max_features=None): model = self.best_model or self.exp.get_config("best_model") X_data = None @@ -120,78 +114,119 @@ except KeyError: continue if X_data is None: - raise RuntimeError( - "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " - "Make sure PyCaret setup/compare_models was run with feature_selection=True." - ) + raise RuntimeError("No transformed dataset found for SHAP.") + + # --- Adaptive feature limiting (proportional cap) --- + n_rows, n_features = X_data.shape + if max_features is None: + if n_features <= 200: + max_features = n_features + else: + max_features = min(200, max(20, int(n_features * 0.1))) try: - used_features = model.booster_.feature_name() - except Exception: - used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) - X_data = X_data[used_features] + if hasattr(model, "feature_importances_"): + importances = pd.Series( + model.feature_importances_, index=X_data.columns + ) + top_features = importances.nlargest(max_features).index + elif hasattr(model, "coef_"): + coef = abs(model.coef_).flatten() + importances = pd.Series(coef, index=X_data.columns) + top_features = importances.nlargest(max_features).index + else: + variances = X_data.var() + top_features = variances.nlargest(max_features).index + + if len(top_features) < n_features: + LOG.info( + f"Restricted SHAP computation to top {len(top_features)} / {n_features} features" + ) + X_data = X_data[top_features] + except Exception as e: + LOG.warning( + f"Feature limiting failed: {e}. Using all {n_features} features." + ) - max_bg = min(len(X_data), 100) - bg = X_data.sample(max_bg, random_state=42) + # --- Adaptive row subsampling --- + if max_samples is None: + if n_rows <= 500: + max_samples = n_rows + elif n_rows <= 5000: + max_samples = 500 + else: + max_samples = min(1000, int(n_rows * 0.1)) + + if n_rows > max_samples: + LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") + X_data = X_data.sample(max_samples, random_state=42) - predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict + # --- Adaptive feature display --- + if max_display is None: + if X_data.shape[1] <= 20: + max_display = X_data.shape[1] + elif X_data.shape[1] <= 100: + max_display = 30 + else: + max_display = 50 + + # Background set + bg = X_data.sample(min(len(X_data), 100), random_state=42) + predict_fn = ( + model.predict_proba if hasattr(model, "predict_proba") else model.predict + ) + + # Optimized explainer + if hasattr(model, "feature_importances_"): + explainer = shap.TreeExplainer( + model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 + ) + elif hasattr(model, "coef_"): + explainer = shap.LinearExplainer(model, bg) + else: + explainer = shap.Explainer(predict_fn, bg) try: - explainer = shap.Explainer(predict_fn, bg) + shap_values = explainer(X_data) self.shap_model_name = explainer.__class__.__name__ - - shap_values = explainer(X_data) except Exception as e: LOG.error(f"SHAP computation failed: {e}") self.shap_model_name = None return - output_names = getattr(shap_values, "output_names", None) - if output_names is None and hasattr(model, "classes_"): - output_names = list(model.classes_) - if output_names is None: - n_out = shap_values.values.shape[-1] - output_names = list(map(str, range(n_out))) + # --- Plot SHAP summary --- + out_path = os.path.join(self.output_dir, "shap_summary.png") + plt.figure() + shap.plots.beeswarm(shap_values, max_display=max_display, show=False) + plt.title( + f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)" + ) + plt.savefig(out_path, bbox_inches="tight") + plt.close() + self.plots["shap_summary"] = out_path - values = shap_values.values - if values.ndim == 3: - for j, name in enumerate(output_names): - safe = name.replace(" ", "_").replace("/", "_") - out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") - plt.figure() - shap.plots.beeswarm(shap_values[..., j], show=False) - plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") - plt.savefig(out_path) - plt.close() - self.plots[f"shap_summary_{safe}"] = out_path - else: - plt.figure() - shap.plots.beeswarm(shap_values, show=False) - plt.title(f"SHAP Summary for {model.__class__.__name__}") - out_path = os.path.join(self.output_dir, "shap_summary.png") - plt.savefig(out_path) - plt.close() - self.plots["shap_summary"] = out_path + # --- Log summary --- + LOG.info( + f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})." + ) def generate_html_report(self): LOG.info("Generating HTML report") - plots_html = "" for plot_name, plot_path in self.plots.items(): - # Special handling for tree importance: skip if no model name (not generated) - if plot_name == 'tree_importance' and not getattr( - self, 'tree_model_name', None + if plot_name == "tree_importance" and not getattr( + self, "tree_model_name", None ): continue encoded_image = self.encode_image_to_base64(plot_path) - if plot_name == 'tree_importance' and getattr( - self, 'tree_model_name', None + if plot_name == "tree_importance" and getattr( + self, "tree_model_name", None ): + section_title = f"Feature importance from {self.tree_model_name}" + elif plot_name == "shap_summary": section_title = ( - f"Feature importance analysis from a trained {self.tree_model_name}" + f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" ) - elif plot_name == 'shap_summary': - section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" else: section_title = plot_name plots_html += f""" @@ -200,25 +235,19 @@ <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> </div> """ - - html_content = f""" - {plots_html} - """ - - return html_content + return f"{plots_html}" def encode_image_to_base64(self, img_path): - with open(img_path, 'rb') as img_file: - return base64.b64encode(img_file.read()).decode('utf-8') + with open(img_path, "rb") as img_file: + return base64.b64encode(img_file.read()).decode("utf-8") def run(self): if ( self.exp is None - or not hasattr(self.exp, 'is_setup') + or not hasattr(self.exp, "is_setup") or not self.exp.is_setup ): self.setup_pycaret() self.save_tree_importance() self.save_shap_values() - html_content = self.generate_html_report() - return html_content + return self.generate_html_report()
