Mercurial > repos > goeckslab > pycaret_predict
diff feature_importance.py @ 6:a32ff7201629 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author | goeckslab |
---|---|
date | Wed, 02 Jul 2025 19:00:03 +0000 |
parents | ccd798db5abb |
children |
line wrap: on
line diff
--- a/feature_importance.py Sat Jun 21 15:07:04 2025 +0000 +++ b/feature_importance.py Wed Jul 02 19:00:03 2025 +0000 @@ -4,6 +4,7 @@ import matplotlib.pyplot as plt import pandas as pd +import shap from pycaret.classification import ClassificationExperiment from pycaret.regression import RegressionExperiment @@ -18,25 +19,38 @@ output_dir, data_path=None, data=None, - target_col=None): + target_col=None, + exp=None, + best_model=None): - if data is not None: - self.data = data - LOG.info("Data loaded from memory") + self.task_type = task_type + self.output_dir = output_dir + self.exp = exp + self.best_model = best_model + + if exp is not None: + # Assume all configs (data, target) are in exp + self.data = exp.dataset.copy() + self.target = exp.target_param + LOG.info("Using provided experiment object") else: - self.target_col = target_col - self.data = pd.read_csv(data_path, sep=None, engine='python') - self.data.columns = self.data.columns.str.replace('.', '_') - self.data = self.data.fillna(self.data.median(numeric_only=True)) - self.task_type = task_type - self.target = self.data.columns[int(target_col) - 1] - self.exp = ClassificationExperiment() \ - if task_type == 'classification' \ - else RegressionExperiment() + if data is not None: + self.data = data + LOG.info("Data loaded from memory") + else: + self.target_col = target_col + self.data = pd.read_csv(data_path, sep=None, engine='python') + self.data.columns = self.data.columns.str.replace('.', '_') + self.data = self.data.fillna(self.data.median(numeric_only=True)) + self.target = self.data.columns[int(target_col) - 1] + self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() + self.plots = {} - self.output_dir = output_dir def setup_pycaret(self): + if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: + LOG.info("Experiment already set up. Skipping PyCaret setup.") + return LOG.info("Initializing PyCaret") setup_params = { 'target': self.target, @@ -45,25 +59,36 @@ 'log_experiment': False, 'system_log': False } - LOG.info(self.task_type) - LOG.info(self.exp) self.exp.setup(self.data, **setup_params) - # def save_coefficients(self): - # model = self.exp.create_model('lr') - # coef_df = pd.DataFrame({ - # 'Feature': self.data.columns.drop(self.target), - # 'Coefficient': model.coef_[0] - # }) - # coef_html = coef_df.to_html(index=False) - # return coef_html + def save_tree_importance(self): + model = self.best_model or self.exp.get_config('best_model') + processed_features = self.exp.get_config('X_transformed').columns + + # Try feature_importances_ or coef_ if available + importances = None + model_type = model.__class__.__name__ + self.tree_model_name = model_type # Store the model name for reporting - def save_tree_importance(self): - model = self.exp.create_model('rf') - importances = model.feature_importances_ - processed_features = self.exp.get_config('X_transformed').columns - LOG.debug(f"Feature importances: {importances}") - LOG.debug(f"Features: {processed_features}") + if hasattr(model, "feature_importances_"): + importances = model.feature_importances_ + elif hasattr(model, "coef_"): + # For linear models, flatten coef_ and take abs (importance as magnitude) + importances = abs(model.coef_).flatten() + else: + # Neither attribute exists; skip the plot + LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") + self.tree_model_name = None # No plot generated + return + + # Defensive: handle mismatch in number of features + if len(importances) != len(processed_features): + LOG.warning( + f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." + ) + self.tree_model_name = None + return + feature_importances = pd.DataFrame({ 'Feature': processed_features, 'Importance': importances @@ -73,7 +98,7 @@ feature_importances['Feature'], feature_importances['Importance']) plt.xlabel('Importance') - plt.title('Feature Importance (Random Forest)') + plt.title(f'Feature Importance ({model_type})') plot_path = os.path.join( self.output_dir, 'tree_importance.png') @@ -82,53 +107,64 @@ self.plots['tree_importance'] = plot_path def save_shap_values(self): - model = self.exp.create_model('lightgbm') - import shap - explainer = shap.Explainer(model) - shap_values = explainer.shap_values( - self.exp.get_config('X_transformed')) - shap.summary_plot(shap_values, - self.exp.get_config('X_transformed'), show=False) - plt.title('Shap (LightGBM)') - plot_path = os.path.join( - self.output_dir, 'shap_summary.png') + model = self.best_model or self.exp.get_config('best_model') + X_transformed = self.exp.get_config('X_transformed') + tree_classes = ( + "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" + ) + model_class_name = model.__class__.__name__ + self.shap_model_name = model_class_name + + # Ensure feature alignment + if hasattr(model, "feature_name_"): + used_features = model.feature_name_ + elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): + used_features = model.booster_.feature_name() + else: + used_features = X_transformed.columns + + if any(tc in model_class_name for tc in tree_classes): + explainer = shap.TreeExplainer(model) + X_shap = X_transformed[used_features] + shap_values = explainer.shap_values(X_shap) + plot_X = X_shap + plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" + else: + sampled_X = X_transformed[used_features].sample(100, random_state=42) + explainer = shap.KernelExplainer(model.predict, sampled_X) + shap_values = explainer.shap_values(sampled_X) + plot_X = sampled_X + plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)" + + shap.summary_plot(shap_values, plot_X, show=False) + plt.title(plot_title) + plot_path = os.path.join(self.output_dir, "shap_summary.png") plt.savefig(plot_path) plt.close() - self.plots['shap_summary'] = plot_path - - def generate_feature_importance(self): - # coef_html = self.save_coefficients() - self.save_tree_importance() - self.save_shap_values() - - def encode_image_to_base64(self, img_path): - with open(img_path, 'rb') as img_file: - return base64.b64encode(img_file.read()).decode('utf-8') + self.plots["shap_summary"] = plot_path def generate_html_report(self): LOG.info("Generating HTML report") - # Read and encode plot images plots_html = "" for plot_name, plot_path in self.plots.items(): + # Special handling for tree importance: skip if no model name (not generated) + if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): + continue encoded_image = self.encode_image_to_base64(plot_path) + if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): + section_title = f"Feature importance analysis from a trained {self.tree_model_name}" + elif plot_name == 'shap_summary': + section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" + else: + section_title = plot_name plots_html += f""" <div class="plot" id="{plot_name}"> - <h2>{'Feature importance analysis from a' - 'trained Random Forest' - if plot_name == 'tree_importance' - else 'SHAP Summary from a trained lightgbm'}</h2> - <h3>{'Use gini impurity for' - 'calculating feature importance for classification' - 'and Variance Reduction for regression' - if plot_name == 'tree_importance' - else ''}</h3> - <img src="data:image/png;base64, - {encoded_image}" alt="{plot_name}"> + <h2>{section_title}</h2> + <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> </div> """ - # Generate HTML content with tabs html_content = f""" <h1>PyCaret Feature Importance Report</h1> {plots_html} @@ -136,34 +172,14 @@ return html_content - def run(self): - LOG.info("Running feature importance analysis") - self.setup_pycaret() - self.generate_feature_importance() - html_content = self.generate_html_report() - LOG.info("Feature importance analysis completed") - return html_content - + def encode_image_to_base64(self, img_path): + with open(img_path, 'rb') as img_file: + return base64.b64encode(img_file.read()).decode('utf-8') -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="Feature Importance Analysis") - parser.add_argument( - "--data_path", type=str, help="Path to the dataset") - parser.add_argument( - "--target_col", type=int, - help="Index of the target column (1-based)") - parser.add_argument( - "--task_type", type=str, - choices=["classification", "regression"], - help="Task type: classification or regression") - parser.add_argument( - "--output_dir", - type=str, - help="Directory to save the outputs") - args = parser.parse_args() - - analyzer = FeatureImportanceAnalyzer( - args.data_path, args.target_col, - args.task_type, args.output_dir) - analyzer.run() + def run(self): + if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: + self.setup_pycaret() + self.save_tree_importance() + self.save_shap_values() + html_content = self.generate_html_report() + return html_content