Mercurial > repos > goeckslab > tabular_learner
comparison feature_importance.py @ 8:ba45bc057d70 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
| author | goeckslab |
|---|---|
| date | Mon, 08 Sep 2025 22:38:55 +0000 |
| parents | 0afd970bd8ae |
| children |
comparison
equal
deleted
inserted
replaced
| 7:0afd970bd8ae | 8:ba45bc057d70 |
|---|---|
| 21 data=None, | 21 data=None, |
| 22 target_col=None, | 22 target_col=None, |
| 23 exp=None, | 23 exp=None, |
| 24 best_model=None, | 24 best_model=None, |
| 25 ): | 25 ): |
| 26 | |
| 27 self.task_type = task_type | 26 self.task_type = task_type |
| 28 self.output_dir = output_dir | 27 self.output_dir = output_dir |
| 29 self.exp = exp | 28 self.exp = exp |
| 30 self.best_model = best_model | 29 self.best_model = best_model |
| 31 | 30 |
| 38 if data is not None: | 37 if data is not None: |
| 39 self.data = data | 38 self.data = data |
| 40 LOG.info("Data loaded from memory") | 39 LOG.info("Data loaded from memory") |
| 41 else: | 40 else: |
| 42 self.target_col = target_col | 41 self.target_col = target_col |
| 43 self.data = pd.read_csv(data_path, sep=None, engine='python') | 42 self.data = pd.read_csv(data_path, sep=None, engine="python") |
| 44 self.data.columns = self.data.columns.str.replace('.', '_') | 43 self.data.columns = self.data.columns.str.replace(".", "_") |
| 45 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 44 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
| 46 self.target = self.data.columns[int(target_col) - 1] | 45 self.target = self.data.columns[int(target_col) - 1] |
| 47 self.exp = ( | 46 self.exp = ( |
| 48 ClassificationExperiment() | 47 ClassificationExperiment() |
| 49 if task_type == "classification" | 48 if task_type == "classification" |
| 51 ) | 50 ) |
| 52 | 51 |
| 53 self.plots = {} | 52 self.plots = {} |
| 54 | 53 |
| 55 def setup_pycaret(self): | 54 def setup_pycaret(self): |
| 56 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: | 55 if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup: |
| 57 LOG.info("Experiment already set up. Skipping PyCaret setup.") | 56 LOG.info("Experiment already set up. Skipping PyCaret setup.") |
| 58 return | 57 return |
| 59 LOG.info("Initializing PyCaret") | 58 LOG.info("Initializing PyCaret") |
| 60 setup_params = { | 59 setup_params = { |
| 61 'target': self.target, | 60 "target": self.target, |
| 62 'session_id': 123, | 61 "session_id": 123, |
| 63 'html': True, | 62 "html": True, |
| 64 'log_experiment': False, | 63 "log_experiment": False, |
| 65 'system_log': False, | 64 "system_log": False, |
| 66 } | 65 } |
| 67 self.exp.setup(self.data, **setup_params) | 66 self.exp.setup(self.data, **setup_params) |
| 68 | 67 |
| 69 def save_tree_importance(self): | 68 def save_tree_importance(self): |
| 70 model = self.best_model or self.exp.get_config('best_model') | 69 model = self.best_model or self.exp.get_config("best_model") |
| 71 processed_features = self.exp.get_config('X_transformed').columns | 70 processed_features = self.exp.get_config("X_transformed").columns |
| 72 | 71 |
| 73 # Try feature_importances_ or coef_ if available | |
| 74 importances = None | 72 importances = None |
| 75 model_type = model.__class__.__name__ | 73 model_type = model.__class__.__name__ |
| 76 self.tree_model_name = model_type # Store the model name for reporting | 74 self.tree_model_name = model_type |
| 77 | 75 |
| 78 if hasattr(model, 'feature_importances_'): | 76 if hasattr(model, "feature_importances_"): |
| 79 importances = model.feature_importances_ | 77 importances = model.feature_importances_ |
| 80 elif hasattr(model, 'coef_'): | 78 elif hasattr(model, "coef_"): |
| 81 # For linear models, flatten coef_ and take abs (importance as magnitude) | |
| 82 importances = abs(model.coef_).flatten() | 79 importances = abs(model.coef_).flatten() |
| 83 else: | 80 else: |
| 84 # Neither attribute exists; skip the plot | |
| 85 LOG.warning( | 81 LOG.warning( |
| 86 f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." | 82 f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance." |
| 87 ) | 83 ) |
| 88 self.tree_model_name = None # No plot generated | 84 self.tree_model_name = None |
| 89 return | 85 return |
| 90 | 86 |
| 91 # Defensive: handle mismatch in number of features | |
| 92 if len(importances) != len(processed_features): | 87 if len(importances) != len(processed_features): |
| 93 LOG.warning( | 88 LOG.warning( |
| 94 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." | 89 f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance." |
| 95 ) | 90 ) |
| 96 self.tree_model_name = None | 91 self.tree_model_name = None |
| 97 return | 92 return |
| 98 | 93 |
| 99 feature_importances = pd.DataFrame( | 94 feature_importances = pd.DataFrame( |
| 100 {'Feature': processed_features, 'Importance': importances} | 95 {"Feature": processed_features, "Importance": importances} |
| 101 ).sort_values(by='Importance', ascending=False) | 96 ).sort_values(by="Importance", ascending=False) |
| 102 plt.figure(figsize=(10, 6)) | 97 plt.figure(figsize=(10, 6)) |
| 103 plt.barh(feature_importances['Feature'], feature_importances['Importance']) | 98 plt.barh(feature_importances["Feature"], feature_importances["Importance"]) |
| 104 plt.xlabel('Importance') | 99 plt.xlabel("Importance") |
| 105 plt.title(f'Feature Importance ({model_type})') | 100 plt.title(f"Feature Importance ({model_type})") |
| 106 plot_path = os.path.join(self.output_dir, 'tree_importance.png') | 101 plot_path = os.path.join(self.output_dir, "tree_importance.png") |
| 107 plt.savefig(plot_path) | 102 plt.savefig(plot_path, bbox_inches="tight") |
| 108 plt.close() | 103 plt.close() |
| 109 self.plots['tree_importance'] = plot_path | 104 self.plots["tree_importance"] = plot_path |
| 110 | 105 |
| 111 def save_shap_values(self): | 106 def save_shap_values(self, max_samples=None, max_display=None, max_features=None): |
| 112 | |
| 113 model = self.best_model or self.exp.get_config("best_model") | 107 model = self.best_model or self.exp.get_config("best_model") |
| 114 | 108 |
| 115 X_data = None | 109 X_data = None |
| 116 for key in ("X_test_transformed", "X_train_transformed"): | 110 for key in ("X_test_transformed", "X_train_transformed"): |
| 117 try: | 111 try: |
| 118 X_data = self.exp.get_config(key) | 112 X_data = self.exp.get_config(key) |
| 119 break | 113 break |
| 120 except KeyError: | 114 except KeyError: |
| 121 continue | 115 continue |
| 122 if X_data is None: | 116 if X_data is None: |
| 123 raise RuntimeError( | 117 raise RuntimeError("No transformed dataset found for SHAP.") |
| 124 "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " | 118 |
| 125 "Make sure PyCaret setup/compare_models was run with feature_selection=True." | 119 # --- Adaptive feature limiting (proportional cap) --- |
| 126 ) | 120 n_rows, n_features = X_data.shape |
| 121 if max_features is None: | |
| 122 if n_features <= 200: | |
| 123 max_features = n_features | |
| 124 else: | |
| 125 max_features = min(200, max(20, int(n_features * 0.1))) | |
| 127 | 126 |
| 128 try: | 127 try: |
| 129 used_features = model.booster_.feature_name() | 128 if hasattr(model, "feature_importances_"): |
| 130 except Exception: | 129 importances = pd.Series( |
| 131 used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) | 130 model.feature_importances_, index=X_data.columns |
| 132 X_data = X_data[used_features] | 131 ) |
| 133 | 132 top_features = importances.nlargest(max_features).index |
| 134 max_bg = min(len(X_data), 100) | 133 elif hasattr(model, "coef_"): |
| 135 bg = X_data.sample(max_bg, random_state=42) | 134 coef = abs(model.coef_).flatten() |
| 136 | 135 importances = pd.Series(coef, index=X_data.columns) |
| 137 predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict | 136 top_features = importances.nlargest(max_features).index |
| 137 else: | |
| 138 variances = X_data.var() | |
| 139 top_features = variances.nlargest(max_features).index | |
| 140 | |
| 141 if len(top_features) < n_features: | |
| 142 LOG.info( | |
| 143 f"Restricted SHAP computation to top {len(top_features)} / {n_features} features" | |
| 144 ) | |
| 145 X_data = X_data[top_features] | |
| 146 except Exception as e: | |
| 147 LOG.warning( | |
| 148 f"Feature limiting failed: {e}. Using all {n_features} features." | |
| 149 ) | |
| 150 | |
| 151 # --- Adaptive row subsampling --- | |
| 152 if max_samples is None: | |
| 153 if n_rows <= 500: | |
| 154 max_samples = n_rows | |
| 155 elif n_rows <= 5000: | |
| 156 max_samples = 500 | |
| 157 else: | |
| 158 max_samples = min(1000, int(n_rows * 0.1)) | |
| 159 | |
| 160 if n_rows > max_samples: | |
| 161 LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}") | |
| 162 X_data = X_data.sample(max_samples, random_state=42) | |
| 163 | |
| 164 # --- Adaptive feature display --- | |
| 165 if max_display is None: | |
| 166 if X_data.shape[1] <= 20: | |
| 167 max_display = X_data.shape[1] | |
| 168 elif X_data.shape[1] <= 100: | |
| 169 max_display = 30 | |
| 170 else: | |
| 171 max_display = 50 | |
| 172 | |
| 173 # Background set | |
| 174 bg = X_data.sample(min(len(X_data), 100), random_state=42) | |
| 175 predict_fn = ( | |
| 176 model.predict_proba if hasattr(model, "predict_proba") else model.predict | |
| 177 ) | |
| 178 | |
| 179 # Optimized explainer | |
| 180 if hasattr(model, "feature_importances_"): | |
| 181 explainer = shap.TreeExplainer( | |
| 182 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1 | |
| 183 ) | |
| 184 elif hasattr(model, "coef_"): | |
| 185 explainer = shap.LinearExplainer(model, bg) | |
| 186 else: | |
| 187 explainer = shap.Explainer(predict_fn, bg) | |
| 138 | 188 |
| 139 try: | 189 try: |
| 140 explainer = shap.Explainer(predict_fn, bg) | 190 shap_values = explainer(X_data) |
| 141 self.shap_model_name = explainer.__class__.__name__ | 191 self.shap_model_name = explainer.__class__.__name__ |
| 142 | |
| 143 shap_values = explainer(X_data) | |
| 144 except Exception as e: | 192 except Exception as e: |
| 145 LOG.error(f"SHAP computation failed: {e}") | 193 LOG.error(f"SHAP computation failed: {e}") |
| 146 self.shap_model_name = None | 194 self.shap_model_name = None |
| 147 return | 195 return |
| 148 | 196 |
| 149 output_names = getattr(shap_values, "output_names", None) | 197 # --- Plot SHAP summary --- |
| 150 if output_names is None and hasattr(model, "classes_"): | 198 out_path = os.path.join(self.output_dir, "shap_summary.png") |
| 151 output_names = list(model.classes_) | 199 plt.figure() |
| 152 if output_names is None: | 200 shap.plots.beeswarm(shap_values, max_display=max_display, show=False) |
| 153 n_out = shap_values.values.shape[-1] | 201 plt.title( |
| 154 output_names = list(map(str, range(n_out))) | 202 f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)" |
| 155 | 203 ) |
| 156 values = shap_values.values | 204 plt.savefig(out_path, bbox_inches="tight") |
| 157 if values.ndim == 3: | 205 plt.close() |
| 158 for j, name in enumerate(output_names): | 206 self.plots["shap_summary"] = out_path |
| 159 safe = name.replace(" ", "_").replace("/", "_") | 207 |
| 160 out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") | 208 # --- Log summary --- |
| 161 plt.figure() | 209 LOG.info( |
| 162 shap.plots.beeswarm(shap_values[..., j], show=False) | 210 f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})." |
| 163 plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") | 211 ) |
| 164 plt.savefig(out_path) | |
| 165 plt.close() | |
| 166 self.plots[f"shap_summary_{safe}"] = out_path | |
| 167 else: | |
| 168 plt.figure() | |
| 169 shap.plots.beeswarm(shap_values, show=False) | |
| 170 plt.title(f"SHAP Summary for {model.__class__.__name__}") | |
| 171 out_path = os.path.join(self.output_dir, "shap_summary.png") | |
| 172 plt.savefig(out_path) | |
| 173 plt.close() | |
| 174 self.plots["shap_summary"] = out_path | |
| 175 | 212 |
| 176 def generate_html_report(self): | 213 def generate_html_report(self): |
| 177 LOG.info("Generating HTML report") | 214 LOG.info("Generating HTML report") |
| 178 | |
| 179 plots_html = "" | 215 plots_html = "" |
| 180 for plot_name, plot_path in self.plots.items(): | 216 for plot_name, plot_path in self.plots.items(): |
| 181 # Special handling for tree importance: skip if no model name (not generated) | 217 if plot_name == "tree_importance" and not getattr( |
| 182 if plot_name == 'tree_importance' and not getattr( | 218 self, "tree_model_name", None |
| 183 self, 'tree_model_name', None | |
| 184 ): | 219 ): |
| 185 continue | 220 continue |
| 186 encoded_image = self.encode_image_to_base64(plot_path) | 221 encoded_image = self.encode_image_to_base64(plot_path) |
| 187 if plot_name == 'tree_importance' and getattr( | 222 if plot_name == "tree_importance" and getattr( |
| 188 self, 'tree_model_name', None | 223 self, "tree_model_name", None |
| 189 ): | 224 ): |
| 225 section_title = f"Feature importance from {self.tree_model_name}" | |
| 226 elif plot_name == "shap_summary": | |
| 190 section_title = ( | 227 section_title = ( |
| 191 f"Feature importance analysis from a trained {self.tree_model_name}" | 228 f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}" |
| 192 ) | 229 ) |
| 193 elif plot_name == 'shap_summary': | |
| 194 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" | |
| 195 else: | 230 else: |
| 196 section_title = plot_name | 231 section_title = plot_name |
| 197 plots_html += f""" | 232 plots_html += f""" |
| 198 <div class="plot" id="{plot_name}"> | 233 <div class="plot" id="{plot_name}"> |
| 199 <h2>{section_title}</h2> | 234 <h2>{section_title}</h2> |
| 200 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | 235 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> |
| 201 </div> | 236 </div> |
| 202 """ | 237 """ |
| 203 | 238 return f"{plots_html}" |
| 204 html_content = f""" | |
| 205 {plots_html} | |
| 206 """ | |
| 207 | |
| 208 return html_content | |
| 209 | 239 |
| 210 def encode_image_to_base64(self, img_path): | 240 def encode_image_to_base64(self, img_path): |
| 211 with open(img_path, 'rb') as img_file: | 241 with open(img_path, "rb") as img_file: |
| 212 return base64.b64encode(img_file.read()).decode('utf-8') | 242 return base64.b64encode(img_file.read()).decode("utf-8") |
| 213 | 243 |
| 214 def run(self): | 244 def run(self): |
| 215 if ( | 245 if ( |
| 216 self.exp is None | 246 self.exp is None |
| 217 or not hasattr(self.exp, 'is_setup') | 247 or not hasattr(self.exp, "is_setup") |
| 218 or not self.exp.is_setup | 248 or not self.exp.is_setup |
| 219 ): | 249 ): |
| 220 self.setup_pycaret() | 250 self.setup_pycaret() |
| 221 self.save_tree_importance() | 251 self.save_tree_importance() |
| 222 self.save_shap_values() | 252 self.save_shap_values() |
| 223 html_content = self.generate_html_report() | 253 return self.generate_html_report() |
| 224 return html_content |
