Mercurial > repos > goeckslab > pycaret_predict
view feature_importance.py @ 16:4fee4504646e draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
| author | goeckslab |
|---|---|
| date | Fri, 28 Nov 2025 22:28:26 +0000 |
| parents | e674b9e946fb |
| children |
line wrap: on
line source
import base64 import logging import os import matplotlib.pyplot as plt import pandas as pd import shap from pycaret.classification import ClassificationExperiment from pycaret.regression import RegressionExperiment logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger(__name__) class FeatureImportanceAnalyzer: def __init__( self, task_type, output_dir, data_path=None, data=None, target_col=None, exp=None, best_model=None, max_plot_features=None, processed_data=None, max_shap_rows=None, ): self.task_type = task_type self.output_dir = output_dir self.exp = exp self.best_model = best_model self._skip_messages = [] self.shap_total_features = None self.shap_used_features = None if isinstance(max_plot_features, int) and max_plot_features > 0: self.max_plot_features = max_plot_features elif max_plot_features is None: self.max_plot_features = 30 else: self.max_plot_features = None if exp is not None: # Assume all configs (data, target) are in exp self.data = exp.dataset.copy() self.target = exp.target_param LOG.info("Using provided experiment object") else: if data is not None: self.data = data LOG.info("Data loaded from memory") else: self.target_col = target_col self.data = pd.read_csv(data_path, sep=None, engine="python") self.data.columns = self.data.columns.str.replace(".", "_") self.data = self.data.fillna(self.data.median(numeric_only=True)) self.target = self.data.columns[int(target_col) - 1] self.exp = ( ClassificationExperiment() if task_type == "classification" else RegressionExperiment() ) if processed_data is not None: self.data = processed_data self.plots = {} self.max_shap_rows = max_shap_rows def _get_feature_names_from_model(self, model): """Best-effort extraction of feature names seen by the estimator.""" if model is None: return None candidates = [model] if hasattr(model, "named_steps"): candidates.extend(model.named_steps.values()) elif hasattr(model, "steps"): candidates.extend(step for _, step in model.steps) for candidate in candidates: names = getattr(candidate, "feature_names_in_", None) if names is not None: return list(names) return None def _get_transformed_frame(self, model=None, prefer_test=True): """Return a DataFrame that mirrors the matrix fed to the estimator.""" key_order = ["X_test_transformed", "X_train_transformed"] if not prefer_test: key_order.reverse() key_order.append("X_transformed") feature_names = self._get_feature_names_from_model(model) for key in key_order: try: frame = self.exp.get_config(key) except KeyError: continue if frame is None: continue if isinstance(frame, pd.DataFrame): return frame.copy() try: n_features = frame.shape[1] except Exception: continue if feature_names and len(feature_names) == n_features: return pd.DataFrame(frame, columns=feature_names) # Fallback to positional names so downstream logic still works return pd.DataFrame(frame, columns=[f"f{i}" for i in range(n_features)]) return None def setup_pycaret(self): if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: LOG.info("Experiment already set up. Skipping PyCaret setup.") return LOG.info("Initializing PyCaret") setup_params = { "target": self.target, "session_id": 123, "html": True, "log_experiment": False, "system_log": False, } self.exp.setup(self.data, **setup_params) def save_tree_importance(self): model = self.best_model or self.exp.get_config("best_model") processed_frame = self._get_transformed_frame(model, prefer_test=False) if processed_frame is None: LOG.warning( "Unable to determine transformed feature names; skipping tree importance plot." ) self.tree_model_name = None return processed_features = list(processed_frame.columns) importances = None model_type = model.__class__.__name__ self.tree_model_name = model_type if hasattr(model, "feature_importances_"): importances = model.feature_importances_ elif hasattr(model, "coef_"): importances = abs(model.coef_).flatten() else: LOG.warning( f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance." ) self.tree_model_name = None return if len(importances) != len(processed_features): model_feature_names = self._get_feature_names_from_model(model) if model_feature_names and len(model_feature_names) == len(importances): processed_features = model_feature_names else: LOG.warning( "Importances (%s) != features (%s). Skipping tree importance.", len(importances), len(processed_features), ) self.tree_model_name = None return feature_importances = pd.DataFrame( {"Feature": processed_features, "Importance": importances} ).sort_values(by="Importance", ascending=False) cap = ( min(self.max_plot_features, len(feature_importances)) if self.max_plot_features is not None else len(feature_importances) ) plot_importances = feature_importances.head(cap) if cap < len(feature_importances): LOG.info( "Tree importance plot limited to top %s of %s features", cap, len(feature_importances), ) plt.figure(figsize=(10, 6)) plt.barh( plot_importances["Feature"], plot_importances["Importance"], ) plt.xlabel("Importance") plt.title(f"Feature Importance ({model_type}) (top {cap})") plot_path = os.path.join(self.output_dir, "tree_importance.png") plt.tight_layout() plt.savefig(plot_path, bbox_inches="tight") plt.close() self.plots["tree_importance"] = plot_path def save_shap_values(self, max_samples=None, max_display=None, max_features=None): model = self.best_model or self.exp.get_config("best_model") X_data = self._get_transformed_frame(model) if X_data is None: raise RuntimeError("No transformed dataset found for SHAP.") n_rows, n_features = X_data.shape self.shap_total_features = n_features feature_cap = ( min(self.max_plot_features, n_features) if self.max_plot_features is not None else n_features ) if max_features is None: max_features = feature_cap else: max_features = min(max_features, feature_cap) display_features = list(X_data.columns) try: if hasattr(model, "feature_importances_"): importances = pd.Series( model.feature_importances_, index=X_data.columns ) top_features = importances.nlargest(max_features).index elif hasattr(model, "coef_"): coef = abs(model.coef_).flatten() importances = pd.Series(coef, index=X_data.columns) top_features = importances.nlargest(max_features).index else: variances = X_data.var() top_features = variances.nlargest(max_features).index candidate_features = list(top_features) missing = [f for f in candidate_features if f not in X_data.columns] display_features = [f for f in candidate_features if f in X_data.columns] if missing: LOG.warning( "Dropping %s transformed feature(s) not present in SHAP frame: %s", len(missing), missing[:5], ) if display_features and len(display_features) < n_features: LOG.info( "Restricting SHAP display to top %s of %s features", len(display_features), n_features, ) elif not display_features: display_features = list(X_data.columns) except Exception as e: LOG.warning( f"Feature limiting failed: {e}. Using all {n_features} features." ) display_features = list(X_data.columns) self.shap_used_features = len(display_features) # Apply the column restriction so SHAP only runs on the selected features. if display_features: X_data = X_data[display_features] n_rows, n_features = X_data.shape # --- Adaptive row subsampling --- if max_samples is None: if n_rows <= 500: max_samples = n_rows elif n_rows <= 5000: max_samples = 500 else: max_samples = min(1000, int(n_rows * 0.1)) if self.max_shap_rows is not None: max_samples = min(max_samples, self.max_shap_rows) if n_rows > max_samples: LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") X_data = X_data.sample(max_samples, random_state=42) # --- Adaptive feature display --- display_cap = ( min(self.max_plot_features, len(display_features)) if self.max_plot_features is not None else len(display_features) ) if max_display is None: max_display = display_cap else: max_display = min(max_display, display_cap) if not display_features: display_features = list(X_data.columns) max_display = len(display_features) # Background set bg = X_data.sample(min(len(X_data), 100), random_state=42) predict_fn = ( model.predict_proba if hasattr(model, "predict_proba") else model.predict ) # Optimized explainer explainer = None explainer_label = None if hasattr(model, "feature_importances_"): explainer = shap.TreeExplainer( model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 ) explainer_label = "tree_path_dependent" elif hasattr(model, "coef_"): explainer = shap.LinearExplainer(model, bg) explainer_label = "linear" else: explainer = shap.Explainer(predict_fn, bg) explainer_label = explainer.__class__.__name__ try: shap_values = explainer(X_data) self.shap_model_name = explainer.__class__.__name__ except Exception as e: error_message = str(e) needs_tree_fallback = ( hasattr(model, "feature_importances_") and "does not cover all the leaves" in error_message.lower() ) feature_name_mismatch = "feature names should match" in error_message.lower() if needs_tree_fallback: LOG.warning( "SHAP computation failed using '%s' perturbation (%s). " "Retrying with interventional perturbation.", explainer_label, error_message, ) try: explainer = shap.TreeExplainer( model, bg, feature_perturbation="interventional", n_jobs=-1, ) shap_values = explainer(X_data) self.shap_model_name = ( f"{explainer.__class__.__name__} (interventional)" ) except Exception as retry_exc: LOG.error( "SHAP computation failed even after fallback: %s", retry_exc, ) self.shap_model_name = None return elif feature_name_mismatch: LOG.warning( "SHAP computation failed due to feature-name mismatch (%s). " "Falling back to model-agnostic SHAP explainer.", error_message, ) try: agnostic_explainer = shap.Explainer(predict_fn, bg) shap_values = agnostic_explainer(X_data) self.shap_model_name = ( f"{agnostic_explainer.__class__.__name__} (fallback)" ) except Exception as fallback_exc: LOG.error( "Model-agnostic SHAP fallback also failed: %s", fallback_exc, ) self.shap_model_name = None return else: LOG.error(f"SHAP computation failed: {e}") self.shap_model_name = None return def _limit_explanation_features(explanation): if len(display_features) >= n_features: return explanation try: limited = explanation[:, display_features] LOG.info( "SHAP explanation trimmed to %s display features.", len(display_features), ) return limited except Exception as exc: LOG.warning( "Failed to restrict SHAP explanation to top features " "(sample=%s); plot will include all features. Error: %s", display_features[:5], exc, ) # Keep using full feature list if trimming fails return explanation shap_shape = getattr(shap_values, "shape", None) class_labels = list(getattr(model, "classes_", [])) shap_outputs = [] if shap_shape is not None and len(shap_shape) == 3: output_count = shap_shape[2] LOG.info("Detected multi-output SHAP explanation with %s classes.", output_count) for class_idx in range(output_count): try: class_expl = shap_values[..., class_idx] except Exception as exc: LOG.warning( "Failed to extract SHAP explanation for class index %s: %s", class_idx, exc, ) continue label = ( class_labels[class_idx] if class_labels and class_idx < len(class_labels) else class_idx ) shap_outputs.append((class_idx, label, class_expl)) else: shap_outputs.append((None, None, shap_values)) if not shap_outputs: LOG.error("No SHAP outputs available for plotting.") self.shap_model_name = None return # --- Plot SHAP summary (one per class if needed) --- for class_idx, class_label, class_expl in shap_outputs: expl_to_plot = _limit_explanation_features(class_expl) suffix = "" plot_key = "shap_summary" if class_idx is not None: safe_label = str(class_label).replace("/", "_").replace(" ", "_") suffix = f"_class_{safe_label}" plot_key = f"shap_summary_class_{safe_label}" out_filename = f"shap_summary{suffix}.png" out_path = os.path.join(self.output_dir, out_filename) plt.figure() shap.plots.beeswarm(expl_to_plot, max_display=max_display, show=False) title = f"SHAP Summary for {model.__class__.__name__}" if class_idx is not None: title += f" (class {class_label})" plt.title(f"{title} (top {max_display} features)") plt.tight_layout() plt.savefig(out_path, bbox_inches="tight") plt.close() self.plots[plot_key] = out_path # --- Log summary --- LOG.info( "SHAP summary completed with %s rows and %s features " "(displaying top %s) across %s output(s).", X_data.shape[0], X_data.shape[1], max_display, len(shap_outputs), ) def generate_html_report(self): LOG.info("Generating HTML report") plots_html = "" for plot_name, plot_path in self.plots.items(): if plot_name == "tree_importance" and not getattr( self, "tree_model_name", None ): continue encoded_image = self.encode_image_to_base64(plot_path) if plot_name == "tree_importance" and getattr( self, "tree_model_name", None ): section_title = f"Feature importance from {self.tree_model_name}" elif plot_name == "shap_summary": section_title = ( f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" ) elif plot_name.startswith("shap_summary_class_"): class_label = plot_name.replace("shap_summary_class_", "") section_title = ( f"SHAP Summary for class {class_label} " f"({getattr(self, 'shap_model_name', 'model')})" ) else: section_title = plot_name plots_html += f""" <div class="plot" id="{plot_name}" style="text-align:center;margin-bottom:24px;"> <h2>{section_title}</h2> <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}" style="max-width:95%;height:auto;display:block;margin:0 auto;border:1px solid #ddd;padding:8px;background:#fff;"> </div> """ return f"{plots_html}" def encode_image_to_base64(self, img_path): with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") def run(self): if ( self.exp is None or not hasattr(self.exp, "is_setup") or not self.exp.is_setup ): self.setup_pycaret() self.save_tree_importance() self.save_shap_values() return self.generate_html_report()
