Mercurial > repos > goeckslab > pycaret_compare
comparison feature_importance.py @ 0:915447b14520 draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
| author | goeckslab |
|---|---|
| date | Wed, 11 Dec 2024 05:00:00 +0000 |
| parents | |
| children | 4aa511539199 |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:915447b14520 |
|---|---|
| 1 import base64 | |
| 2 import logging | |
| 3 import os | |
| 4 | |
| 5 import matplotlib.pyplot as plt | |
| 6 | |
| 7 import pandas as pd | |
| 8 | |
| 9 from pycaret.classification import ClassificationExperiment | |
| 10 from pycaret.regression import RegressionExperiment | |
| 11 | |
| 12 logging.basicConfig(level=logging.DEBUG) | |
| 13 LOG = logging.getLogger(__name__) | |
| 14 | |
| 15 | |
| 16 class FeatureImportanceAnalyzer: | |
| 17 def __init__( | |
| 18 self, | |
| 19 task_type, | |
| 20 output_dir, | |
| 21 data_path=None, | |
| 22 data=None, | |
| 23 target_col=None): | |
| 24 | |
| 25 if data is not None: | |
| 26 self.data = data | |
| 27 LOG.info("Data loaded from memory") | |
| 28 else: | |
| 29 self.target_col = target_col | |
| 30 self.data = pd.read_csv(data_path, sep=None, engine='python') | |
| 31 self.data.columns = self.data.columns.str.replace('.', '_') | |
| 32 self.data = self.data.fillna(self.data.median(numeric_only=True)) | |
| 33 self.task_type = task_type | |
| 34 self.target = self.data.columns[int(target_col) - 1] | |
| 35 self.exp = ClassificationExperiment() \ | |
| 36 if task_type == 'classification' \ | |
| 37 else RegressionExperiment() | |
| 38 self.plots = {} | |
| 39 self.output_dir = output_dir | |
| 40 | |
| 41 def setup_pycaret(self): | |
| 42 LOG.info("Initializing PyCaret") | |
| 43 setup_params = { | |
| 44 'target': self.target, | |
| 45 'session_id': 123, | |
| 46 'html': True, | |
| 47 'log_experiment': False, | |
| 48 'system_log': False | |
| 49 } | |
| 50 LOG.info(self.task_type) | |
| 51 LOG.info(self.exp) | |
| 52 self.exp.setup(self.data, **setup_params) | |
| 53 | |
| 54 # def save_coefficients(self): | |
| 55 # model = self.exp.create_model('lr') | |
| 56 # coef_df = pd.DataFrame({ | |
| 57 # 'Feature': self.data.columns.drop(self.target), | |
| 58 # 'Coefficient': model.coef_[0] | |
| 59 # }) | |
| 60 # coef_html = coef_df.to_html(index=False) | |
| 61 # return coef_html | |
| 62 | |
| 63 def save_tree_importance(self): | |
| 64 model = self.exp.create_model('rf') | |
| 65 importances = model.feature_importances_ | |
| 66 processed_features = self.exp.get_config('X_transformed').columns | |
| 67 LOG.debug(f"Feature importances: {importances}") | |
| 68 LOG.debug(f"Features: {processed_features}") | |
| 69 feature_importances = pd.DataFrame({ | |
| 70 'Feature': processed_features, | |
| 71 'Importance': importances | |
| 72 }).sort_values(by='Importance', ascending=False) | |
| 73 plt.figure(figsize=(10, 6)) | |
| 74 plt.barh( | |
| 75 feature_importances['Feature'], | |
| 76 feature_importances['Importance']) | |
| 77 plt.xlabel('Importance') | |
| 78 plt.title('Feature Importance (Random Forest)') | |
| 79 plot_path = os.path.join( | |
| 80 self.output_dir, | |
| 81 'tree_importance.png') | |
| 82 plt.savefig(plot_path) | |
| 83 plt.close() | |
| 84 self.plots['tree_importance'] = plot_path | |
| 85 | |
| 86 def save_shap_values(self): | |
| 87 model = self.exp.create_model('lightgbm') | |
| 88 import shap | |
| 89 explainer = shap.Explainer(model) | |
| 90 shap_values = explainer.shap_values( | |
| 91 self.exp.get_config('X_transformed')) | |
| 92 shap.summary_plot(shap_values, | |
| 93 self.exp.get_config('X_transformed'), show=False) | |
| 94 plt.title('Shap (LightGBM)') | |
| 95 plot_path = os.path.join( | |
| 96 self.output_dir, 'shap_summary.png') | |
| 97 plt.savefig(plot_path) | |
| 98 plt.close() | |
| 99 self.plots['shap_summary'] = plot_path | |
| 100 | |
| 101 def generate_feature_importance(self): | |
| 102 # coef_html = self.save_coefficients() | |
| 103 self.save_tree_importance() | |
| 104 self.save_shap_values() | |
| 105 | |
| 106 def encode_image_to_base64(self, img_path): | |
| 107 with open(img_path, 'rb') as img_file: | |
| 108 return base64.b64encode(img_file.read()).decode('utf-8') | |
| 109 | |
| 110 def generate_html_report(self): | |
| 111 LOG.info("Generating HTML report") | |
| 112 | |
| 113 # Read and encode plot images | |
| 114 plots_html = "" | |
| 115 for plot_name, plot_path in self.plots.items(): | |
| 116 encoded_image = self.encode_image_to_base64(plot_path) | |
| 117 plots_html += f""" | |
| 118 <div class="plot" id="{plot_name}"> | |
| 119 <h2>{'Feature importance analysis from a' | |
| 120 'trained Random Forest' | |
| 121 if plot_name == 'tree_importance' | |
| 122 else 'SHAP Summary from a trained lightgbm'}</h2> | |
| 123 <h3>{'Use gini impurity for' | |
| 124 'calculating feature importance for classification' | |
| 125 'and Variance Reduction for regression' | |
| 126 if plot_name == 'tree_importance' | |
| 127 else ''}</h3> | |
| 128 <img src="data:image/png;base64, | |
| 129 {encoded_image}" alt="{plot_name}"> | |
| 130 </div> | |
| 131 """ | |
| 132 | |
| 133 # Generate HTML content with tabs | |
| 134 html_content = f""" | |
| 135 <h1>PyCaret Feature Importance Report</h1> | |
| 136 {plots_html} | |
| 137 """ | |
| 138 | |
| 139 return html_content | |
| 140 | |
| 141 def run(self): | |
| 142 LOG.info("Running feature importance analysis") | |
| 143 self.setup_pycaret() | |
| 144 self.generate_feature_importance() | |
| 145 html_content = self.generate_html_report() | |
| 146 LOG.info("Feature importance analysis completed") | |
| 147 return html_content | |
| 148 | |
| 149 | |
| 150 if __name__ == "__main__": | |
| 151 import argparse | |
| 152 parser = argparse.ArgumentParser(description="Feature Importance Analysis") | |
| 153 parser.add_argument( | |
| 154 "--data_path", type=str, help="Path to the dataset") | |
| 155 parser.add_argument( | |
| 156 "--target_col", type=int, | |
| 157 help="Index of the target column (1-based)") | |
| 158 parser.add_argument( | |
| 159 "--task_type", type=str, | |
| 160 choices=["classification", "regression"], | |
| 161 help="Task type: classification or regression") | |
| 162 parser.add_argument( | |
| 163 "--output_dir", | |
| 164 type=str, | |
| 165 help="Directory to save the outputs") | |
| 166 args = parser.parse_args() | |
| 167 | |
| 168 analyzer = FeatureImportanceAnalyzer( | |
| 169 args.data_path, args.target_col, | |
| 170 args.task_type, args.output_dir) | |
| 171 analyzer.run() |
