Mercurial > repos > goeckslab > pycaret_predict
diff feature_importance.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | f4cb41f458fd |
children |
line wrap: on
line diff
--- a/feature_importance.py Wed Jul 09 01:13:01 2025 +0000 +++ b/feature_importance.py Fri Jul 25 19:02:32 2025 +0000 @@ -14,14 +14,15 @@ class FeatureImportanceAnalyzer: def __init__( - self, - task_type, - output_dir, - data_path=None, - data=None, - target_col=None, - exp=None, - best_model=None): + self, + task_type, + output_dir, + data_path=None, + data=None, + target_col=None, + exp=None, + best_model=None, + ): self.task_type = task_type self.output_dir = output_dir @@ -43,7 +44,11 @@ self.data.columns = self.data.columns.str.replace('.', '_') self.data = self.data.fillna(self.data.median(numeric_only=True)) self.target = self.data.columns[int(target_col) - 1] - self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() + self.exp = ( + ClassificationExperiment() + if task_type == "classification" + else RegressionExperiment() + ) self.plots = {} @@ -57,7 +62,7 @@ 'session_id': 123, 'html': True, 'log_experiment': False, - 'system_log': False + 'system_log': False, } self.exp.setup(self.data, **setup_params) @@ -70,14 +75,16 @@ model_type = model.__class__.__name__ self.tree_model_name = model_type # Store the model name for reporting - if hasattr(model, "feature_importances_"): + if hasattr(model, 'feature_importances_'): importances = model.feature_importances_ - elif hasattr(model, "coef_"): + elif hasattr(model, 'coef_'): # For linear models, flatten coef_ and take abs (importance as magnitude) importances = abs(model.coef_).flatten() else: # Neither attribute exists; skip the plot - LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") + LOG.warning( + f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." + ) self.tree_model_name = None # No plot generated return @@ -89,69 +96,77 @@ self.tree_model_name = None return - feature_importances = pd.DataFrame({ - 'Feature': processed_features, - 'Importance': importances - }).sort_values(by='Importance', ascending=False) + feature_importances = pd.DataFrame( + {'Feature': processed_features, 'Importance': importances} + ).sort_values(by='Importance', ascending=False) plt.figure(figsize=(10, 6)) - plt.barh( - feature_importances['Feature'], - feature_importances['Importance']) + plt.barh(feature_importances['Feature'], feature_importances['Importance']) plt.xlabel('Importance') plt.title(f'Feature Importance ({model_type})') - plot_path = os.path.join( - self.output_dir, - 'tree_importance.png') + plot_path = os.path.join(self.output_dir, 'tree_importance.png') plt.savefig(plot_path) plt.close() self.plots['tree_importance'] = plot_path def save_shap_values(self): - model = self.best_model or self.exp.get_config('best_model') - X_transformed = self.exp.get_config('X_transformed') - tree_classes = ( - "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" - ) - model_class_name = model.__class__.__name__ - self.shap_model_name = model_class_name + + model = self.best_model or self.exp.get_config("best_model") - # Ensure feature alignment - if hasattr(model, "feature_name_"): - used_features = model.feature_name_ - elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): + X_data = None + for key in ("X_test_transformed", "X_train_transformed"): + try: + X_data = self.exp.get_config(key) + break + except KeyError: + continue + if X_data is None: + raise RuntimeError( + "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " + "Make sure PyCaret setup/compare_models was run with feature_selection=True." + ) + + try: used_features = model.booster_.feature_name() - elif hasattr(model, "feature_names_in_"): - # scikit‐learn's standard attribute for the names of features used during fit - used_features = list(model.feature_names_in_) - else: - used_features = X_transformed.columns + except Exception: + used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) + X_data = X_data[used_features] + + max_bg = min(len(X_data), 100) + bg = X_data.sample(max_bg, random_state=42) + + predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict - if any(tc in model_class_name for tc in tree_classes): - explainer = shap.TreeExplainer(model) - X_shap = X_transformed[used_features] - shap_values = explainer.shap_values(X_shap) - plot_X = X_shap - plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" + explainer = shap.Explainer(predict_fn, bg) + self.shap_model_name = explainer.__class__.__name__ + + shap_values = explainer(X_data) + + output_names = getattr(shap_values, "output_names", None) + if output_names is None and hasattr(model, "classes_"): + output_names = list(model.classes_) + if output_names is None: + n_out = shap_values.values.shape[-1] + output_names = list(map(str, range(n_out))) + + values = shap_values.values + if values.ndim == 3: + for j, name in enumerate(output_names): + safe = name.replace(" ", "_").replace("/", "_") + out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") + plt.figure() + shap.plots.beeswarm(shap_values[..., j], show=False) + plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") + plt.savefig(out_path) + plt.close() + self.plots[f"shap_summary_{safe}"] = out_path else: - logging.warning(f"len(X_transformed) = {len(X_transformed)}") - max_samples = 100 - n_samples = min(max_samples, len(X_transformed)) - sampled_X = X_transformed[used_features].sample( - n=n_samples, - replace=False, - random_state=42 - ) - explainer = shap.KernelExplainer(model.predict, sampled_X) - shap_values = explainer.shap_values(sampled_X) - plot_X = sampled_X - plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)" - - shap.summary_plot(shap_values, plot_X, show=False) - plt.title(plot_title) - plot_path = os.path.join(self.output_dir, "shap_summary.png") - plt.savefig(plot_path) - plt.close() - self.plots["shap_summary"] = plot_path + plt.figure() + shap.plots.beeswarm(shap_values, show=False) + plt.title(f"SHAP Summary for {model.__class__.__name__}") + out_path = os.path.join(self.output_dir, "shap_summary.png") + plt.savefig(out_path) + plt.close() + self.plots["shap_summary"] = out_path def generate_html_report(self): LOG.info("Generating HTML report") @@ -159,11 +174,17 @@ plots_html = "" for plot_name, plot_path in self.plots.items(): # Special handling for tree importance: skip if no model name (not generated) - if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): + if plot_name == 'tree_importance' and not getattr( + self, 'tree_model_name', None + ): continue encoded_image = self.encode_image_to_base64(plot_path) - if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): - section_title = f"Feature importance analysis from a trained {self.tree_model_name}" + if plot_name == 'tree_importance' and getattr( + self, 'tree_model_name', None + ): + section_title = ( + f"Feature importance analysis from a trained {self.tree_model_name}" + ) elif plot_name == 'shap_summary': section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" else: @@ -176,7 +197,6 @@ """ html_content = f""" - <h1>PyCaret Feature Importance Report</h1> {plots_html} """ @@ -187,7 +207,11 @@ return base64.b64encode(img_file.read()).decode('utf-8') def run(self): - if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: + if ( + self.exp is None + or not hasattr(self.exp, 'is_setup') + or not self.exp.is_setup + ): self.setup_pycaret() self.save_tree_importance() self.save_shap_values()