comparison feature_importance.py @ 8:ba45bc057d70 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 1594d503179f28987720594eb49b48a15486f073
author goeckslab
date Mon, 08 Sep 2025 22:38:55 +0000
parents 0afd970bd8ae
children
comparison
equal deleted inserted replaced
7:0afd970bd8ae 8:ba45bc057d70
21 data=None, 21 data=None,
22 target_col=None, 22 target_col=None,
23 exp=None, 23 exp=None,
24 best_model=None, 24 best_model=None,
25 ): 25 ):
26
27 self.task_type = task_type 26 self.task_type = task_type
28 self.output_dir = output_dir 27 self.output_dir = output_dir
29 self.exp = exp 28 self.exp = exp
30 self.best_model = best_model 29 self.best_model = best_model
31 30
38 if data is not None: 37 if data is not None:
39 self.data = data 38 self.data = data
40 LOG.info("Data loaded from memory") 39 LOG.info("Data loaded from memory")
41 else: 40 else:
42 self.target_col = target_col 41 self.target_col = target_col
43 self.data = pd.read_csv(data_path, sep=None, engine='python') 42 self.data = pd.read_csv(data_path, sep=None, engine="python")
44 self.data.columns = self.data.columns.str.replace('.', '_') 43 self.data.columns = self.data.columns.str.replace(".", "_")
45 self.data = self.data.fillna(self.data.median(numeric_only=True)) 44 self.data = self.data.fillna(self.data.median(numeric_only=True))
46 self.target = self.data.columns[int(target_col) - 1] 45 self.target = self.data.columns[int(target_col) - 1]
47 self.exp = ( 46 self.exp = (
48 ClassificationExperiment() 47 ClassificationExperiment()
49 if task_type == "classification" 48 if task_type == "classification"
51 ) 50 )
52 51
53 self.plots = {} 52 self.plots = {}
54 53
55 def setup_pycaret(self): 54 def setup_pycaret(self):
56 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: 55 if self.exp is not None and hasattr(self.exp, "is_setup") and self.exp.is_setup:
57 LOG.info("Experiment already set up. Skipping PyCaret setup.") 56 LOG.info("Experiment already set up. Skipping PyCaret setup.")
58 return 57 return
59 LOG.info("Initializing PyCaret") 58 LOG.info("Initializing PyCaret")
60 setup_params = { 59 setup_params = {
61 'target': self.target, 60 "target": self.target,
62 'session_id': 123, 61 "session_id": 123,
63 'html': True, 62 "html": True,
64 'log_experiment': False, 63 "log_experiment": False,
65 'system_log': False, 64 "system_log": False,
66 } 65 }
67 self.exp.setup(self.data, **setup_params) 66 self.exp.setup(self.data, **setup_params)
68 67
69 def save_tree_importance(self): 68 def save_tree_importance(self):
70 model = self.best_model or self.exp.get_config('best_model') 69 model = self.best_model or self.exp.get_config("best_model")
71 processed_features = self.exp.get_config('X_transformed').columns 70 processed_features = self.exp.get_config("X_transformed").columns
72 71
73 # Try feature_importances_ or coef_ if available
74 importances = None 72 importances = None
75 model_type = model.__class__.__name__ 73 model_type = model.__class__.__name__
76 self.tree_model_name = model_type # Store the model name for reporting 74 self.tree_model_name = model_type
77 75
78 if hasattr(model, 'feature_importances_'): 76 if hasattr(model, "feature_importances_"):
79 importances = model.feature_importances_ 77 importances = model.feature_importances_
80 elif hasattr(model, 'coef_'): 78 elif hasattr(model, "coef_"):
81 # For linear models, flatten coef_ and take abs (importance as magnitude)
82 importances = abs(model.coef_).flatten() 79 importances = abs(model.coef_).flatten()
83 else: 80 else:
84 # Neither attribute exists; skip the plot
85 LOG.warning( 81 LOG.warning(
86 f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." 82 f"Model {model_type} does not have feature_importances_ or coef_. Skipping tree importance."
87 ) 83 )
88 self.tree_model_name = None # No plot generated 84 self.tree_model_name = None
89 return 85 return
90 86
91 # Defensive: handle mismatch in number of features
92 if len(importances) != len(processed_features): 87 if len(importances) != len(processed_features):
93 LOG.warning( 88 LOG.warning(
94 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." 89 f"Importances ({len(importances)}) != features ({len(processed_features)}). Skipping tree importance."
95 ) 90 )
96 self.tree_model_name = None 91 self.tree_model_name = None
97 return 92 return
98 93
99 feature_importances = pd.DataFrame( 94 feature_importances = pd.DataFrame(
100 {'Feature': processed_features, 'Importance': importances} 95 {"Feature": processed_features, "Importance": importances}
101 ).sort_values(by='Importance', ascending=False) 96 ).sort_values(by="Importance", ascending=False)
102 plt.figure(figsize=(10, 6)) 97 plt.figure(figsize=(10, 6))
103 plt.barh(feature_importances['Feature'], feature_importances['Importance']) 98 plt.barh(feature_importances["Feature"], feature_importances["Importance"])
104 plt.xlabel('Importance') 99 plt.xlabel("Importance")
105 plt.title(f'Feature Importance ({model_type})') 100 plt.title(f"Feature Importance ({model_type})")
106 plot_path = os.path.join(self.output_dir, 'tree_importance.png') 101 plot_path = os.path.join(self.output_dir, "tree_importance.png")
107 plt.savefig(plot_path) 102 plt.savefig(plot_path, bbox_inches="tight")
108 plt.close() 103 plt.close()
109 self.plots['tree_importance'] = plot_path 104 self.plots["tree_importance"] = plot_path
110 105
111 def save_shap_values(self): 106 def save_shap_values(self, max_samples=None, max_display=None, max_features=None):
112
113 model = self.best_model or self.exp.get_config("best_model") 107 model = self.best_model or self.exp.get_config("best_model")
114 108
115 X_data = None 109 X_data = None
116 for key in ("X_test_transformed", "X_train_transformed"): 110 for key in ("X_test_transformed", "X_train_transformed"):
117 try: 111 try:
118 X_data = self.exp.get_config(key) 112 X_data = self.exp.get_config(key)
119 break 113 break
120 except KeyError: 114 except KeyError:
121 continue 115 continue
122 if X_data is None: 116 if X_data is None:
123 raise RuntimeError( 117 raise RuntimeError("No transformed dataset found for SHAP.")
124 "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " 118
125 "Make sure PyCaret setup/compare_models was run with feature_selection=True." 119 # --- Adaptive feature limiting (proportional cap) ---
126 ) 120 n_rows, n_features = X_data.shape
121 if max_features is None:
122 if n_features <= 200:
123 max_features = n_features
124 else:
125 max_features = min(200, max(20, int(n_features * 0.1)))
127 126
128 try: 127 try:
129 used_features = model.booster_.feature_name() 128 if hasattr(model, "feature_importances_"):
130 except Exception: 129 importances = pd.Series(
131 used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) 130 model.feature_importances_, index=X_data.columns
132 X_data = X_data[used_features] 131 )
133 132 top_features = importances.nlargest(max_features).index
134 max_bg = min(len(X_data), 100) 133 elif hasattr(model, "coef_"):
135 bg = X_data.sample(max_bg, random_state=42) 134 coef = abs(model.coef_).flatten()
136 135 importances = pd.Series(coef, index=X_data.columns)
137 predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict 136 top_features = importances.nlargest(max_features).index
137 else:
138 variances = X_data.var()
139 top_features = variances.nlargest(max_features).index
140
141 if len(top_features) < n_features:
142 LOG.info(
143 f"Restricted SHAP computation to top {len(top_features)} / {n_features} features"
144 )
145 X_data = X_data[top_features]
146 except Exception as e:
147 LOG.warning(
148 f"Feature limiting failed: {e}. Using all {n_features} features."
149 )
150
151 # --- Adaptive row subsampling ---
152 if max_samples is None:
153 if n_rows <= 500:
154 max_samples = n_rows
155 elif n_rows <= 5000:
156 max_samples = 500
157 else:
158 max_samples = min(1000, int(n_rows * 0.1))
159
160 if n_rows > max_samples:
161 LOG.info(f"Subsampling SHAP rows: {max_samples} of {n_rows}")
162 X_data = X_data.sample(max_samples, random_state=42)
163
164 # --- Adaptive feature display ---
165 if max_display is None:
166 if X_data.shape[1] <= 20:
167 max_display = X_data.shape[1]
168 elif X_data.shape[1] <= 100:
169 max_display = 30
170 else:
171 max_display = 50
172
173 # Background set
174 bg = X_data.sample(min(len(X_data), 100), random_state=42)
175 predict_fn = (
176 model.predict_proba if hasattr(model, "predict_proba") else model.predict
177 )
178
179 # Optimized explainer
180 if hasattr(model, "feature_importances_"):
181 explainer = shap.TreeExplainer(
182 model, bg, feature_perturbation="tree_path_dependent", n_jobs=-1
183 )
184 elif hasattr(model, "coef_"):
185 explainer = shap.LinearExplainer(model, bg)
186 else:
187 explainer = shap.Explainer(predict_fn, bg)
138 188
139 try: 189 try:
140 explainer = shap.Explainer(predict_fn, bg) 190 shap_values = explainer(X_data)
141 self.shap_model_name = explainer.__class__.__name__ 191 self.shap_model_name = explainer.__class__.__name__
142
143 shap_values = explainer(X_data)
144 except Exception as e: 192 except Exception as e:
145 LOG.error(f"SHAP computation failed: {e}") 193 LOG.error(f"SHAP computation failed: {e}")
146 self.shap_model_name = None 194 self.shap_model_name = None
147 return 195 return
148 196
149 output_names = getattr(shap_values, "output_names", None) 197 # --- Plot SHAP summary ---
150 if output_names is None and hasattr(model, "classes_"): 198 out_path = os.path.join(self.output_dir, "shap_summary.png")
151 output_names = list(model.classes_) 199 plt.figure()
152 if output_names is None: 200 shap.plots.beeswarm(shap_values, max_display=max_display, show=False)
153 n_out = shap_values.values.shape[-1] 201 plt.title(
154 output_names = list(map(str, range(n_out))) 202 f"SHAP Summary for {model.__class__.__name__} (top {max_display} features)"
155 203 )
156 values = shap_values.values 204 plt.savefig(out_path, bbox_inches="tight")
157 if values.ndim == 3: 205 plt.close()
158 for j, name in enumerate(output_names): 206 self.plots["shap_summary"] = out_path
159 safe = name.replace(" ", "_").replace("/", "_") 207
160 out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") 208 # --- Log summary ---
161 plt.figure() 209 LOG.info(
162 shap.plots.beeswarm(shap_values[..., j], show=False) 210 f"SHAP summary completed with {X_data.shape[0]} rows and {X_data.shape[1]} features (displaying top {max_display})."
163 plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") 211 )
164 plt.savefig(out_path)
165 plt.close()
166 self.plots[f"shap_summary_{safe}"] = out_path
167 else:
168 plt.figure()
169 shap.plots.beeswarm(shap_values, show=False)
170 plt.title(f"SHAP Summary for {model.__class__.__name__}")
171 out_path = os.path.join(self.output_dir, "shap_summary.png")
172 plt.savefig(out_path)
173 plt.close()
174 self.plots["shap_summary"] = out_path
175 212
176 def generate_html_report(self): 213 def generate_html_report(self):
177 LOG.info("Generating HTML report") 214 LOG.info("Generating HTML report")
178
179 plots_html = "" 215 plots_html = ""
180 for plot_name, plot_path in self.plots.items(): 216 for plot_name, plot_path in self.plots.items():
181 # Special handling for tree importance: skip if no model name (not generated) 217 if plot_name == "tree_importance" and not getattr(
182 if plot_name == 'tree_importance' and not getattr( 218 self, "tree_model_name", None
183 self, 'tree_model_name', None
184 ): 219 ):
185 continue 220 continue
186 encoded_image = self.encode_image_to_base64(plot_path) 221 encoded_image = self.encode_image_to_base64(plot_path)
187 if plot_name == 'tree_importance' and getattr( 222 if plot_name == "tree_importance" and getattr(
188 self, 'tree_model_name', None 223 self, "tree_model_name", None
189 ): 224 ):
225 section_title = f"Feature importance from {self.tree_model_name}"
226 elif plot_name == "shap_summary":
190 section_title = ( 227 section_title = (
191 f"Feature importance analysis from a trained {self.tree_model_name}" 228 f"SHAP Summary from {getattr(self, 'shap_model_name', 'model')}"
192 ) 229 )
193 elif plot_name == 'shap_summary':
194 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
195 else: 230 else:
196 section_title = plot_name 231 section_title = plot_name
197 plots_html += f""" 232 plots_html += f"""
198 <div class="plot" id="{plot_name}"> 233 <div class="plot" id="{plot_name}">
199 <h2>{section_title}</h2> 234 <h2>{section_title}</h2>
200 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> 235 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
201 </div> 236 </div>
202 """ 237 """
203 238 return f"{plots_html}"
204 html_content = f"""
205 {plots_html}
206 """
207
208 return html_content
209 239
210 def encode_image_to_base64(self, img_path): 240 def encode_image_to_base64(self, img_path):
211 with open(img_path, 'rb') as img_file: 241 with open(img_path, "rb") as img_file:
212 return base64.b64encode(img_file.read()).decode('utf-8') 242 return base64.b64encode(img_file.read()).decode("utf-8")
213 243
214 def run(self): 244 def run(self):
215 if ( 245 if (
216 self.exp is None 246 self.exp is None
217 or not hasattr(self.exp, 'is_setup') 247 or not hasattr(self.exp, "is_setup")
218 or not self.exp.is_setup 248 or not self.exp.is_setup
219 ): 249 ):
220 self.setup_pycaret() 250 self.setup_pycaret()
221 self.save_tree_importance() 251 self.save_tree_importance()
222 self.save_shap_values() 252 self.save_shap_values()
223 html_content = self.generate_html_report() 253 return self.generate_html_report()
224 return html_content