Mercurial > repos > goeckslab > pycaret_predict
comparison feature_importance.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | f4cb41f458fd |
children |
comparison
equal
deleted
inserted
replaced
7:f4cb41f458fd | 8:1aed7d47c5ec |
---|---|
12 LOG = logging.getLogger(__name__) | 12 LOG = logging.getLogger(__name__) |
13 | 13 |
14 | 14 |
15 class FeatureImportanceAnalyzer: | 15 class FeatureImportanceAnalyzer: |
16 def __init__( | 16 def __init__( |
17 self, | 17 self, |
18 task_type, | 18 task_type, |
19 output_dir, | 19 output_dir, |
20 data_path=None, | 20 data_path=None, |
21 data=None, | 21 data=None, |
22 target_col=None, | 22 target_col=None, |
23 exp=None, | 23 exp=None, |
24 best_model=None): | 24 best_model=None, |
25 ): | |
25 | 26 |
26 self.task_type = task_type | 27 self.task_type = task_type |
27 self.output_dir = output_dir | 28 self.output_dir = output_dir |
28 self.exp = exp | 29 self.exp = exp |
29 self.best_model = best_model | 30 self.best_model = best_model |
41 self.target_col = target_col | 42 self.target_col = target_col |
42 self.data = pd.read_csv(data_path, sep=None, engine='python') | 43 self.data = pd.read_csv(data_path, sep=None, engine='python') |
43 self.data.columns = self.data.columns.str.replace('.', '_') | 44 self.data.columns = self.data.columns.str.replace('.', '_') |
44 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 45 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
45 self.target = self.data.columns[int(target_col) - 1] | 46 self.target = self.data.columns[int(target_col) - 1] |
46 self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() | 47 self.exp = ( |
48 ClassificationExperiment() | |
49 if task_type == "classification" | |
50 else RegressionExperiment() | |
51 ) | |
47 | 52 |
48 self.plots = {} | 53 self.plots = {} |
49 | 54 |
50 def setup_pycaret(self): | 55 def setup_pycaret(self): |
51 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: | 56 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: |
55 setup_params = { | 60 setup_params = { |
56 'target': self.target, | 61 'target': self.target, |
57 'session_id': 123, | 62 'session_id': 123, |
58 'html': True, | 63 'html': True, |
59 'log_experiment': False, | 64 'log_experiment': False, |
60 'system_log': False | 65 'system_log': False, |
61 } | 66 } |
62 self.exp.setup(self.data, **setup_params) | 67 self.exp.setup(self.data, **setup_params) |
63 | 68 |
64 def save_tree_importance(self): | 69 def save_tree_importance(self): |
65 model = self.best_model or self.exp.get_config('best_model') | 70 model = self.best_model or self.exp.get_config('best_model') |
68 # Try feature_importances_ or coef_ if available | 73 # Try feature_importances_ or coef_ if available |
69 importances = None | 74 importances = None |
70 model_type = model.__class__.__name__ | 75 model_type = model.__class__.__name__ |
71 self.tree_model_name = model_type # Store the model name for reporting | 76 self.tree_model_name = model_type # Store the model name for reporting |
72 | 77 |
73 if hasattr(model, "feature_importances_"): | 78 if hasattr(model, 'feature_importances_'): |
74 importances = model.feature_importances_ | 79 importances = model.feature_importances_ |
75 elif hasattr(model, "coef_"): | 80 elif hasattr(model, 'coef_'): |
76 # For linear models, flatten coef_ and take abs (importance as magnitude) | 81 # For linear models, flatten coef_ and take abs (importance as magnitude) |
77 importances = abs(model.coef_).flatten() | 82 importances = abs(model.coef_).flatten() |
78 else: | 83 else: |
79 # Neither attribute exists; skip the plot | 84 # Neither attribute exists; skip the plot |
80 LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") | 85 LOG.warning( |
86 f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." | |
87 ) | |
81 self.tree_model_name = None # No plot generated | 88 self.tree_model_name = None # No plot generated |
82 return | 89 return |
83 | 90 |
84 # Defensive: handle mismatch in number of features | 91 # Defensive: handle mismatch in number of features |
85 if len(importances) != len(processed_features): | 92 if len(importances) != len(processed_features): |
87 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." | 94 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." |
88 ) | 95 ) |
89 self.tree_model_name = None | 96 self.tree_model_name = None |
90 return | 97 return |
91 | 98 |
92 feature_importances = pd.DataFrame({ | 99 feature_importances = pd.DataFrame( |
93 'Feature': processed_features, | 100 {'Feature': processed_features, 'Importance': importances} |
94 'Importance': importances | 101 ).sort_values(by='Importance', ascending=False) |
95 }).sort_values(by='Importance', ascending=False) | |
96 plt.figure(figsize=(10, 6)) | 102 plt.figure(figsize=(10, 6)) |
97 plt.barh( | 103 plt.barh(feature_importances['Feature'], feature_importances['Importance']) |
98 feature_importances['Feature'], | |
99 feature_importances['Importance']) | |
100 plt.xlabel('Importance') | 104 plt.xlabel('Importance') |
101 plt.title(f'Feature Importance ({model_type})') | 105 plt.title(f'Feature Importance ({model_type})') |
102 plot_path = os.path.join( | 106 plot_path = os.path.join(self.output_dir, 'tree_importance.png') |
103 self.output_dir, | |
104 'tree_importance.png') | |
105 plt.savefig(plot_path) | 107 plt.savefig(plot_path) |
106 plt.close() | 108 plt.close() |
107 self.plots['tree_importance'] = plot_path | 109 self.plots['tree_importance'] = plot_path |
108 | 110 |
109 def save_shap_values(self): | 111 def save_shap_values(self): |
110 model = self.best_model or self.exp.get_config('best_model') | 112 |
111 X_transformed = self.exp.get_config('X_transformed') | 113 model = self.best_model or self.exp.get_config("best_model") |
112 tree_classes = ( | 114 |
113 "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" | 115 X_data = None |
114 ) | 116 for key in ("X_test_transformed", "X_train_transformed"): |
115 model_class_name = model.__class__.__name__ | 117 try: |
116 self.shap_model_name = model_class_name | 118 X_data = self.exp.get_config(key) |
117 | 119 break |
118 # Ensure feature alignment | 120 except KeyError: |
119 if hasattr(model, "feature_name_"): | 121 continue |
120 used_features = model.feature_name_ | 122 if X_data is None: |
121 elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): | 123 raise RuntimeError( |
124 "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " | |
125 "Make sure PyCaret setup/compare_models was run with feature_selection=True." | |
126 ) | |
127 | |
128 try: | |
122 used_features = model.booster_.feature_name() | 129 used_features = model.booster_.feature_name() |
123 elif hasattr(model, "feature_names_in_"): | 130 except Exception: |
124 # scikit‐learn's standard attribute for the names of features used during fit | 131 used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) |
125 used_features = list(model.feature_names_in_) | 132 X_data = X_data[used_features] |
133 | |
134 max_bg = min(len(X_data), 100) | |
135 bg = X_data.sample(max_bg, random_state=42) | |
136 | |
137 predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict | |
138 | |
139 explainer = shap.Explainer(predict_fn, bg) | |
140 self.shap_model_name = explainer.__class__.__name__ | |
141 | |
142 shap_values = explainer(X_data) | |
143 | |
144 output_names = getattr(shap_values, "output_names", None) | |
145 if output_names is None and hasattr(model, "classes_"): | |
146 output_names = list(model.classes_) | |
147 if output_names is None: | |
148 n_out = shap_values.values.shape[-1] | |
149 output_names = list(map(str, range(n_out))) | |
150 | |
151 values = shap_values.values | |
152 if values.ndim == 3: | |
153 for j, name in enumerate(output_names): | |
154 safe = name.replace(" ", "_").replace("/", "_") | |
155 out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") | |
156 plt.figure() | |
157 shap.plots.beeswarm(shap_values[..., j], show=False) | |
158 plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") | |
159 plt.savefig(out_path) | |
160 plt.close() | |
161 self.plots[f"shap_summary_{safe}"] = out_path | |
126 else: | 162 else: |
127 used_features = X_transformed.columns | 163 plt.figure() |
128 | 164 shap.plots.beeswarm(shap_values, show=False) |
129 if any(tc in model_class_name for tc in tree_classes): | 165 plt.title(f"SHAP Summary for {model.__class__.__name__}") |
130 explainer = shap.TreeExplainer(model) | 166 out_path = os.path.join(self.output_dir, "shap_summary.png") |
131 X_shap = X_transformed[used_features] | 167 plt.savefig(out_path) |
132 shap_values = explainer.shap_values(X_shap) | 168 plt.close() |
133 plot_X = X_shap | 169 self.plots["shap_summary"] = out_path |
134 plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" | |
135 else: | |
136 logging.warning(f"len(X_transformed) = {len(X_transformed)}") | |
137 max_samples = 100 | |
138 n_samples = min(max_samples, len(X_transformed)) | |
139 sampled_X = X_transformed[used_features].sample( | |
140 n=n_samples, | |
141 replace=False, | |
142 random_state=42 | |
143 ) | |
144 explainer = shap.KernelExplainer(model.predict, sampled_X) | |
145 shap_values = explainer.shap_values(sampled_X) | |
146 plot_X = sampled_X | |
147 plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)" | |
148 | |
149 shap.summary_plot(shap_values, plot_X, show=False) | |
150 plt.title(plot_title) | |
151 plot_path = os.path.join(self.output_dir, "shap_summary.png") | |
152 plt.savefig(plot_path) | |
153 plt.close() | |
154 self.plots["shap_summary"] = plot_path | |
155 | 170 |
156 def generate_html_report(self): | 171 def generate_html_report(self): |
157 LOG.info("Generating HTML report") | 172 LOG.info("Generating HTML report") |
158 | 173 |
159 plots_html = "" | 174 plots_html = "" |
160 for plot_name, plot_path in self.plots.items(): | 175 for plot_name, plot_path in self.plots.items(): |
161 # Special handling for tree importance: skip if no model name (not generated) | 176 # Special handling for tree importance: skip if no model name (not generated) |
162 if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): | 177 if plot_name == 'tree_importance' and not getattr( |
178 self, 'tree_model_name', None | |
179 ): | |
163 continue | 180 continue |
164 encoded_image = self.encode_image_to_base64(plot_path) | 181 encoded_image = self.encode_image_to_base64(plot_path) |
165 if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): | 182 if plot_name == 'tree_importance' and getattr( |
166 section_title = f"Feature importance analysis from a trained {self.tree_model_name}" | 183 self, 'tree_model_name', None |
184 ): | |
185 section_title = ( | |
186 f"Feature importance analysis from a trained {self.tree_model_name}" | |
187 ) | |
167 elif plot_name == 'shap_summary': | 188 elif plot_name == 'shap_summary': |
168 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" | 189 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" |
169 else: | 190 else: |
170 section_title = plot_name | 191 section_title = plot_name |
171 plots_html += f""" | 192 plots_html += f""" |
174 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | 195 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> |
175 </div> | 196 </div> |
176 """ | 197 """ |
177 | 198 |
178 html_content = f""" | 199 html_content = f""" |
179 <h1>PyCaret Feature Importance Report</h1> | |
180 {plots_html} | 200 {plots_html} |
181 """ | 201 """ |
182 | 202 |
183 return html_content | 203 return html_content |
184 | 204 |
185 def encode_image_to_base64(self, img_path): | 205 def encode_image_to_base64(self, img_path): |
186 with open(img_path, 'rb') as img_file: | 206 with open(img_path, 'rb') as img_file: |
187 return base64.b64encode(img_file.read()).decode('utf-8') | 207 return base64.b64encode(img_file.read()).decode('utf-8') |
188 | 208 |
189 def run(self): | 209 def run(self): |
190 if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: | 210 if ( |
211 self.exp is None | |
212 or not hasattr(self.exp, 'is_setup') | |
213 or not self.exp.is_setup | |
214 ): | |
191 self.setup_pycaret() | 215 self.setup_pycaret() |
192 self.save_tree_importance() | 216 self.save_tree_importance() |
193 self.save_shap_values() | 217 self.save_shap_values() |
194 html_content = self.generate_html_report() | 218 html_content = self.generate_html_report() |
195 return html_content | 219 return html_content |