view base_model_trainer.py @ 6:a32ff7201629 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 19:00:03 +0000
parents ccd798db5abb
children
line wrap: on
line source

import base64
import logging
import os
import tempfile

import h5py
import joblib
import numpy as np
import pandas as pd
from feature_help_modal import get_feature_metrics_help_modal
from feature_importance import FeatureImportanceAnalyzer
from sklearn.metrics import average_precision_score
from utils import get_html_closing, get_html_template

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:
    def __init__(
        self,
        input_file,
        target_col,
        output_dir,
        task_type,
        random_seed,
        test_file=None,
        **kwargs,
    ):
        self.exp = None  # This will be set in the subclass
        self.input_file = input_file
        self.target_col = target_col
        self.output_dir = output_dir
        self.task_type = task_type
        self.random_seed = random_seed
        self.data = None
        self.target = None
        self.best_model = None
        self.results = None
        self.features_name = None
        self.plots = {}
        self.expaliner = None
        self.plots_explainer_html = None
        self.trees = []
        for key, value in kwargs.items():
            setattr(self, key, value)
        self.setup_params = {}
        self.test_file = test_file
        self.test_data = None

        if not self.output_dir:
            raise ValueError("output_dir must be specified and not None")

        LOG.info(f"Model kwargs: {self.__dict__}")

    def load_data(self):
        LOG.info(f"Loading data from {self.input_file}")
        self.data = pd.read_csv(self.input_file, sep=None, engine="python")
        self.data.columns = self.data.columns.str.replace(".", "_")

        # Remove prediction_label if present
        if "prediction_label" in self.data.columns:
            self.data = self.data.drop(columns=["prediction_label"])

        numeric_cols = self.data.select_dtypes(include=["number"]).columns
        non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns

        self.data[numeric_cols] = self.data[numeric_cols].apply(
            pd.to_numeric, errors="coerce"
        )

        if len(non_numeric_cols) > 0:
            LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")

        names = self.data.columns.to_list()
        target_index = int(self.target_col) - 1
        self.target = names[target_index]
        self.features_name = [name for i, name in enumerate(names) if i != target_index]
        if hasattr(self, "missing_value_strategy"):
            if self.missing_value_strategy == "mean":
                self.data = self.data.fillna(self.data.mean(numeric_only=True))
            elif self.missing_value_strategy == "median":
                self.data = self.data.fillna(self.data.median(numeric_only=True))
            elif self.missing_value_strategy == "drop":
                self.data = self.data.dropna()
        else:
            # Default strategy if not specified
            self.data = self.data.fillna(self.data.median(numeric_only=True))

        if self.test_file:
            LOG.info(f"Loading test data from {self.test_file}")
            self.test_data = pd.read_csv(self.test_file, sep=None, engine="python")
            self.test_data = self.test_data[numeric_cols].apply(
                pd.to_numeric, errors="coerce"
            )
            self.test_data.columns = self.test_data.columns.str.replace(".", "_")

    def setup_pycaret(self):
        LOG.info("Initializing PyCaret")
        self.setup_params = {
            "target": self.target,
            "session_id": self.random_seed,
            "html": True,
            "log_experiment": False,
            "system_log": False,
            "index": False,
        }

        if self.test_data is not None:
            self.setup_params["test_data"] = self.test_data

        if (
            hasattr(self, "train_size")
            and self.train_size is not None
            and self.test_data is None
        ):
            self.setup_params["train_size"] = self.train_size

        if hasattr(self, "normalize") and self.normalize is not None:
            self.setup_params["normalize"] = self.normalize

        if hasattr(self, "feature_selection") and self.feature_selection is not None:
            self.setup_params["feature_selection"] = self.feature_selection

        if (
            hasattr(self, "cross_validation")
            and self.cross_validation is not None
            and self.cross_validation is False
        ):
            self.setup_params["cross_validation"] = self.cross_validation

        if hasattr(self, "cross_validation") and self.cross_validation is not None:
            if hasattr(self, "cross_validation_folds"):
                self.setup_params["fold"] = self.cross_validation_folds

        if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
            self.setup_params["remove_outliers"] = self.remove_outliers

        if (
            hasattr(self, "remove_multicollinearity")
            and self.remove_multicollinearity is not None
        ):
            self.setup_params["remove_multicollinearity"] = (
                self.remove_multicollinearity
            )

        if (
            hasattr(self, "polynomial_features")
            and self.polynomial_features is not None
        ):
            self.setup_params["polynomial_features"] = self.polynomial_features

        if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
            self.setup_params["fix_imbalance"] = self.fix_imbalance

        LOG.info(self.setup_params)

        # Solution: instantiate the correct PyCaret experiment based on task_type
        if self.task_type == "classification":
            from pycaret.classification import ClassificationExperiment

            self.exp = ClassificationExperiment()
        elif self.task_type == "regression":
            from pycaret.regression import RegressionExperiment

            self.exp = RegressionExperiment()
        else:
            raise ValueError("task_type must be 'classification' or 'regression'")

        self.exp.setup(self.data, **self.setup_params)

    def train_model(self):
        LOG.info("Training and selecting the best model")
        if self.task_type == "classification":
            average_displayed = "Weighted"
            self.exp.add_metric(
                id=f"PR-AUC-{average_displayed}",
                name=f"PR-AUC-{average_displayed}",
                target="pred_proba",
                score_func=average_precision_score,
                average="weighted",
            )

        if hasattr(self, "models") and self.models is not None:
            self.best_model = self.exp.compare_models(include=self.models)
        else:
            self.best_model = self.exp.compare_models()
        self.results = self.exp.pull()
        if self.task_type == "classification":
            self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)

        _ = self.exp.predict_model(self.best_model)
        self.test_result_df = self.exp.pull()
        if self.task_type == "classification":
            self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)

    def save_model(self):
        hdf5_model_path = "pycaret_model.h5"
        with h5py.File(hdf5_model_path, "w") as f:
            with tempfile.NamedTemporaryFile(delete=False) as temp_file:
                joblib.dump(self.best_model, temp_file.name)
                temp_file.seek(0)
                model_bytes = temp_file.read()
            f.create_dataset("model", data=np.void(model_bytes))

    def generate_plots(self):
        raise NotImplementedError("Subclasses should implement this method")

    def encode_image_to_base64(self, img_path):
        with open(img_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode("utf-8")

    def save_html_report(self):
        LOG.info("Saving HTML report")

        if not self.output_dir:
            raise ValueError("output_dir must be specified and not None")

        model_name = type(self.best_model).__name__
        excluded_params = ["html", "log_experiment", "system_log", "test_data"]
        filtered_setup_params = {
            k: v for k, v in self.setup_params.items() if k not in excluded_params
        }
        setup_params_table = pd.DataFrame(
            list(filtered_setup_params.items()), columns=["Parameter", "Value"]
        )

        best_model_params = pd.DataFrame(
            self.best_model.get_params().items(), columns=["Parameter", "Value"]
        )
        best_model_params.to_csv(
            os.path.join(self.output_dir, "best_model.csv"), index=False
        )
        self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
        self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv"))

        plots_html = ""
        length = len(self.plots)
        for i, (plot_name, plot_path) in enumerate(self.plots.items()):
            encoded_image = self.encode_image_to_base64(plot_path)
            plots_html += (
                f'<div class="plot">'
                f"<h3>{plot_name.capitalize()}</h3>"
                f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">'
                f"</div>"
            )
            if i < length - 1:
                plots_html += "<hr>"

        tree_plots = ""
        for i, tree in enumerate(self.trees):
            if tree:
                tree_plots += (
                    f'<div class="plot">'
                    f"<h3>Tree {i + 1}</h3>"
                    f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">'
                    f"</div>"
                )

        analyzer = FeatureImportanceAnalyzer(
            data=self.data,
            target_col=self.target_col,
            task_type=self.task_type,
            output_dir=self.output_dir,
            exp=self.exp,
            best_model=self.best_model,
        )
        feature_importance_html = analyzer.run()

        # --- Feature Metrics Help Button ---
        feature_metrics_button_html = (
            '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
            "Help: Metrics Guide"
            "</button>"
            "<style>"
            ".help-modal-btn {"
            "background-color: #17623b;"
            "color: #fff;"
            "border: none;"
            "border-radius: 24px;"
            "padding: 10px 28px;"
            "font-size: 1.1rem;"
            "font-weight: bold;"
            "letter-spacing: 0.03em;"
            "cursor: pointer;"
            "transition: background 0.2s, box-shadow 0.2s;"
            "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
            "}"
            ".help-modal-btn:hover, .help-modal-btn:focus {"
            "background-color: #21895e;"
            "outline: none;"
            "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
            "}"
            "</style>"
        )

        html_content = (
            f"{get_html_template()}"
            "<h1>Tabular Learner Model Report</h1>"
            f"{feature_metrics_button_html}"
            '<div class="tabs">'
            '<div class="tab" onclick="openTab(event, \'summary\')">'
            "Validation Result Summary & Config</div>"
            '<div class="tab" onclick="openTab(event, \'plots\')">'
            "Test Results</div>"
            '<div class="tab" onclick="openTab(event, \'feature\')">'
            "Feature Importance</div>"
        )
        if self.plots_explainer_html:
            html_content += (
                '<div class="tab" onclick="openTab(event, \'explainer\')">'
                "Explainer Plots</div>"
            )
        html_content += (
            "</div>"
            '<div id="summary" class="tab-content">'
            "<h2>Model Metrics from Cross-Validation Set</h2>"
            f"<h2>Best Model: {model_name}</h2>"
            "<h5>The best model is selected by: Accuracy (Classification)"
            " or R2 (Regression).</h5>"
            f"{self.results.to_html(index=False, classes='table sortable')}"
            "<h2>Best Model's Hyperparameters</h2>"
            f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
            "<h2>Setup Parameters</h2>"
            f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
            "<h5>If you want to know all the experiment setup parameters,"
            " please check the PyCaret documentation for"
            " the classification/regression <code>exp</code> function.</h5>"
            "</div>"
            '<div id="plots" class="tab-content">'
            f"<h2>Best Model: {model_name}</h2>"
            "<h5>The best model is selected by: Accuracy (Classification)"
            " or R2 (Regression).</h5>"
            "<h2>Test Metrics</h2>"
            f"{self.test_result_df.to_html(index=False)}"
            "<h2>Test Results</h2>"
            f"{plots_html}"
            "</div>"
            '<div id="feature" class="tab-content">'
            f"{feature_importance_html}"
            "</div>"
        )
        if self.plots_explainer_html:
            html_content += (
                '<div id="explainer" class="tab-content">'
                f"{self.plots_explainer_html}"
                f"{tree_plots}"
                "</div>"
            )
        html_content += (
            "<script>"
            "document.addEventListener(\"DOMContentLoaded\", function() {"
            "var tables = document.querySelectorAll(\"table.sortable\");"
            "tables.forEach(function(table) {"
            "var headers = table.querySelectorAll(\"th\");"
            "headers.forEach(function(header, index) {"
            "header.style.cursor = \"pointer\";"
            "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
            "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
            "header.addEventListener(\"click\", function() {"
            "var direction = this.getAttribute("
            "\"data-sort-direction\""
            ") || \"asc\";"
            "// Reset arrows in all headers of this table"
            "headers.forEach(function(h) {"
            "var arrow = h.querySelector(\".sort-arrow\");"
            "if (arrow) arrow.textContent = \" ↑\";"
            "});"
            "// Set arrow for clicked header"
            "var arrow = this.querySelector(\".sort-arrow\");"
            "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
            "sortTable(table, index, direction);"
            "this.setAttribute(\"data-sort-direction\","
            "direction === \"asc\" ? \"desc\" : \"asc\");"
            "});"
            "});"
            "});"
            "});"
            "function sortTable(table, colNum, direction) {"
            "var tb = table.tBodies[0];"
            "var tr = Array.prototype.slice.call(tb.rows, 0);"
            "var multiplier = direction === \"asc\" ? 1 : -1;"
            "tr = tr.sort(function(a, b) {"
            "var aText = a.cells[colNum].textContent.trim();"
            "var bText = b.cells[colNum].textContent.trim();"
            "// Remove arrow from text comparison"
            "aText = aText.replace(/[↑↓]/g, '').trim();"
            "bText = bText.replace(/[↑↓]/g, '').trim();"
            "if (!isNaN(aText) && !isNaN(bText)) {"
            "return multiplier * ("
            "parseFloat(aText) - parseFloat(bText)"
            ");"
            "} else {"
            "return multiplier * aText.localeCompare(bText);"
            "}"
            "});"
            "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
            "}"
            "</script>"
        )
        # --- Add the Feature Metrics Help Modal ---
        html_content += get_feature_metrics_help_modal()
        html_content += f"{get_html_closing()}"
        with open(
            os.path.join(self.output_dir, "comparison_result.html"),
            "w",
            encoding="utf-8",
        ) as file:
            file.write(html_content)

    def save_dashboard(self):
        raise NotImplementedError("Subclasses should implement this method")

    def generate_plots_explainer(self):
        raise NotImplementedError("Subclasses should implement this method")

    def generate_tree_plots(self):
        from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
        from xgboost import XGBClassifier, XGBRegressor
        from explainerdashboard.explainers import RandomForestExplainer

        LOG.info("Generating tree plots")
        X_test = self.exp.X_test_transformed.copy()
        y_test = self.exp.y_test_transformed

        is_rf = isinstance(
            self.best_model, (RandomForestClassifier, RandomForestRegressor)
        )
        is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor))

        num_trees = None
        if is_rf:
            num_trees = self.best_model.n_estimators
        elif is_xgb:
            num_trees = len(self.best_model.get_booster().get_dump())
        else:
            LOG.warning("Tree plots not supported for this model type.")
            return

        try:
            explainer = RandomForestExplainer(self.best_model, X_test, y_test)
            for i in range(num_trees):
                fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
                LOG.info(f"Tree {i + 1}")
                LOG.info(fig)
                self.trees.append(fig)
        except Exception as e:
            LOG.error(f"Error generating tree plots: {e}")

    def run(self):
        self.load_data()
        self.setup_pycaret()
        self.train_model()
        self.save_model()
        self.generate_plots()
        self.generate_plots_explainer()
        self.generate_tree_plots()
        self.save_html_report()
        # self.save_dashboard()