comparison feature_importance.py @ 6:a32ff7201629 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 19:00:03 +0000
parents ccd798db5abb
children
comparison
equal deleted inserted replaced
5:c846405830eb 6:a32ff7201629
2 import logging 2 import logging
3 import os 3 import os
4 4
5 import matplotlib.pyplot as plt 5 import matplotlib.pyplot as plt
6 import pandas as pd 6 import pandas as pd
7 import shap
7 from pycaret.classification import ClassificationExperiment 8 from pycaret.classification import ClassificationExperiment
8 from pycaret.regression import RegressionExperiment 9 from pycaret.regression import RegressionExperiment
9 10
10 logging.basicConfig(level=logging.DEBUG) 11 logging.basicConfig(level=logging.DEBUG)
11 LOG = logging.getLogger(__name__) 12 LOG = logging.getLogger(__name__)
16 self, 17 self,
17 task_type, 18 task_type,
18 output_dir, 19 output_dir,
19 data_path=None, 20 data_path=None,
20 data=None, 21 data=None,
21 target_col=None): 22 target_col=None,
23 exp=None,
24 best_model=None):
22 25
23 if data is not None: 26 self.task_type = task_type
24 self.data = data 27 self.output_dir = output_dir
25 LOG.info("Data loaded from memory") 28 self.exp = exp
29 self.best_model = best_model
30
31 if exp is not None:
32 # Assume all configs (data, target) are in exp
33 self.data = exp.dataset.copy()
34 self.target = exp.target_param
35 LOG.info("Using provided experiment object")
26 else: 36 else:
27 self.target_col = target_col 37 if data is not None:
28 self.data = pd.read_csv(data_path, sep=None, engine='python') 38 self.data = data
29 self.data.columns = self.data.columns.str.replace('.', '_') 39 LOG.info("Data loaded from memory")
30 self.data = self.data.fillna(self.data.median(numeric_only=True)) 40 else:
31 self.task_type = task_type 41 self.target_col = target_col
32 self.target = self.data.columns[int(target_col) - 1] 42 self.data = pd.read_csv(data_path, sep=None, engine='python')
33 self.exp = ClassificationExperiment() \ 43 self.data.columns = self.data.columns.str.replace('.', '_')
34 if task_type == 'classification' \ 44 self.data = self.data.fillna(self.data.median(numeric_only=True))
35 else RegressionExperiment() 45 self.target = self.data.columns[int(target_col) - 1]
46 self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
47
36 self.plots = {} 48 self.plots = {}
37 self.output_dir = output_dir
38 49
39 def setup_pycaret(self): 50 def setup_pycaret(self):
51 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup:
52 LOG.info("Experiment already set up. Skipping PyCaret setup.")
53 return
40 LOG.info("Initializing PyCaret") 54 LOG.info("Initializing PyCaret")
41 setup_params = { 55 setup_params = {
42 'target': self.target, 56 'target': self.target,
43 'session_id': 123, 57 'session_id': 123,
44 'html': True, 58 'html': True,
45 'log_experiment': False, 59 'log_experiment': False,
46 'system_log': False 60 'system_log': False
47 } 61 }
48 LOG.info(self.task_type)
49 LOG.info(self.exp)
50 self.exp.setup(self.data, **setup_params) 62 self.exp.setup(self.data, **setup_params)
51 63
52 # def save_coefficients(self): 64 def save_tree_importance(self):
53 # model = self.exp.create_model('lr') 65 model = self.best_model or self.exp.get_config('best_model')
54 # coef_df = pd.DataFrame({ 66 processed_features = self.exp.get_config('X_transformed').columns
55 # 'Feature': self.data.columns.drop(self.target),
56 # 'Coefficient': model.coef_[0]
57 # })
58 # coef_html = coef_df.to_html(index=False)
59 # return coef_html
60 67
61 def save_tree_importance(self): 68 # Try feature_importances_ or coef_ if available
62 model = self.exp.create_model('rf') 69 importances = None
63 importances = model.feature_importances_ 70 model_type = model.__class__.__name__
64 processed_features = self.exp.get_config('X_transformed').columns 71 self.tree_model_name = model_type # Store the model name for reporting
65 LOG.debug(f"Feature importances: {importances}") 72
66 LOG.debug(f"Features: {processed_features}") 73 if hasattr(model, "feature_importances_"):
74 importances = model.feature_importances_
75 elif hasattr(model, "coef_"):
76 # For linear models, flatten coef_ and take abs (importance as magnitude)
77 importances = abs(model.coef_).flatten()
78 else:
79 # Neither attribute exists; skip the plot
80 LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.")
81 self.tree_model_name = None # No plot generated
82 return
83
84 # Defensive: handle mismatch in number of features
85 if len(importances) != len(processed_features):
86 LOG.warning(
87 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot."
88 )
89 self.tree_model_name = None
90 return
91
67 feature_importances = pd.DataFrame({ 92 feature_importances = pd.DataFrame({
68 'Feature': processed_features, 93 'Feature': processed_features,
69 'Importance': importances 94 'Importance': importances
70 }).sort_values(by='Importance', ascending=False) 95 }).sort_values(by='Importance', ascending=False)
71 plt.figure(figsize=(10, 6)) 96 plt.figure(figsize=(10, 6))
72 plt.barh( 97 plt.barh(
73 feature_importances['Feature'], 98 feature_importances['Feature'],
74 feature_importances['Importance']) 99 feature_importances['Importance'])
75 plt.xlabel('Importance') 100 plt.xlabel('Importance')
76 plt.title('Feature Importance (Random Forest)') 101 plt.title(f'Feature Importance ({model_type})')
77 plot_path = os.path.join( 102 plot_path = os.path.join(
78 self.output_dir, 103 self.output_dir,
79 'tree_importance.png') 104 'tree_importance.png')
80 plt.savefig(plot_path) 105 plt.savefig(plot_path)
81 plt.close() 106 plt.close()
82 self.plots['tree_importance'] = plot_path 107 self.plots['tree_importance'] = plot_path
83 108
84 def save_shap_values(self): 109 def save_shap_values(self):
85 model = self.exp.create_model('lightgbm') 110 model = self.best_model or self.exp.get_config('best_model')
86 import shap 111 X_transformed = self.exp.get_config('X_transformed')
87 explainer = shap.Explainer(model) 112 tree_classes = (
88 shap_values = explainer.shap_values( 113 "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting"
89 self.exp.get_config('X_transformed')) 114 )
90 shap.summary_plot(shap_values, 115 model_class_name = model.__class__.__name__
91 self.exp.get_config('X_transformed'), show=False) 116 self.shap_model_name = model_class_name
92 plt.title('Shap (LightGBM)') 117
93 plot_path = os.path.join( 118 # Ensure feature alignment
94 self.output_dir, 'shap_summary.png') 119 if hasattr(model, "feature_name_"):
120 used_features = model.feature_name_
121 elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
122 used_features = model.booster_.feature_name()
123 else:
124 used_features = X_transformed.columns
125
126 if any(tc in model_class_name for tc in tree_classes):
127 explainer = shap.TreeExplainer(model)
128 X_shap = X_transformed[used_features]
129 shap_values = explainer.shap_values(X_shap)
130 plot_X = X_shap
131 plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
132 else:
133 sampled_X = X_transformed[used_features].sample(100, random_state=42)
134 explainer = shap.KernelExplainer(model.predict, sampled_X)
135 shap_values = explainer.shap_values(sampled_X)
136 plot_X = sampled_X
137 plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)"
138
139 shap.summary_plot(shap_values, plot_X, show=False)
140 plt.title(plot_title)
141 plot_path = os.path.join(self.output_dir, "shap_summary.png")
95 plt.savefig(plot_path) 142 plt.savefig(plot_path)
96 plt.close() 143 plt.close()
97 self.plots['shap_summary'] = plot_path 144 self.plots["shap_summary"] = plot_path
98
99 def generate_feature_importance(self):
100 # coef_html = self.save_coefficients()
101 self.save_tree_importance()
102 self.save_shap_values()
103
104 def encode_image_to_base64(self, img_path):
105 with open(img_path, 'rb') as img_file:
106 return base64.b64encode(img_file.read()).decode('utf-8')
107 145
108 def generate_html_report(self): 146 def generate_html_report(self):
109 LOG.info("Generating HTML report") 147 LOG.info("Generating HTML report")
110 148
111 # Read and encode plot images
112 plots_html = "" 149 plots_html = ""
113 for plot_name, plot_path in self.plots.items(): 150 for plot_name, plot_path in self.plots.items():
151 # Special handling for tree importance: skip if no model name (not generated)
152 if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None):
153 continue
114 encoded_image = self.encode_image_to_base64(plot_path) 154 encoded_image = self.encode_image_to_base64(plot_path)
155 if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None):
156 section_title = f"Feature importance analysis from a trained {self.tree_model_name}"
157 elif plot_name == 'shap_summary':
158 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
159 else:
160 section_title = plot_name
115 plots_html += f""" 161 plots_html += f"""
116 <div class="plot" id="{plot_name}"> 162 <div class="plot" id="{plot_name}">
117 <h2>{'Feature importance analysis from a' 163 <h2>{section_title}</h2>
118 'trained Random Forest' 164 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
119 if plot_name == 'tree_importance'
120 else 'SHAP Summary from a trained lightgbm'}</h2>
121 <h3>{'Use gini impurity for'
122 'calculating feature importance for classification'
123 'and Variance Reduction for regression'
124 if plot_name == 'tree_importance'
125 else ''}</h3>
126 <img src="data:image/png;base64,
127 {encoded_image}" alt="{plot_name}">
128 </div> 165 </div>
129 """ 166 """
130 167
131 # Generate HTML content with tabs
132 html_content = f""" 168 html_content = f"""
133 <h1>PyCaret Feature Importance Report</h1> 169 <h1>PyCaret Feature Importance Report</h1>
134 {plots_html} 170 {plots_html}
135 """ 171 """
136 172
137 return html_content 173 return html_content
138 174
175 def encode_image_to_base64(self, img_path):
176 with open(img_path, 'rb') as img_file:
177 return base64.b64encode(img_file.read()).decode('utf-8')
178
139 def run(self): 179 def run(self):
140 LOG.info("Running feature importance analysis") 180 if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup:
141 self.setup_pycaret() 181 self.setup_pycaret()
142 self.generate_feature_importance() 182 self.save_tree_importance()
183 self.save_shap_values()
143 html_content = self.generate_html_report() 184 html_content = self.generate_html_report()
144 LOG.info("Feature importance analysis completed")
145 return html_content 185 return html_content
146
147
148 if __name__ == "__main__":
149 import argparse
150 parser = argparse.ArgumentParser(description="Feature Importance Analysis")
151 parser.add_argument(
152 "--data_path", type=str, help="Path to the dataset")
153 parser.add_argument(
154 "--target_col", type=int,
155 help="Index of the target column (1-based)")
156 parser.add_argument(
157 "--task_type", type=str,
158 choices=["classification", "regression"],
159 help="Task type: classification or regression")
160 parser.add_argument(
161 "--output_dir",
162 type=str,
163 help="Directory to save the outputs")
164 args = parser.parse_args()
165
166 analyzer = FeatureImportanceAnalyzer(
167 args.data_path, args.target_col,
168 args.task_type, args.output_dir)
169 analyzer.run()