view feature_importance.py @ 2:009b18a75dc3 draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit 9497c4faca7063bcbb6b201ab6d0dd1570f22acb
author goeckslab
date Sat, 14 Dec 2024 23:18:02 +0000
parents 915447b14520
children
line wrap: on
line source

import base64
import logging
import os

import matplotlib.pyplot as plt

import pandas as pd

from pycaret.classification import ClassificationExperiment
from pycaret.regression import RegressionExperiment

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class FeatureImportanceAnalyzer:
    def __init__(
            self,
            task_type,
            output_dir,
            data_path=None,
            data=None,
            target_col=None):

        if data is not None:
            self.data = data
            LOG.info("Data loaded from memory")
        else:
            self.target_col = target_col
            self.data = pd.read_csv(data_path, sep=None, engine='python')
            self.data.columns = self.data.columns.str.replace('.', '_')
            self.data = self.data.fillna(self.data.median(numeric_only=True))
        self.task_type = task_type
        self.target = self.data.columns[int(target_col) - 1]
        self.exp = ClassificationExperiment() \
            if task_type == 'classification' \
            else RegressionExperiment()
        self.plots = {}
        self.output_dir = output_dir

    def setup_pycaret(self):
        LOG.info("Initializing PyCaret")
        setup_params = {
            'target': self.target,
            'session_id': 123,
            'html': True,
            'log_experiment': False,
            'system_log': False
        }
        LOG.info(self.task_type)
        LOG.info(self.exp)
        self.exp.setup(self.data, **setup_params)

    # def save_coefficients(self):
    #     model = self.exp.create_model('lr')
    #     coef_df = pd.DataFrame({
    #         'Feature': self.data.columns.drop(self.target),
    #         'Coefficient': model.coef_[0]
    #     })
    #     coef_html = coef_df.to_html(index=False)
    #     return coef_html

    def save_tree_importance(self):
        model = self.exp.create_model('rf')
        importances = model.feature_importances_
        processed_features = self.exp.get_config('X_transformed').columns
        LOG.debug(f"Feature importances: {importances}")
        LOG.debug(f"Features: {processed_features}")
        feature_importances = pd.DataFrame({
            'Feature': processed_features,
            'Importance': importances
        }).sort_values(by='Importance', ascending=False)
        plt.figure(figsize=(10, 6))
        plt.barh(
            feature_importances['Feature'],
            feature_importances['Importance'])
        plt.xlabel('Importance')
        plt.title('Feature Importance (Random Forest)')
        plot_path = os.path.join(
            self.output_dir,
            'tree_importance.png')
        plt.savefig(plot_path)
        plt.close()
        self.plots['tree_importance'] = plot_path

    def save_shap_values(self):
        model = self.exp.create_model('lightgbm')
        import shap
        explainer = shap.Explainer(model)
        shap_values = explainer.shap_values(
            self.exp.get_config('X_transformed'))
        shap.summary_plot(shap_values,
                          self.exp.get_config('X_transformed'), show=False)
        plt.title('Shap (LightGBM)')
        plot_path = os.path.join(
            self.output_dir, 'shap_summary.png')
        plt.savefig(plot_path)
        plt.close()
        self.plots['shap_summary'] = plot_path

    def generate_feature_importance(self):
        # coef_html = self.save_coefficients()
        self.save_tree_importance()
        self.save_shap_values()

    def encode_image_to_base64(self, img_path):
        with open(img_path, 'rb') as img_file:
            return base64.b64encode(img_file.read()).decode('utf-8')

    def generate_html_report(self):
        LOG.info("Generating HTML report")

        # Read and encode plot images
        plots_html = ""
        for plot_name, plot_path in self.plots.items():
            encoded_image = self.encode_image_to_base64(plot_path)
            plots_html += f"""
            <div class="plot" id="{plot_name}">
                <h2>{'Feature importance analysis from a'
                    'trained Random Forest'
                    if plot_name == 'tree_importance'
                    else 'SHAP Summary from a trained lightgbm'}</h2>
                <h3>{'Use gini impurity for'
                    'calculating feature importance for classification'
                    'and Variance Reduction for regression'
                  if plot_name == 'tree_importance'
                  else ''}</h3>
                <img src="data:image/png;base64,
                {encoded_image}" alt="{plot_name}">
            </div>
            """

        # Generate HTML content with tabs
        html_content = f"""
            <h1>PyCaret Feature Importance Report</h1>
            {plots_html}
        """

        return html_content

    def run(self):
        LOG.info("Running feature importance analysis")
        self.setup_pycaret()
        self.generate_feature_importance()
        html_content = self.generate_html_report()
        LOG.info("Feature importance analysis completed")
        return html_content


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="Feature Importance Analysis")
    parser.add_argument(
        "--data_path", type=str, help="Path to the dataset")
    parser.add_argument(
        "--target_col", type=int,
        help="Index of the target column (1-based)")
    parser.add_argument(
        "--task_type", type=str,
        choices=["classification", "regression"],
        help="Task type: classification or regression")
    parser.add_argument(
        "--output_dir",
        type=str,
        help="Directory to save the outputs")
    args = parser.parse_args()

    analyzer = FeatureImportanceAnalyzer(
        args.data_path, args.target_col,
        args.task_type, args.output_dir)
    analyzer.run()