Mercurial > repos > goeckslab > tabular_learner
comparison feature_importance.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author | goeckslab |
---|---|
date | Wed, 18 Jun 2025 15:38:19 +0000 |
parents | |
children | 77c88226bfde |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:209b663a4f62 |
---|---|
1 import base64 | |
2 import logging | |
3 import os | |
4 | |
5 import matplotlib.pyplot as plt | |
6 import pandas as pd | |
7 from pycaret.classification import ClassificationExperiment | |
8 from pycaret.regression import RegressionExperiment | |
9 | |
10 logging.basicConfig(level=logging.DEBUG) | |
11 LOG = logging.getLogger(__name__) | |
12 | |
13 | |
14 class FeatureImportanceAnalyzer: | |
15 def __init__( | |
16 self, | |
17 task_type, | |
18 output_dir, | |
19 data_path=None, | |
20 data=None, | |
21 target_col=None): | |
22 | |
23 if data is not None: | |
24 self.data = data | |
25 LOG.info("Data loaded from memory") | |
26 else: | |
27 self.target_col = target_col | |
28 self.data = pd.read_csv(data_path, sep=None, engine='python') | |
29 self.data.columns = self.data.columns.str.replace('.', '_') | |
30 self.data = self.data.fillna(self.data.median(numeric_only=True)) | |
31 self.task_type = task_type | |
32 self.target = self.data.columns[int(target_col) - 1] | |
33 self.exp = ClassificationExperiment() \ | |
34 if task_type == 'classification' \ | |
35 else RegressionExperiment() | |
36 self.plots = {} | |
37 self.output_dir = output_dir | |
38 | |
39 def setup_pycaret(self): | |
40 LOG.info("Initializing PyCaret") | |
41 setup_params = { | |
42 'target': self.target, | |
43 'session_id': 123, | |
44 'html': True, | |
45 'log_experiment': False, | |
46 'system_log': False | |
47 } | |
48 LOG.info(self.task_type) | |
49 LOG.info(self.exp) | |
50 self.exp.setup(self.data, **setup_params) | |
51 | |
52 # def save_coefficients(self): | |
53 # model = self.exp.create_model('lr') | |
54 # coef_df = pd.DataFrame({ | |
55 # 'Feature': self.data.columns.drop(self.target), | |
56 # 'Coefficient': model.coef_[0] | |
57 # }) | |
58 # coef_html = coef_df.to_html(index=False) | |
59 # return coef_html | |
60 | |
61 def save_tree_importance(self): | |
62 model = self.exp.create_model('rf') | |
63 importances = model.feature_importances_ | |
64 processed_features = self.exp.get_config('X_transformed').columns | |
65 LOG.debug(f"Feature importances: {importances}") | |
66 LOG.debug(f"Features: {processed_features}") | |
67 feature_importances = pd.DataFrame({ | |
68 'Feature': processed_features, | |
69 'Importance': importances | |
70 }).sort_values(by='Importance', ascending=False) | |
71 plt.figure(figsize=(10, 6)) | |
72 plt.barh( | |
73 feature_importances['Feature'], | |
74 feature_importances['Importance']) | |
75 plt.xlabel('Importance') | |
76 plt.title('Feature Importance (Random Forest)') | |
77 plot_path = os.path.join( | |
78 self.output_dir, | |
79 'tree_importance.png') | |
80 plt.savefig(plot_path) | |
81 plt.close() | |
82 self.plots['tree_importance'] = plot_path | |
83 | |
84 def save_shap_values(self): | |
85 model = self.exp.create_model('lightgbm') | |
86 import shap | |
87 explainer = shap.Explainer(model) | |
88 shap_values = explainer.shap_values( | |
89 self.exp.get_config('X_transformed')) | |
90 shap.summary_plot(shap_values, | |
91 self.exp.get_config('X_transformed'), show=False) | |
92 plt.title('Shap (LightGBM)') | |
93 plot_path = os.path.join( | |
94 self.output_dir, 'shap_summary.png') | |
95 plt.savefig(plot_path) | |
96 plt.close() | |
97 self.plots['shap_summary'] = plot_path | |
98 | |
99 def generate_feature_importance(self): | |
100 # coef_html = self.save_coefficients() | |
101 self.save_tree_importance() | |
102 self.save_shap_values() | |
103 | |
104 def encode_image_to_base64(self, img_path): | |
105 with open(img_path, 'rb') as img_file: | |
106 return base64.b64encode(img_file.read()).decode('utf-8') | |
107 | |
108 def generate_html_report(self): | |
109 LOG.info("Generating HTML report") | |
110 | |
111 # Read and encode plot images | |
112 plots_html = "" | |
113 for plot_name, plot_path in self.plots.items(): | |
114 encoded_image = self.encode_image_to_base64(plot_path) | |
115 plots_html += f""" | |
116 <div class="plot" id="{plot_name}"> | |
117 <h2>{'Feature importance analysis from a' | |
118 'trained Random Forest' | |
119 if plot_name == 'tree_importance' | |
120 else 'SHAP Summary from a trained lightgbm'}</h2> | |
121 <h3>{'Use gini impurity for' | |
122 'calculating feature importance for classification' | |
123 'and Variance Reduction for regression' | |
124 if plot_name == 'tree_importance' | |
125 else ''}</h3> | |
126 <img src="data:image/png;base64, | |
127 {encoded_image}" alt="{plot_name}"> | |
128 </div> | |
129 """ | |
130 | |
131 # Generate HTML content with tabs | |
132 html_content = f""" | |
133 <h1>PyCaret Feature Importance Report</h1> | |
134 {plots_html} | |
135 """ | |
136 | |
137 return html_content | |
138 | |
139 def run(self): | |
140 LOG.info("Running feature importance analysis") | |
141 self.setup_pycaret() | |
142 self.generate_feature_importance() | |
143 html_content = self.generate_html_report() | |
144 LOG.info("Feature importance analysis completed") | |
145 return html_content | |
146 | |
147 | |
148 if __name__ == "__main__": | |
149 import argparse | |
150 parser = argparse.ArgumentParser(description="Feature Importance Analysis") | |
151 parser.add_argument( | |
152 "--data_path", type=str, help="Path to the dataset") | |
153 parser.add_argument( | |
154 "--target_col", type=int, | |
155 help="Index of the target column (1-based)") | |
156 parser.add_argument( | |
157 "--task_type", type=str, | |
158 choices=["classification", "regression"], | |
159 help="Task type: classification or regression") | |
160 parser.add_argument( | |
161 "--output_dir", | |
162 type=str, | |
163 help="Directory to save the outputs") | |
164 args = parser.parse_args() | |
165 | |
166 analyzer = FeatureImportanceAnalyzer( | |
167 args.data_path, args.target_col, | |
168 args.task_type, args.output_dir) | |
169 analyzer.run() |