Mercurial > repos > goeckslab > tabular_learner
comparison feature_importance.py @ 8:ba45bc057d70 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author | goeckslab |
---|---|
date | Mon, 08 Sep 2025 22:38:55 +0000 |
parents | 0afd970bd8ae |
children |
comparison
equal
deleted
inserted
replaced
7:0afd970bd8ae | 8:ba45bc057d70 |
---|---|
21 data=None, | 21 data=None, |
22 target_col=None, | 22 target_col=None, |
23 exp=None, | 23 exp=None, |
24 best_model=None, | 24 best_model=None, |
25 ): | 25 ): |
26 | |
27 self.task_type = task_type | 26 self.task_type = task_type |
28 self.output_dir = output_dir | 27 self.output_dir = output_dir |
29 self.exp = exp | 28 self.exp = exp |
30 self.best_model = best_model | 29 self.best_model = best_model |
31 | 30 |
38 if data is not None: | 37 if data is not None: |
39 self.data = data | 38 self.data = data |
40 LOG.info("Data loaded from memory") | 39 LOG.info("Data loaded from memory") |
41 else: | 40 else: |
42 self.target_col = target_col | 41 self.target_col = target_col |
43 self.data = pd.read_csv(data_path, sep=None, engine='python') | 42 self.data = pd.read_csv(data_path, sep=None, engine="python") |
44 self.data.columns = self.data.columns.str.replace('.', '_') | 43 self.data.columns = self.data.columns.str.replace(".", "_") |
45 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 44 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
46 self.target = self.data.columns[int(target_col) - 1] | 45 self.target = self.data.columns[int(target_col) - 1] |
47 self.exp = ( | 46 self.exp = ( |
48 ClassificationExperiment() | 47 ClassificationExperiment() |
49 if task_type == "classification" | 48 if task_type == "classification" |
51 ) | 50 ) |
52 | 51 |
53 self.plots = {} | 52 self.plots = {} |
54 | 53 |
55 def setup_pycaret(self): | 54 def setup_pycaret(self): |
56 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: | 55 if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: |
57 LOG.info("Experiment already set up. Skipping PyCaret setup.") | 56 LOG.info("Experiment already set up. Skipping PyCaret setup.") |
58 return | 57 return |
59 LOG.info("Initializing PyCaret") | 58 LOG.info("Initializing PyCaret") |
60 setup_params = { | 59 setup_params = { |
61 'target': self.target, | 60 "target": self.target, |
62 'session_id': 123, | 61 "session_id": 123, |
63 'html': True, | 62 "html": True, |
64 'log_experiment': False, | 63 "log_experiment": False, |
65 'system_log': False, | 64 "system_log": False, |
66 } | 65 } |
67 self.exp.setup(self.data, **setup_params) | 66 self.exp.setup(self.data, **setup_params) |
68 | 67 |
69 def save_tree_importance(self): | 68 def save_tree_importance(self): |
70 model = self.best_model or self.exp.get_config('best_model') | 69 model = self.best_model or self.exp.get_config("best_model") |
71 processed_features = self.exp.get_config('X_transformed').columns | 70 processed_features = self.exp.get_config("X_transformed").columns |
72 | 71 |
73 # Try feature_importances_ or coef_ if available | |
74 importances = None | 72 importances = None |
75 model_type = model.__class__.__name__ | 73 model_type = model.__class__.__name__ |
76 self.tree_model_name = model_type # Store the model name for reporting | 74 self.tree_model_name = model_type |
77 | 75 |
78 if hasattr(model, 'feature_importances_'): | 76 if hasattr(model, "feature_importances_"): |
79 importances = model.feature_importances_ | 77 importances = model.feature_importances_ |
80 elif hasattr(model, 'coef_'): | 78 elif hasattr(model, "coef_"): |
81 # For linear models, flatten coef_ and take abs (importance as magnitude) | |
82 importances = abs(model.coef_).flatten() | 79 importances = abs(model.coef_).flatten() |
83 else: | 80 else: |
84 # Neither attribute exists; skip the plot | |
85 LOG.warning( | 81 LOG.warning( |
86 f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." | 82 f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance." |
87 ) | 83 ) |
88 self.tree_model_name = None # No plot generated | 84 self.tree_model_name = None |
89 return | 85 return |
90 | 86 |
91 # Defensive: handle mismatch in number of features | |
92 if len(importances) != len(processed_features): | 87 if len(importances) != len(processed_features): |
93 LOG.warning( | 88 LOG.warning( |
94 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." | 89 f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance." |
95 ) | 90 ) |
96 self.tree_model_name = None | 91 self.tree_model_name = None |
97 return | 92 return |
98 | 93 |
99 feature_importances = pd.DataFrame( | 94 feature_importances = pd.DataFrame( |
100 {'Feature': processed_features, 'Importance': importances} | 95 {"Feature": processed_features, "Importance": importances} |
101 ).sort_values(by='Importance', ascending=False) | 96 ).sort_values(by="Importance", ascending=False) |
102 plt.figure(figsize=(10, 6)) | 97 plt.figure(figsize=(10, 6)) |
103 plt.barh(feature_importances['Feature'], feature_importances['Importance']) | 98 plt.barh(feature_importances["Feature"], feature_importances["Importance"]) |
104 plt.xlabel('Importance') | 99 plt.xlabel("Importance") |
105 plt.title(f'Feature Importance ({model_type})') | 100 plt.title(f"Feature Importance ({model_type})") |
106 plot_path = os.path.join(self.output_dir, 'tree_importance.png') | 101 plot_path = os.path.join(self.output_dir, "tree_importance.png") |
107 plt.savefig(plot_path) | 102 plt.savefig(plot_path, bbox_inches="tight") |
108 plt.close() | 103 plt.close() |
109 self.plots['tree_importance'] = plot_path | 104 self.plots["tree_importance"] = plot_path |
110 | 105 |
111 def save_shap_values(self): | 106 def save_shap_values(self, max_samples=None, max_display=None, max_features=None): |
112 | |
113 model = self.best_model or self.exp.get_config("best_model") | 107 model = self.best_model or self.exp.get_config("best_model") |
114 | 108 |
115 X_data = None | 109 X_data = None |
116 for key in ("X_test_transformed", "X_train_transformed"): | 110 for key in ("X_test_transformed", "X_train_transformed"): |
117 try: | 111 try: |
118 X_data = self.exp.get_config(key) | 112 X_data = self.exp.get_config(key) |
119 break | 113 break |
120 except KeyError: | 114 except KeyError: |
121 continue | 115 continue |
122 if X_data is None: | 116 if X_data is None: |
123 raise RuntimeError( | 117 raise RuntimeError("No transformed dataset found for SHAP.") |
124 "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " | 118 |
125 "Make sure PyCaret setup/compare_models was run with feature_selection=True." | 119 # --- Adaptive feature limiting (proportional cap) --- |
126 ) | 120 n_rows, n_features = X_data.shape |
121 if max_features is None: | |
122 if n_features <= 200: | |
123 max_features = n_features | |
124 else: | |
125 max_features = min(200, max(20, int(n_features * 0.1))) | |
127 | 126 |
128 try: | 127 try: |
129 used_features = model.booster_.feature_name() | 128 if hasattr(model, "feature_importances_"): |
130 except Exception: | 129 importances = pd.Series( |
131 used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) | 130 model.feature_importances_, index=X_data.columns |
132 X_data = X_data[used_features] | 131 ) |
133 | 132 top_features = importances.nlargest(max_features).index |
134 max_bg = min(len(X_data), 100) | 133 elif hasattr(model, "coef_"): |
135 bg = X_data.sample(max_bg, random_state=42) | 134 coef = abs(model.coef_).flatten() |
136 | 135 importances = pd.Series(coef, index=X_data.columns) |
137 predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict | 136 top_features = importances.nlargest(max_features).index |
137 else: | |
138 variances = X_data.var() | |
139 top_features = variances.nlargest(max_features).index | |
140 | |
141 if len(top_features) < n_features: | |
142 LOG.info( | |
143 f"Restricted SHAP computation to top {len(top_features)} / {n_features} features" | |
144 ) | |
145 X_data = X_data[top_features] | |
146 except Exception as e: | |
147 LOG.warning( | |
148 f"Feature limiting failed: {e}. Using all {n_features} features." | |
149 ) | |
150 | |
151 # --- Adaptive row subsampling --- | |
152 if max_samples is None: | |
153 if n_rows <= 500: | |
154 max_samples = n_rows | |
155 elif n_rows <= 5000: | |
156 max_samples = 500 | |
157 else: | |
158 max_samples = min(1000, int(n_rows * 0.1)) | |
159 | |
160 if n_rows > max_samples: | |
161 LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") | |
162 X_data = X_data.sample(max_samples, random_state=42) | |
163 | |
164 # --- Adaptive feature display --- | |
165 if max_display is None: | |
166 if X_data.shape[1] <= 20: | |
167 max_display = X_data.shape[1] | |
168 elif X_data.shape[1] <= 100: | |
169 max_display = 30 | |
170 else: | |
171 max_display = 50 | |
172 | |
173 # Background set | |
174 bg = X_data.sample(min(len(X_data), 100), random_state=42) | |
175 predict_fn = ( | |
176 model.predict_proba if hasattr(model, "predict_proba") else model.predict | |
177 ) | |
178 | |
179 # Optimized explainer | |
180 if hasattr(model, "feature_importances_"): | |
181 explainer = shap.TreeExplainer( | |
182 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
183 ) | |
184 elif hasattr(model, "coef_"): | |
185 explainer = shap.LinearExplainer(model, bg) | |
186 else: | |
187 explainer = shap.Explainer(predict_fn, bg) | |
138 | 188 |
139 try: | 189 try: |
140 explainer = shap.Explainer(predict_fn, bg) | 190 shap_values = explainer(X_data) |
141 self.shap_model_name = explainer.__class__.__name__ | 191 self.shap_model_name = explainer.__class__.__name__ |
142 | |
143 shap_values = explainer(X_data) | |
144 except Exception as e: | 192 except Exception as e: |
145 LOG.error(f"SHAP computation failed: {e}") | 193 LOG.error(f"SHAP computation failed: {e}") |
146 self.shap_model_name = None | 194 self.shap_model_name = None |
147 return | 195 return |
148 | 196 |
149 output_names = getattr(shap_values, "output_names", None) | 197 # --- Plot SHAP summary --- |
150 if output_names is None and hasattr(model, "classes_"): | 198 out_path = os.path.join(self.output_dir, "shap_summary.png") |
151 output_names = list(model.classes_) | 199 plt.figure() |
152 if output_names is None: | 200 shap.plots.beeswarm(shap_values, max_display=max_display, show=False) |
153 n_out = shap_values.values.shape[-1] | 201 plt.title( |
154 output_names = list(map(str, range(n_out))) | 202 f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)" |
155 | 203 ) |
156 values = shap_values.values | 204 plt.savefig(out_path, bbox_inches="tight") |
157 if values.ndim == 3: | 205 plt.close() |
158 for j, name in enumerate(output_names): | 206 self.plots["shap_summary"] = out_path |
159 safe = name.replace(" ", "_").replace("/", "_") | 207 |
160 out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") | 208 # --- Log summary --- |
161 plt.figure() | 209 LOG.info( |
162 shap.plots.beeswarm(shap_values[..., j], show=False) | 210 f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})." |
163 plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") | 211 ) |
164 plt.savefig(out_path) | |
165 plt.close() | |
166 self.plots[f"shap_summary_{safe}"] = out_path | |
167 else: | |
168 plt.figure() | |
169 shap.plots.beeswarm(shap_values, show=False) | |
170 plt.title(f"SHAP Summary for {model.__class__.__name__}") | |
171 out_path = os.path.join(self.output_dir, "shap_summary.png") | |
172 plt.savefig(out_path) | |
173 plt.close() | |
174 self.plots["shap_summary"] = out_path | |
175 | 212 |
176 def generate_html_report(self): | 213 def generate_html_report(self): |
177 LOG.info("Generating HTML report") | 214 LOG.info("Generating HTML report") |
178 | |
179 plots_html = "" | 215 plots_html = "" |
180 for plot_name, plot_path in self.plots.items(): | 216 for plot_name, plot_path in self.plots.items(): |
181 # Special handling for tree importance: skip if no model name (not generated) | 217 if plot_name == "tree_importance" and not getattr( |
182 if plot_name == 'tree_importance' and not getattr( | 218 self, "tree_model_name", None |
183 self, 'tree_model_name', None | |
184 ): | 219 ): |
185 continue | 220 continue |
186 encoded_image = self.encode_image_to_base64(plot_path) | 221 encoded_image = self.encode_image_to_base64(plot_path) |
187 if plot_name == 'tree_importance' and getattr( | 222 if plot_name == "tree_importance" and getattr( |
188 self, 'tree_model_name', None | 223 self, "tree_model_name", None |
189 ): | 224 ): |
225 section_title = f"Feature importance from {self.tree_model_name}" | |
226 elif plot_name == "shap_summary": | |
190 section_title = ( | 227 section_title = ( |
191 f"Feature importance analysis from a trained {self.tree_model_name}" | 228 f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" |
192 ) | 229 ) |
193 elif plot_name == 'shap_summary': | |
194 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" | |
195 else: | 230 else: |
196 section_title = plot_name | 231 section_title = plot_name |
197 plots_html += f""" | 232 plots_html += f""" |
198 <div class="plot" id="{plot_name}"> | 233 <div class="plot" id="{plot_name}"> |
199 <h2>{section_title}</h2> | 234 <h2>{section_title}</h2> |
200 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | 235 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> |
201 </div> | 236 </div> |
202 """ | 237 """ |
203 | 238 return f"{plots_html}" |
204 html_content = f""" | |
205 {plots_html} | |
206 """ | |
207 | |
208 return html_content | |
209 | 239 |
210 def encode_image_to_base64(self, img_path): | 240 def encode_image_to_base64(self, img_path): |
211 with open(img_path, 'rb') as img_file: | 241 with open(img_path, "rb") as img_file: |
212 return base64.b64encode(img_file.read()).decode('utf-8') | 242 return base64.b64encode(img_file.read()).decode("utf-8") |
213 | 243 |
214 def run(self): | 244 def run(self): |
215 if ( | 245 if ( |
216 self.exp is None | 246 self.exp is None |
217 or not hasattr(self.exp, 'is_setup') | 247 or not hasattr(self.exp, "is_setup") |
218 or not self.exp.is_setup | 248 or not self.exp.is_setup |
219 ): | 249 ): |
220 self.setup_pycaret() | 250 self.setup_pycaret() |
221 self.save_tree_importance() | 251 self.save_tree_importance() |
222 self.save_shap_values() | 252 self.save_shap_values() |
223 html_content = self.generate_html_report() | 253 return self.generate_html_report() |
224 return html_content |