Mercurial > repos > goeckslab > pycaret_predict
comparison feature_importance.py @ 6:a32ff7201629 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author | goeckslab |
---|---|
date | Wed, 02 Jul 2025 19:00:03 +0000 |
parents | ccd798db5abb |
children |
comparison
equal
deleted
inserted
replaced
5:c846405830eb | 6:a32ff7201629 |
---|---|
2 import logging | 2 import logging |
3 import os | 3 import os |
4 | 4 |
5 import matplotlib.pyplot as plt | 5 import matplotlib.pyplot as plt |
6 import pandas as pd | 6 import pandas as pd |
7 import shap | |
7 from pycaret.classification import ClassificationExperiment | 8 from pycaret.classification import ClassificationExperiment |
8 from pycaret.regression import RegressionExperiment | 9 from pycaret.regression import RegressionExperiment |
9 | 10 |
10 logging.basicConfig(level=logging.DEBUG) | 11 logging.basicConfig(level=logging.DEBUG) |
11 LOG = logging.getLogger(__name__) | 12 LOG = logging.getLogger(__name__) |
16 self, | 17 self, |
17 task_type, | 18 task_type, |
18 output_dir, | 19 output_dir, |
19 data_path=None, | 20 data_path=None, |
20 data=None, | 21 data=None, |
21 target_col=None): | 22 target_col=None, |
23 exp=None, | |
24 best_model=None): | |
22 | 25 |
23 if data is not None: | 26 self.task_type = task_type |
24 self.data = data | 27 self.output_dir = output_dir |
25 LOG.info("Data loaded from memory") | 28 self.exp = exp |
29 self.best_model = best_model | |
30 | |
31 if exp is not None: | |
32 # Assume all configs (data, target) are in exp | |
33 self.data = exp.dataset.copy() | |
34 self.target = exp.target_param | |
35 LOG.info("Using provided experiment object") | |
26 else: | 36 else: |
27 self.target_col = target_col | 37 if data is not None: |
28 self.data = pd.read_csv(data_path, sep=None, engine='python') | 38 self.data = data |
29 self.data.columns = self.data.columns.str.replace('.', '_') | 39 LOG.info("Data loaded from memory") |
30 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 40 else: |
31 self.task_type = task_type | 41 self.target_col = target_col |
32 self.target = self.data.columns[int(target_col) - 1] | 42 self.data = pd.read_csv(data_path, sep=None, engine='python') |
33 self.exp = ClassificationExperiment() \ | 43 self.data.columns = self.data.columns.str.replace('.', '_') |
34 if task_type == 'classification' \ | 44 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
35 else RegressionExperiment() | 45 self.target = self.data.columns[int(target_col) - 1] |
46 self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() | |
47 | |
36 self.plots = {} | 48 self.plots = {} |
37 self.output_dir = output_dir | |
38 | 49 |
39 def setup_pycaret(self): | 50 def setup_pycaret(self): |
51 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: | |
52 LOG.info("Experiment already set up. Skipping PyCaret setup.") | |
53 return | |
40 LOG.info("Initializing PyCaret") | 54 LOG.info("Initializing PyCaret") |
41 setup_params = { | 55 setup_params = { |
42 'target': self.target, | 56 'target': self.target, |
43 'session_id': 123, | 57 'session_id': 123, |
44 'html': True, | 58 'html': True, |
45 'log_experiment': False, | 59 'log_experiment': False, |
46 'system_log': False | 60 'system_log': False |
47 } | 61 } |
48 LOG.info(self.task_type) | |
49 LOG.info(self.exp) | |
50 self.exp.setup(self.data, **setup_params) | 62 self.exp.setup(self.data, **setup_params) |
51 | 63 |
52 # def save_coefficients(self): | 64 def save_tree_importance(self): |
53 # model = self.exp.create_model('lr') | 65 model = self.best_model or self.exp.get_config('best_model') |
54 # coef_df = pd.DataFrame({ | 66 processed_features = self.exp.get_config('X_transformed').columns |
55 # 'Feature': self.data.columns.drop(self.target), | |
56 # 'Coefficient': model.coef_[0] | |
57 # }) | |
58 # coef_html = coef_df.to_html(index=False) | |
59 # return coef_html | |
60 | 67 |
61 def save_tree_importance(self): | 68 # Try feature_importances_ or coef_ if available |
62 model = self.exp.create_model('rf') | 69 importances = None |
63 importances = model.feature_importances_ | 70 model_type = model.__class__.__name__ |
64 processed_features = self.exp.get_config('X_transformed').columns | 71 self.tree_model_name = model_type # Store the model name for reporting |
65 LOG.debug(f"Feature importances: {importances}") | 72 |
66 LOG.debug(f"Features: {processed_features}") | 73 if hasattr(model, "feature_importances_"): |
74 importances = model.feature_importances_ | |
75 elif hasattr(model, "coef_"): | |
76 # For linear models, flatten coef_ and take abs (importance as magnitude) | |
77 importances = abs(model.coef_).flatten() | |
78 else: | |
79 # Neither attribute exists; skip the plot | |
80 LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") | |
81 self.tree_model_name = None # No plot generated | |
82 return | |
83 | |
84 # Defensive: handle mismatch in number of features | |
85 if len(importances) != len(processed_features): | |
86 LOG.warning( | |
87 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." | |
88 ) | |
89 self.tree_model_name = None | |
90 return | |
91 | |
67 feature_importances = pd.DataFrame({ | 92 feature_importances = pd.DataFrame({ |
68 'Feature': processed_features, | 93 'Feature': processed_features, |
69 'Importance': importances | 94 'Importance': importances |
70 }).sort_values(by='Importance', ascending=False) | 95 }).sort_values(by='Importance', ascending=False) |
71 plt.figure(figsize=(10, 6)) | 96 plt.figure(figsize=(10, 6)) |
72 plt.barh( | 97 plt.barh( |
73 feature_importances['Feature'], | 98 feature_importances['Feature'], |
74 feature_importances['Importance']) | 99 feature_importances['Importance']) |
75 plt.xlabel('Importance') | 100 plt.xlabel('Importance') |
76 plt.title('Feature Importance (Random Forest)') | 101 plt.title(f'Feature Importance ({model_type})') |
77 plot_path = os.path.join( | 102 plot_path = os.path.join( |
78 self.output_dir, | 103 self.output_dir, |
79 'tree_importance.png') | 104 'tree_importance.png') |
80 plt.savefig(plot_path) | 105 plt.savefig(plot_path) |
81 plt.close() | 106 plt.close() |
82 self.plots['tree_importance'] = plot_path | 107 self.plots['tree_importance'] = plot_path |
83 | 108 |
84 def save_shap_values(self): | 109 def save_shap_values(self): |
85 model = self.exp.create_model('lightgbm') | 110 model = self.best_model or self.exp.get_config('best_model') |
86 import shap | 111 X_transformed = self.exp.get_config('X_transformed') |
87 explainer = shap.Explainer(model) | 112 tree_classes = ( |
88 shap_values = explainer.shap_values( | 113 "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" |
89 self.exp.get_config('X_transformed')) | 114 ) |
90 shap.summary_plot(shap_values, | 115 model_class_name = model.__class__.__name__ |
91 self.exp.get_config('X_transformed'), show=False) | 116 self.shap_model_name = model_class_name |
92 plt.title('Shap (LightGBM)') | 117 |
93 plot_path = os.path.join( | 118 # Ensure feature alignment |
94 self.output_dir, 'shap_summary.png') | 119 if hasattr(model, "feature_name_"): |
120 used_features = model.feature_name_ | |
121 elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): | |
122 used_features = model.booster_.feature_name() | |
123 else: | |
124 used_features = X_transformed.columns | |
125 | |
126 if any(tc in model_class_name for tc in tree_classes): | |
127 explainer = shap.TreeExplainer(model) | |
128 X_shap = X_transformed[used_features] | |
129 shap_values = explainer.shap_values(X_shap) | |
130 plot_X = X_shap | |
131 plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" | |
132 else: | |
133 sampled_X = X_transformed[used_features].sample(100, random_state=42) | |
134 explainer = shap.KernelExplainer(model.predict, sampled_X) | |
135 shap_values = explainer.shap_values(sampled_X) | |
136 plot_X = sampled_X | |
137 plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)" | |
138 | |
139 shap.summary_plot(shap_values, plot_X, show=False) | |
140 plt.title(plot_title) | |
141 plot_path = os.path.join(self.output_dir, "shap_summary.png") | |
95 plt.savefig(plot_path) | 142 plt.savefig(plot_path) |
96 plt.close() | 143 plt.close() |
97 self.plots['shap_summary'] = plot_path | 144 self.plots["shap_summary"] = plot_path |
98 | |
99 def generate_feature_importance(self): | |
100 # coef_html = self.save_coefficients() | |
101 self.save_tree_importance() | |
102 self.save_shap_values() | |
103 | |
104 def encode_image_to_base64(self, img_path): | |
105 with open(img_path, 'rb') as img_file: | |
106 return base64.b64encode(img_file.read()).decode('utf-8') | |
107 | 145 |
108 def generate_html_report(self): | 146 def generate_html_report(self): |
109 LOG.info("Generating HTML report") | 147 LOG.info("Generating HTML report") |
110 | 148 |
111 # Read and encode plot images | |
112 plots_html = "" | 149 plots_html = "" |
113 for plot_name, plot_path in self.plots.items(): | 150 for plot_name, plot_path in self.plots.items(): |
151 # Special handling for tree importance: skip if no model name (not generated) | |
152 if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): | |
153 continue | |
114 encoded_image = self.encode_image_to_base64(plot_path) | 154 encoded_image = self.encode_image_to_base64(plot_path) |
155 if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): | |
156 section_title = f"Feature importance analysis from a trained {self.tree_model_name}" | |
157 elif plot_name == 'shap_summary': | |
158 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" | |
159 else: | |
160 section_title = plot_name | |
115 plots_html += f""" | 161 plots_html += f""" |
116 <div class="plot" id="{plot_name}"> | 162 <div class="plot" id="{plot_name}"> |
117 <h2>{'Feature importance analysis from a' | 163 <h2>{section_title}</h2> |
118 'trained Random Forest' | 164 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> |
119 if plot_name == 'tree_importance' | |
120 else 'SHAP Summary from a trained lightgbm'}</h2> | |
121 <h3>{'Use gini impurity for' | |
122 'calculating feature importance for classification' | |
123 'and Variance Reduction for regression' | |
124 if plot_name == 'tree_importance' | |
125 else ''}</h3> | |
126 <img src="data:image/png;base64, | |
127 {encoded_image}" alt="{plot_name}"> | |
128 </div> | 165 </div> |
129 """ | 166 """ |
130 | 167 |
131 # Generate HTML content with tabs | |
132 html_content = f""" | 168 html_content = f""" |
133 <h1>PyCaret Feature Importance Report</h1> | 169 <h1>PyCaret Feature Importance Report</h1> |
134 {plots_html} | 170 {plots_html} |
135 """ | 171 """ |
136 | 172 |
137 return html_content | 173 return html_content |
138 | 174 |
175 def encode_image_to_base64(self, img_path): | |
176 with open(img_path, 'rb') as img_file: | |
177 return base64.b64encode(img_file.read()).decode('utf-8') | |
178 | |
139 def run(self): | 179 def run(self): |
140 LOG.info("Running feature importance analysis") | 180 if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: |
141 self.setup_pycaret() | 181 self.setup_pycaret() |
142 self.generate_feature_importance() | 182 self.save_tree_importance() |
183 self.save_shap_values() | |
143 html_content = self.generate_html_report() | 184 html_content = self.generate_html_report() |
144 LOG.info("Feature importance analysis completed") | |
145 return html_content | 185 return html_content |
146 | |
147 | |
148 if __name__ == "__main__": | |
149 import argparse | |
150 parser = argparse.ArgumentParser(description="Feature Importance Analysis") | |
151 parser.add_argument( | |
152 "--data_path", type=str, help="Path to the dataset") | |
153 parser.add_argument( | |
154 "--target_col", type=int, | |
155 help="Index of the target column (1-based)") | |
156 parser.add_argument( | |
157 "--task_type", type=str, | |
158 choices=["classification", "regression"], | |
159 help="Task type: classification or regression") | |
160 parser.add_argument( | |
161 "--output_dir", | |
162 type=str, | |
163 help="Directory to save the outputs") | |
164 args = parser.parse_args() | |
165 | |
166 analyzer = FeatureImportanceAnalyzer( | |
167 args.data_path, args.target_col, | |
168 args.task_type, args.output_dir) | |
169 analyzer.run() |