view base_model_trainer.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
line wrap: on
line source

import base64
import logging
import tempfile
from pathlib import Path

import h5py
import joblib
import numpy as np
import pandas as pd
from feature_help_modal import get_feature_metrics_help_modal
from feature_importance import FeatureImportanceAnalyzer
from sklearn.metrics import average_precision_score
from utils import (
    add_hr_to_html,
    add_plot_to_html,
    build_tabbed_html,
    encode_image_to_base64,
    get_html_closing,
    get_html_template,
)

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:
    def __init__(
        self,
        input_file,
        target_col,
        output_dir,
        task_type,
        random_seed,
        test_file=None,
        **kwargs,
    ):
        self.exp = None
        self.input_file = input_file
        self.target_col = target_col
        self.output_dir = output_dir
        self.task_type = task_type
        self.random_seed = random_seed
        self.data = None
        self.target = None
        self.best_model = None
        self.results = None
        self.features_name = None
        self.plots = {}
        self.explainer_plots = {}
        self.plots_explainer_html = None
        self.trees = []
        self.user_kwargs = kwargs.copy()
        for key, value in self.user_kwargs.items():
            setattr(self, key, value)
        self.setup_params = {}
        self.test_file = test_file
        self.test_data = None

        if not self.output_dir:
            raise ValueError("output_dir must be specified and not None")

        LOG.info(f"Model kwargs: {self.__dict__}")

    def load_data(self):
        LOG.info(f"Loading data from {self.input_file}")
        self.data = pd.read_csv(self.input_file, sep=None, engine="python")
        self.data.columns = self.data.columns.str.replace(".", "_")
        if "prediction_label" in self.data.columns:
            self.data = self.data.drop(columns=["prediction_label"])

        numeric_cols = self.data.select_dtypes(include=["number"]).columns
        non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
        self.data[numeric_cols] = self.data[numeric_cols].apply(
            pd.to_numeric, errors="coerce"
        )
        if len(non_numeric_cols) > 0:
            LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")

        names = self.data.columns.to_list()
        target_index = int(self.target_col) - 1
        self.target = names[target_index]
        self.features_name = [n for i, n in enumerate(names) if i != target_index]

        if getattr(self, "missing_value_strategy", None):
            strat = self.missing_value_strategy
            if strat == "mean":
                self.data = self.data.fillna(self.data.mean(numeric_only=True))
            elif strat == "median":
                self.data = self.data.fillna(self.data.median(numeric_only=True))
            elif strat == "drop":
                self.data = self.data.dropna()
        else:
            self.data = self.data.fillna(self.data.median(numeric_only=True))

        if self.test_file:
            LOG.info(f"Loading test data from {self.test_file}")
            df_test = pd.read_csv(self.test_file, sep=None, engine="python")
            df_test.columns = df_test.columns.str.replace(".", "_")
            self.test_data = df_test

    def setup_pycaret(self):
        LOG.info("Initializing PyCaret")
        self.setup_params = {
            "target": self.target,
            "session_id": self.random_seed,
            "html": True,
            "log_experiment": False,
            "system_log": False,
            "index": False,
        }
        if self.test_data is not None:
            self.setup_params["test_data"] = self.test_data
        for attr in [
            "train_size",
            "normalize",
            "feature_selection",
            "remove_outliers",
            "remove_multicollinearity",
            "polynomial_features",
            "feature_interaction",
            "feature_ratio",
            "fix_imbalance",
        ]:
            val = getattr(self, attr, None)
            if val is not None:
                self.setup_params[attr] = val
        if getattr(self, "cross_validation_folds", None) is not None:
            self.setup_params["fold"] = self.cross_validation_folds
        LOG.info(self.setup_params)

        if self.task_type == "classification":
            from pycaret.classification import ClassificationExperiment

            self.exp = ClassificationExperiment()
        elif self.task_type == "regression":
            from pycaret.regression import RegressionExperiment

            self.exp = RegressionExperiment()
        else:
            raise ValueError("task_type must be 'classification' or 'regression'")

        self.exp.setup(self.data, **self.setup_params)
        self.setup_params.update(self.user_kwargs)

    def train_model(self):
        LOG.info("Training and selecting the best model")
        if self.task_type == "classification":
            self.exp.add_metric(
                id="PR-AUC-Weighted",
                name="PR-AUC-Weighted",
                target="pred_proba",
                score_func=average_precision_score,
                average="weighted",
            )
        # Build arguments for compare_models()
        compare_kwargs = {}
        if getattr(self, "models", None):
            compare_kwargs["include"] = self.models

        # Respect explicit cross-validation flag
        if getattr(self, "cross_validation", None) is not None:
            compare_kwargs["cross_validation"] = self.cross_validation

        # Respect explicit fold count
        if getattr(self, "cross_validation_folds", None) is not None:
            compare_kwargs["fold"] = self.cross_validation_folds

        LOG.info(f"compare_models kwargs: {compare_kwargs}")
        self.best_model = self.exp.compare_models(**compare_kwargs)
        self.results = self.exp.pull()
        if getattr(self, "tune_model", False):
            LOG.info("Tuning hyperparameters of the best model")
            self.best_model = self.exp.tune_model(self.best_model)
            self.results = self.exp.pull()

        if self.task_type == "classification":
            self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
        _ = self.exp.predict_model(self.best_model)
        self.test_result_df = self.exp.pull()
        if self.task_type == "classification":
            self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)

    def save_model(self):
        hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
        with h5py.File(hdf5_path, "w") as f:
            with tempfile.NamedTemporaryFile(delete=False) as tmp:
                joblib.dump(self.best_model, tmp.name)
                tmp.seek(0)
                model_bytes = tmp.read()
            f.create_dataset("model", data=np.void(model_bytes))

    def generate_plots(self):
        LOG.info("Generating PyCaret diagnostic pltos")

        # choose the right plots based on task
        if self.task_type == "classification":
            plot_names = [
                "learning",
                "vc",
                "calibration",
                "dimension",
                "manifold",
                "rfe",
                "threshold",
                "percentage_above_below",
                "class_report",
                "pr_auc",
                "roc_auc",
            ]
        else:
            plot_names = ["residuals", "vc", "parameter", "error", "learning"]
        for name in plot_names:
            try:
                ax = self.exp.plot_model(self.best_model, plot=name, save=False)
                out_path = Path(self.output_dir) / f"plot_{name}.png"
                fig = ax.get_figure()
                fig.savefig(out_path, bbox_inches="tight")
                self.plots[name] = str(out_path)
            except Exception as e:
                LOG.warning(f"Could not generate {name} plot: {e}")

    def encode_image_to_base64(self, img_path: str) -> str:
        with open(img_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode("utf-8")

    def save_html_report(self):
        LOG.info("Saving HTML report")

        # 1) Determine best model name
        try:
            best_model_name = str(self.results.iloc[0]["Model"])
        except Exception:
            best_model_name = type(self.best_model).__name__
        LOG.info(f"Best model determined as: {best_model_name}")

        # 2) Compute training sample count
        try:
            n_train = self.exp.X_train.shape[0]
        except Exception:
            n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0]
        total_rows = self.data.shape[0]

        # 3) Build setup parameters table
        all_params = self.setup_params
        display_keys = [
            "Target",
            "Session ID",
            "Train Size",
            "Normalize",
            "Feature Selection",
            "Cross Validation",
            "Cross Validation Folds",
            "Remove Outliers",
            "Remove Multicollinearity",
            "Polynomial Features",
            "Fix Imbalance",
            "Models",
        ]
        setup_rows = []
        for key in display_keys:
            pk = key.lower().replace(" ", "_")
            v = all_params.get(pk)
            if key == "Train Size":
                frac = (
                    float(v)
                    if v is not None
                    else (n_train / total_rows if total_rows else 0)
                )
                dv = f"{frac:.2f} ({n_train} rows)"
            elif key in {
                "Normalize",
                "Feature Selection",
                "Cross Validation",
                "Remove Outliers",
                "Remove Multicollinearity",
                "Polynomial Features",
                "Fix Imbalance",
            }:
                dv = bool(v)
            elif key == "Cross Validation Folds":
                dv = v if v is not None else "None"
            elif key == "Models":
                dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
            else:
                dv = v if v is not None else "None"
            setup_rows.append([key, dv])
        if hasattr(self.exp, "_fold_metric"):
            setup_rows.append(["best_model_metric", self.exp._fold_metric])

        df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
        df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False)

        # 4) Persist CSVs
        self.results.to_csv(
            Path(self.output_dir) / "comparison_results.csv", index=False
        )
        self.test_result_df.to_csv(
            Path(self.output_dir) / "test_results.csv", index=False
        )
        pd.DataFrame(
            self.best_model.get_params().items(), columns=["Parameter", "Value"]
        ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)

        # 5) Header
        header = f"<h2>Best Model: {best_model_name}</h2>"

        # — Validation Summary & Configuration —
        val_df = self.results.copy()
        # mapping raw plot keys to user-friendly titles
        plot_title_map = {
            "learning": "Learning Curve",
            "vc": "Validation Curve",
            "calibration": "Calibration Curve",
            "dimension": "Dimensionality Reduction",
            "manifold": "Manifold Learning",
            "rfe": "Recursive Feature Elimination",
            "threshold": "Threshold Plot",
            "percentage_above_below": "Percentage Above vs. Below Cutoff",
            "class_report": "Classification Report",
            "pr_auc": "Precision-Recall AUC",
            "roc_auc": "Receiver Operating Characteristic AUC",
            "residuals": "Residuals Distribution",
            "error": "Prediction Error Distribution",
        }
        val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True)
        summary_html = (
            header
            + "<h2>Train & Validation Summary</h2>"
            + '<div class="table-wrapper">'
            + val_df.to_html(index=False, classes="table sortable")
            + "</div>"
            + "<h2>Setup Parameters</h2>"
            + '<div class="table-wrapper">'
            + df_setup.to_html(index=False, classes="table sortable")
            + "</div>"
            # — Hyperparameters
            + "<h2>Best Model Hyperparameters</h2>"
            + '<div class="table-wrapper">'
            + pd.DataFrame(
                self.best_model.get_params().items(), columns=["Parameter", "Value"]
            ).to_html(index=False, classes="table sortable")
            + "</div>"
        )

        # choose summary plots based on task type
        if self.task_type == "classification":
            summary_plots = [
                "learning",
                "vc",
                "calibration",
                "dimension",
                "manifold",
                "rfe",
                "threshold",
                "percentage_above_below",
            ]
        else:
            summary_plots = ["learning", "vc", "parameter", "residuals"]

        for name in summary_plots:
            if name in self.plots:
                summary_html += "<hr>"
                b64 = encode_image_to_base64(self.plots[name])
                title = plot_title_map.get(name, name.replace("_", " ").title())
                summary_html += (
                    '<div class="plot">'
                    f"<h2>{title}</h2>"
                    f'<img src="data:image/png;base64,{b64}" '
                    'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>'
                    "</div>"
                )

        # — Test Summary —
        test_html = (
            header
            + '<div class="table-wrapper">'
            + self.test_result_df.to_html(index=False, classes="table sortable")
            + "</div>"
        )
        if self.task_type == "regression":
            try:
                y_true = (
                    pd.Series(self.exp.y_test_transformed)
                    .reset_index(drop=True)
                    .rename("True")
                )
                y_pred = pd.Series(
                    self.best_model.predict(self.exp.X_test_transformed)
                ).rename("Predicted")
                df_tp = pd.concat([y_true, y_pred], axis=1)
                test_html += "<h2>True vs Predicted Values</h2>"
                test_html += (
                    '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">'
                    + df_tp.head(50).to_html(index=False, classes="table sortable")
                    + "</div>"
                    + add_hr_to_html()
                )
            except Exception as e:
                LOG.warning(f"Could not generate True vs Predicted table: {e}")

        # 5a) Explainer-substituted plots in order
        if self.task_type == "regression":
            test_order = ["residuals"]
        else:
            test_order = [
                "confusion_matrix",
                "roc_auc",
                "pr_auc",
                "lift_curve",
                "threshold",
                "cumulative_precision",
            ]
        for key in test_order:
            fig_or_fn = self.explainer_plots.pop(key, None)
            if fig_or_fn is not None:
                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
                title = plot_title_map.get(key, key.replace("_", " ").title())
                test_html += (
                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
                )
        # 5b) Remaining PyCaret test plots
        for name, path in self.plots.items():
            # classification: include only the small extras, before skipping anything
            if self.task_type == "classification" and name in {
                "threshold",
                "pr_auc",
                "class_report",
            }:
                title = plot_title_map.get(name, name.replace("_", " ").title())
                b64 = encode_image_to_base64(path)
                test_html += (
                    f"<h2>{title}</h2>"
                    "<div class='plot'>"
                    f"<img src='data:image/png;base64,{b64}' "
                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
                    "</div>" + add_hr_to_html()
                )
                continue

            # regression: explicitly include the 'error' plot, before skipping
            if self.task_type == "regression" and name == "error":
                title = plot_title_map.get("error", "Prediction Error Distribution")
                b64 = encode_image_to_base64(path)
                test_html += (
                    f"<h2>{title}</h2>"
                    "<div class='plot'>"
                    f"<img src='data:image/png;base64,{b64}' "
                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
                    "</div>" + add_hr_to_html()
                )
                continue

            # now skip any plots already rendered via test_order
            if name in test_order:
                continue

        # — Feature Importance —
        feature_html = header

        # 6a) PyCaret’s default feature importances
        feature_html += FeatureImportanceAnalyzer(
            data=self.data,
            target_col=self.target_col,
            task_type=self.task_type,
            output_dir=self.output_dir,
            exp=self.exp,
            best_model=self.best_model,
        ).run()

        # 6b) Explainer SHAP importances
        for key in ["shap_mean", "shap_perm"]:
            fig_or_fn = self.explainer_plots.pop(key, None)
            if fig_or_fn is not None:
                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
                # give SHAP plots explicit titles
                title = (
                    "Mean Absolute SHAP Value Impact"
                    if key == "shap_mean"
                    else "Permutation Feature Importance"
                )
                feature_html += (
                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
                )

        # 6c) PDPs last
        pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__"))
        for k in pdp_keys:
            fig_or_fn = self.explainer_plots[k]
            fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
            # extract feature name
            feature = k.split("__", 1)[1]
            title = f"Partial Dependence for {feature}"
            feature_html += (
                f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
            )
        # 7) Assemble final HTML (three tabs)
        html = get_html_template()
        html += "<h1>Tabular Learner Model Report</h1>"
        html += build_tabbed_html(summary_html, test_html, feature_html)
        html += get_feature_metrics_help_modal()
        html += get_html_closing()

        # 8) Write out
        (Path(self.output_dir) / "comparison_result.html").write_text(
            html, encoding="utf-8"
        )
        LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html")

    def save_dashboard(self):
        raise NotImplementedError("Subclasses should implement this method")

    def generate_plots_explainer(self):
        raise NotImplementedError("Subclasses should implement this method")

    def generate_tree_plots(self):
        from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
        from xgboost import XGBClassifier, XGBRegressor
        from explainerdashboard.explainers import RandomForestExplainer

        LOG.info("Generating tree plots")
        X_test = self.exp.X_test_transformed.copy()
        y_test = self.exp.y_test_transformed

        if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)):
            n_trees = self.best_model.n_estimators
        elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
            n_trees = len(self.best_model.get_booster().get_dump())
        else:
            LOG.warning("Tree plots not supported for this model type.")
            return

        explainer = RandomForestExplainer(self.best_model, X_test, y_test)
        for i in range(n_trees):
            fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
            self.trees.append(fig)

    def run(self):
        self.load_data()
        self.setup_pycaret()
        self.train_model()
        self.save_model()
        self.generate_plots()
        self.generate_plots_explainer()
        self.generate_tree_plots()
        self.save_html_report()
        # self.save_dashboard()