comparison feature_importance.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
comparison
equal deleted inserted replaced
7:f4cb41f458fd 8:1aed7d47c5ec
12 LOG = logging.getLogger(__name__) 12 LOG = logging.getLogger(__name__)
13 13
14 14
15 class FeatureImportanceAnalyzer: 15 class FeatureImportanceAnalyzer:
16 def __init__( 16 def __init__(
17 self, 17 self,
18 task_type, 18 task_type,
19 output_dir, 19 output_dir,
20 data_path=None, 20 data_path=None,
21 data=None, 21 data=None,
22 target_col=None, 22 target_col=None,
23 exp=None, 23 exp=None,
24 best_model=None): 24 best_model=None,
25 ):
25 26
26 self.task_type = task_type 27 self.task_type = task_type
27 self.output_dir = output_dir 28 self.output_dir = output_dir
28 self.exp = exp 29 self.exp = exp
29 self.best_model = best_model 30 self.best_model = best_model
41 self.target_col = target_col 42 self.target_col = target_col
42 self.data = pd.read_csv(data_path, sep=None, engine='python') 43 self.data = pd.read_csv(data_path, sep=None, engine='python')
43 self.data.columns = self.data.columns.str.replace('.', '_') 44 self.data.columns = self.data.columns.str.replace('.', '_')
44 self.data = self.data.fillna(self.data.median(numeric_only=True)) 45 self.data = self.data.fillna(self.data.median(numeric_only=True))
45 self.target = self.data.columns[int(target_col) - 1] 46 self.target = self.data.columns[int(target_col) - 1]
46 self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() 47 self.exp = (
48 ClassificationExperiment()
49 if task_type == "classification"
50 else RegressionExperiment()
51 )
47 52
48 self.plots = {} 53 self.plots = {}
49 54
50 def setup_pycaret(self): 55 def setup_pycaret(self):
51 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: 56 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup:
55 setup_params = { 60 setup_params = {
56 'target': self.target, 61 'target': self.target,
57 'session_id': 123, 62 'session_id': 123,
58 'html': True, 63 'html': True,
59 'log_experiment': False, 64 'log_experiment': False,
60 'system_log': False 65 'system_log': False,
61 } 66 }
62 self.exp.setup(self.data, **setup_params) 67 self.exp.setup(self.data, **setup_params)
63 68
64 def save_tree_importance(self): 69 def save_tree_importance(self):
65 model = self.best_model or self.exp.get_config('best_model') 70 model = self.best_model or self.exp.get_config('best_model')
68 # Try feature_importances_ or coef_ if available 73 # Try feature_importances_ or coef_ if available
69 importances = None 74 importances = None
70 model_type = model.__class__.__name__ 75 model_type = model.__class__.__name__
71 self.tree_model_name = model_type # Store the model name for reporting 76 self.tree_model_name = model_type # Store the model name for reporting
72 77
73 if hasattr(model, "feature_importances_"): 78 if hasattr(model, 'feature_importances_'):
74 importances = model.feature_importances_ 79 importances = model.feature_importances_
75 elif hasattr(model, "coef_"): 80 elif hasattr(model, 'coef_'):
76 # For linear models, flatten coef_ and take abs (importance as magnitude) 81 # For linear models, flatten coef_ and take abs (importance as magnitude)
77 importances = abs(model.coef_).flatten() 82 importances = abs(model.coef_).flatten()
78 else: 83 else:
79 # Neither attribute exists; skip the plot 84 # Neither attribute exists; skip the plot
80 LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") 85 LOG.warning(
86 f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot."
87 )
81 self.tree_model_name = None # No plot generated 88 self.tree_model_name = None # No plot generated
82 return 89 return
83 90
84 # Defensive: handle mismatch in number of features 91 # Defensive: handle mismatch in number of features
85 if len(importances) != len(processed_features): 92 if len(importances) != len(processed_features):
87 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." 94 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot."
88 ) 95 )
89 self.tree_model_name = None 96 self.tree_model_name = None
90 return 97 return
91 98
92 feature_importances = pd.DataFrame({ 99 feature_importances = pd.DataFrame(
93 'Feature': processed_features, 100 {'Feature': processed_features, 'Importance': importances}
94 'Importance': importances 101 ).sort_values(by='Importance', ascending=False)
95 }).sort_values(by='Importance', ascending=False)
96 plt.figure(figsize=(10, 6)) 102 plt.figure(figsize=(10, 6))
97 plt.barh( 103 plt.barh(feature_importances['Feature'], feature_importances['Importance'])
98 feature_importances['Feature'],
99 feature_importances['Importance'])
100 plt.xlabel('Importance') 104 plt.xlabel('Importance')
101 plt.title(f'Feature Importance ({model_type})') 105 plt.title(f'Feature Importance ({model_type})')
102 plot_path = os.path.join( 106 plot_path = os.path.join(self.output_dir, 'tree_importance.png')
103 self.output_dir,
104 'tree_importance.png')
105 plt.savefig(plot_path) 107 plt.savefig(plot_path)
106 plt.close() 108 plt.close()
107 self.plots['tree_importance'] = plot_path 109 self.plots['tree_importance'] = plot_path
108 110
109 def save_shap_values(self): 111 def save_shap_values(self):
110 model = self.best_model or self.exp.get_config('best_model') 112
111 X_transformed = self.exp.get_config('X_transformed') 113 model = self.best_model or self.exp.get_config("best_model")
112 tree_classes = ( 114
113 "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" 115 X_data = None
114 ) 116 for key in ("X_test_transformed", "X_train_transformed"):
115 model_class_name = model.__class__.__name__ 117 try:
116 self.shap_model_name = model_class_name 118 X_data = self.exp.get_config(key)
117 119 break
118 # Ensure feature alignment 120 except KeyError:
119 if hasattr(model, "feature_name_"): 121 continue
120 used_features = model.feature_name_ 122 if X_data is None:
121 elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): 123 raise RuntimeError(
124 "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. "
125 "Make sure PyCaret setup/compare_models was run with feature_selection=True."
126 )
127
128 try:
122 used_features = model.booster_.feature_name() 129 used_features = model.booster_.feature_name()
123 elif hasattr(model, "feature_names_in_"): 130 except Exception:
124 # scikit‐learn's standard attribute for the names of features used during fit 131 used_features = getattr(model, "feature_names_in_", X_data.columns.tolist())
125 used_features = list(model.feature_names_in_) 132 X_data = X_data[used_features]
133
134 max_bg = min(len(X_data), 100)
135 bg = X_data.sample(max_bg, random_state=42)
136
137 predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict
138
139 explainer = shap.Explainer(predict_fn, bg)
140 self.shap_model_name = explainer.__class__.__name__
141
142 shap_values = explainer(X_data)
143
144 output_names = getattr(shap_values, "output_names", None)
145 if output_names is None and hasattr(model, "classes_"):
146 output_names = list(model.classes_)
147 if output_names is None:
148 n_out = shap_values.values.shape[-1]
149 output_names = list(map(str, range(n_out)))
150
151 values = shap_values.values
152 if values.ndim == 3:
153 for j, name in enumerate(output_names):
154 safe = name.replace(" ", "_").replace("/", "_")
155 out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png")
156 plt.figure()
157 shap.plots.beeswarm(shap_values[..., j], show=False)
158 plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}")
159 plt.savefig(out_path)
160 plt.close()
161 self.plots[f"shap_summary_{safe}"] = out_path
126 else: 162 else:
127 used_features = X_transformed.columns 163 plt.figure()
128 164 shap.plots.beeswarm(shap_values, show=False)
129 if any(tc in model_class_name for tc in tree_classes): 165 plt.title(f"SHAP Summary for {model.__class__.__name__}")
130 explainer = shap.TreeExplainer(model) 166 out_path = os.path.join(self.output_dir, "shap_summary.png")
131 X_shap = X_transformed[used_features] 167 plt.savefig(out_path)
132 shap_values = explainer.shap_values(X_shap) 168 plt.close()
133 plot_X = X_shap 169 self.plots["shap_summary"] = out_path
134 plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
135 else:
136 logging.warning(f"len(X_transformed) = {len(X_transformed)}")
137 max_samples = 100
138 n_samples = min(max_samples, len(X_transformed))
139 sampled_X = X_transformed[used_features].sample(
140 n=n_samples,
141 replace=False,
142 random_state=42
143 )
144 explainer = shap.KernelExplainer(model.predict, sampled_X)
145 shap_values = explainer.shap_values(sampled_X)
146 plot_X = sampled_X
147 plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)"
148
149 shap.summary_plot(shap_values, plot_X, show=False)
150 plt.title(plot_title)
151 plot_path = os.path.join(self.output_dir, "shap_summary.png")
152 plt.savefig(plot_path)
153 plt.close()
154 self.plots["shap_summary"] = plot_path
155 170
156 def generate_html_report(self): 171 def generate_html_report(self):
157 LOG.info("Generating HTML report") 172 LOG.info("Generating HTML report")
158 173
159 plots_html = "" 174 plots_html = ""
160 for plot_name, plot_path in self.plots.items(): 175 for plot_name, plot_path in self.plots.items():
161 # Special handling for tree importance: skip if no model name (not generated) 176 # Special handling for tree importance: skip if no model name (not generated)
162 if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): 177 if plot_name == 'tree_importance' and not getattr(
178 self, 'tree_model_name', None
179 ):
163 continue 180 continue
164 encoded_image = self.encode_image_to_base64(plot_path) 181 encoded_image = self.encode_image_to_base64(plot_path)
165 if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): 182 if plot_name == 'tree_importance' and getattr(
166 section_title = f"Feature importance analysis from a trained {self.tree_model_name}" 183 self, 'tree_model_name', None
184 ):
185 section_title = (
186 f"Feature importance analysis from a trained {self.tree_model_name}"
187 )
167 elif plot_name == 'shap_summary': 188 elif plot_name == 'shap_summary':
168 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" 189 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
169 else: 190 else:
170 section_title = plot_name 191 section_title = plot_name
171 plots_html += f""" 192 plots_html += f"""
174 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> 195 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
175 </div> 196 </div>
176 """ 197 """
177 198
178 html_content = f""" 199 html_content = f"""
179 <h1>PyCaret Feature Importance Report</h1>
180 {plots_html} 200 {plots_html}
181 """ 201 """
182 202
183 return html_content 203 return html_content
184 204
185 def encode_image_to_base64(self, img_path): 205 def encode_image_to_base64(self, img_path):
186 with open(img_path, 'rb') as img_file: 206 with open(img_path, 'rb') as img_file:
187 return base64.b64encode(img_file.read()).decode('utf-8') 207 return base64.b64encode(img_file.read()).decode('utf-8')
188 208
189 def run(self): 209 def run(self):
190 if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: 210 if (
211 self.exp is None
212 or not hasattr(self.exp, 'is_setup')
213 or not self.exp.is_setup
214 ):
191 self.setup_pycaret() 215 self.setup_pycaret()
192 self.save_tree_importance() 216 self.save_tree_importance()
193 self.save_shap_values() 217 self.save_shap_values()
194 html_content = self.generate_html_report() 218 html_content = self.generate_html_report()
195 return html_content 219 return html_content