Mercurial > repos > goeckslab > pycaret_predict
view feature_importance.py @ 6:a32ff7201629 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author | goeckslab |
---|---|
date | Wed, 02 Jul 2025 19:00:03 +0000 |
parents | ccd798db5abb |
children |
line wrap: on
line source
import base64 import logging import os import matplotlib.pyplot as plt import pandas as pd import shap from pycaret.classification import ClassificationExperiment from pycaret.regression import RegressionExperiment logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger(__name__) class FeatureImportanceAnalyzer: def __init__( self, task_type, output_dir, data_path=None, data=None, target_col=None, exp=None, best_model=None): self.task_type = task_type self.output_dir = output_dir self.exp = exp self.best_model = best_model if exp is not None: # Assume all configs (data, target) are in exp self.data = exp.dataset.copy() self.target = exp.target_param LOG.info("Using provided experiment object") else: if data is not None: self.data = data LOG.info("Data loaded from memory") else: self.target_col = target_col self.data = pd.read_csv(data_path, sep=None, engine='python') self.data.columns = self.data.columns.str.replace('.', '_') self.data = self.data.fillna(self.data.median(numeric_only=True)) self.target = self.data.columns[int(target_col) - 1] self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() self.plots = {} def setup_pycaret(self): if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: LOG.info("Experiment already set up. Skipping PyCaret setup.") return LOG.info("Initializing PyCaret") setup_params = { 'target': self.target, 'session_id': 123, 'html': True, 'log_experiment': False, 'system_log': False } self.exp.setup(self.data, **setup_params) def save_tree_importance(self): model = self.best_model or self.exp.get_config('best_model') processed_features = self.exp.get_config('X_transformed').columns # Try feature_importances_ or coef_ if available importances = None model_type = model.__class__.__name__ self.tree_model_name = model_type # Store the model name for reporting if hasattr(model, "feature_importances_"): importances = model.feature_importances_ elif hasattr(model, "coef_"): # For linear models, flatten coef_ and take abs (importance as magnitude) importances = abs(model.coef_).flatten() else: # Neither attribute exists; skip the plot LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") self.tree_model_name = None # No plot generated return # Defensive: handle mismatch in number of features if len(importances) != len(processed_features): LOG.warning( f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." ) self.tree_model_name = None return feature_importances = pd.DataFrame({ 'Feature': processed_features, 'Importance': importances }).sort_values(by='Importance', ascending=False) plt.figure(figsize=(10, 6)) plt.barh( feature_importances['Feature'], feature_importances['Importance']) plt.xlabel('Importance') plt.title(f'Feature Importance ({model_type})') plot_path = os.path.join( self.output_dir, 'tree_importance.png') plt.savefig(plot_path) plt.close() self.plots['tree_importance'] = plot_path def save_shap_values(self): model = self.best_model or self.exp.get_config('best_model') X_transformed = self.exp.get_config('X_transformed') tree_classes = ( "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" ) model_class_name = model.__class__.__name__ self.shap_model_name = model_class_name # Ensure feature alignment if hasattr(model, "feature_name_"): used_features = model.feature_name_ elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): used_features = model.booster_.feature_name() else: used_features = X_transformed.columns if any(tc in model_class_name for tc in tree_classes): explainer = shap.TreeExplainer(model) X_shap = X_transformed[used_features] shap_values = explainer.shap_values(X_shap) plot_X = X_shap plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" else: sampled_X = X_transformed[used_features].sample(100, random_state=42) explainer = shap.KernelExplainer(model.predict, sampled_X) shap_values = explainer.shap_values(sampled_X) plot_X = sampled_X plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)" shap.summary_plot(shap_values, plot_X, show=False) plt.title(plot_title) plot_path = os.path.join(self.output_dir, "shap_summary.png") plt.savefig(plot_path) plt.close() self.plots["shap_summary"] = plot_path def generate_html_report(self): LOG.info("Generating HTML report") plots_html = "" for plot_name, plot_path in self.plots.items(): # Special handling for tree importance: skip if no model name (not generated) if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): continue encoded_image = self.encode_image_to_base64(plot_path) if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): section_title = f"Feature importance analysis from a trained {self.tree_model_name}" elif plot_name == 'shap_summary': section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" else: section_title = plot_name plots_html += f""" <div class="plot" id="{plot_name}"> <h2>{section_title}</h2> <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> </div> """ html_content = f""" <h1>PyCaret Feature Importance Report</h1> {plots_html} """ return html_content def encode_image_to_base64(self, img_path): with open(img_path, 'rb') as img_file: return base64.b64encode(img_file.read()).decode('utf-8') def run(self): if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: self.setup_pycaret() self.save_tree_importance() self.save_shap_values() html_content = self.generate_html_report() return html_content