comparison feature_importance.py @ 0:209b663a4f62 draft

planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author goeckslab
date Wed, 18 Jun 2025 15:38:19 +0000
parents
children 77c88226bfde
comparison
equal deleted inserted replaced
-1:000000000000 0:209b663a4f62
1 import base64
2 import logging
3 import os
4
5 import matplotlib.pyplot as plt
6 import pandas as pd
7 from pycaret.classification import ClassificationExperiment
8 from pycaret.regression import RegressionExperiment
9
10 logging.basicConfig(level=logging.DEBUG)
11 LOG = logging.getLogger(__name__)
12
13
14 class FeatureImportanceAnalyzer:
15 def __init__(
16 self,
17 task_type,
18 output_dir,
19 data_path=None,
20 data=None,
21 target_col=None):
22
23 if data is not None:
24 self.data = data
25 LOG.info("Data loaded from memory")
26 else:
27 self.target_col = target_col
28 self.data = pd.read_csv(data_path, sep=None, engine='python')
29 self.data.columns = self.data.columns.str.replace('.', '_')
30 self.data = self.data.fillna(self.data.median(numeric_only=True))
31 self.task_type = task_type
32 self.target = self.data.columns[int(target_col) - 1]
33 self.exp = ClassificationExperiment() \
34 if task_type == 'classification' \
35 else RegressionExperiment()
36 self.plots = {}
37 self.output_dir = output_dir
38
39 def setup_pycaret(self):
40 LOG.info("Initializing PyCaret")
41 setup_params = {
42 'target': self.target,
43 'session_id': 123,
44 'html': True,
45 'log_experiment': False,
46 'system_log': False
47 }
48 LOG.info(self.task_type)
49 LOG.info(self.exp)
50 self.exp.setup(self.data, **setup_params)
51
52 # def save_coefficients(self):
53 # model = self.exp.create_model('lr')
54 # coef_df = pd.DataFrame({
55 # 'Feature': self.data.columns.drop(self.target),
56 # 'Coefficient': model.coef_[0]
57 # })
58 # coef_html = coef_df.to_html(index=False)
59 # return coef_html
60
61 def save_tree_importance(self):
62 model = self.exp.create_model('rf')
63 importances = model.feature_importances_
64 processed_features = self.exp.get_config('X_transformed').columns
65 LOG.debug(f"Feature importances: {importances}")
66 LOG.debug(f"Features: {processed_features}")
67 feature_importances = pd.DataFrame({
68 'Feature': processed_features,
69 'Importance': importances
70 }).sort_values(by='Importance', ascending=False)
71 plt.figure(figsize=(10, 6))
72 plt.barh(
73 feature_importances['Feature'],
74 feature_importances['Importance'])
75 plt.xlabel('Importance')
76 plt.title('Feature Importance (Random Forest)')
77 plot_path = os.path.join(
78 self.output_dir,
79 'tree_importance.png')
80 plt.savefig(plot_path)
81 plt.close()
82 self.plots['tree_importance'] = plot_path
83
84 def save_shap_values(self):
85 model = self.exp.create_model('lightgbm')
86 import shap
87 explainer = shap.Explainer(model)
88 shap_values = explainer.shap_values(
89 self.exp.get_config('X_transformed'))
90 shap.summary_plot(shap_values,
91 self.exp.get_config('X_transformed'), show=False)
92 plt.title('Shap (LightGBM)')
93 plot_path = os.path.join(
94 self.output_dir, 'shap_summary.png')
95 plt.savefig(plot_path)
96 plt.close()
97 self.plots['shap_summary'] = plot_path
98
99 def generate_feature_importance(self):
100 # coef_html = self.save_coefficients()
101 self.save_tree_importance()
102 self.save_shap_values()
103
104 def encode_image_to_base64(self, img_path):
105 with open(img_path, 'rb') as img_file:
106 return base64.b64encode(img_file.read()).decode('utf-8')
107
108 def generate_html_report(self):
109 LOG.info("Generating HTML report")
110
111 # Read and encode plot images
112 plots_html = ""
113 for plot_name, plot_path in self.plots.items():
114 encoded_image = self.encode_image_to_base64(plot_path)
115 plots_html += f"""
116 <div class="plot" id="{plot_name}">
117 <h2>{'Feature importance analysis from a'
118 'trained Random Forest'
119 if plot_name == 'tree_importance'
120 else 'SHAP Summary from a trained lightgbm'}</h2>
121 <h3>{'Use gini impurity for'
122 'calculating feature importance for classification'
123 'and Variance Reduction for regression'
124 if plot_name == 'tree_importance'
125 else ''}</h3>
126 <img src="data:image/png;base64,
127 {encoded_image}" alt="{plot_name}">
128 </div>
129 """
130
131 # Generate HTML content with tabs
132 html_content = f"""
133 <h1>PyCaret Feature Importance Report</h1>
134 {plots_html}
135 """
136
137 return html_content
138
139 def run(self):
140 LOG.info("Running feature importance analysis")
141 self.setup_pycaret()
142 self.generate_feature_importance()
143 html_content = self.generate_html_report()
144 LOG.info("Feature importance analysis completed")
145 return html_content
146
147
148 if __name__ == "__main__":
149 import argparse
150 parser = argparse.ArgumentParser(description="Feature Importance Analysis")
151 parser.add_argument(
152 "--data_path", type=str, help="Path to the dataset")
153 parser.add_argument(
154 "--target_col", type=int,
155 help="Index of the target column (1-based)")
156 parser.add_argument(
157 "--task_type", type=str,
158 choices=["classification", "regression"],
159 help="Task type: classification or regression")
160 parser.add_argument(
161 "--output_dir",
162 type=str,
163 help="Directory to save the outputs")
164 args = parser.parse_args()
165
166 analyzer = FeatureImportanceAnalyzer(
167 args.data_path, args.target_col,
168 args.task_type, args.output_dir)
169 analyzer.run()