Mercurial > repos > goeckslab > pycaret_predict
view base_model_trainer.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | f4cb41f458fd |
children |
line wrap: on
line source
import base64 import logging import tempfile from pathlib import Path import h5py import joblib import numpy as np import pandas as pd from feature_help_modal import get_feature_metrics_help_modal from feature_importance import FeatureImportanceAnalyzer from sklearn.metrics import average_precision_score from utils import ( add_hr_to_html, add_plot_to_html, build_tabbed_html, encode_image_to_base64, get_html_closing, get_html_template, ) logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger(__name__) class BaseModelTrainer: def __init__( self, input_file, target_col, output_dir, task_type, random_seed, test_file=None, **kwargs, ): self.exp = None self.input_file = input_file self.target_col = target_col self.output_dir = output_dir self.task_type = task_type self.random_seed = random_seed self.data = None self.target = None self.best_model = None self.results = None self.features_name = None self.plots = {} self.explainer_plots = {} self.plots_explainer_html = None self.trees = [] self.user_kwargs = kwargs.copy() for key, value in self.user_kwargs.items(): setattr(self, key, value) self.setup_params = {} self.test_file = test_file self.test_data = None if not self.output_dir: raise ValueError("output_dir must be specified and not None") LOG.info(f"Model kwargs: {self.__dict__}") def load_data(self): LOG.info(f"Loading data from {self.input_file}") self.data = pd.read_csv(self.input_file, sep=None, engine="python") self.data.columns = self.data.columns.str.replace(".", "_") if "prediction_label" in self.data.columns: self.data = self.data.drop(columns=["prediction_label"]) numeric_cols = self.data.select_dtypes(include=["number"]).columns non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns self.data[numeric_cols] = self.data[numeric_cols].apply( pd.to_numeric, errors="coerce" ) if len(non_numeric_cols) > 0: LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") names = self.data.columns.to_list() target_index = int(self.target_col) - 1 self.target = names[target_index] self.features_name = [n for i, n in enumerate(names) if i != target_index] if getattr(self, "missing_value_strategy", None): strat = self.missing_value_strategy if strat == "mean": self.data = self.data.fillna(self.data.mean(numeric_only=True)) elif strat == "median": self.data = self.data.fillna(self.data.median(numeric_only=True)) elif strat == "drop": self.data = self.data.dropna() else: self.data = self.data.fillna(self.data.median(numeric_only=True)) if self.test_file: LOG.info(f"Loading test data from {self.test_file}") df_test = pd.read_csv(self.test_file, sep=None, engine="python") df_test.columns = df_test.columns.str.replace(".", "_") self.test_data = df_test def setup_pycaret(self): LOG.info("Initializing PyCaret") self.setup_params = { "target": self.target, "session_id": self.random_seed, "html": True, "log_experiment": False, "system_log": False, "index": False, } if self.test_data is not None: self.setup_params["test_data"] = self.test_data for attr in [ "train_size", "normalize", "feature_selection", "remove_outliers", "remove_multicollinearity", "polynomial_features", "feature_interaction", "feature_ratio", "fix_imbalance", ]: val = getattr(self, attr, None) if val is not None: self.setup_params[attr] = val if getattr(self, "cross_validation_folds", None) is not None: self.setup_params["fold"] = self.cross_validation_folds LOG.info(self.setup_params) if self.task_type == "classification": from pycaret.classification import ClassificationExperiment self.exp = ClassificationExperiment() elif self.task_type == "regression": from pycaret.regression import RegressionExperiment self.exp = RegressionExperiment() else: raise ValueError("task_type must be 'classification' or 'regression'") self.exp.setup(self.data, **self.setup_params) self.setup_params.update(self.user_kwargs) def train_model(self): LOG.info("Training and selecting the best model") if self.task_type == "classification": self.exp.add_metric( id="PR-AUC-Weighted", name="PR-AUC-Weighted", target="pred_proba", score_func=average_precision_score, average="weighted", ) # Build arguments for compare_models() compare_kwargs = {} if getattr(self, "models", None): compare_kwargs["include"] = self.models # Respect explicit cross-validation flag if getattr(self, "cross_validation", None) is not None: compare_kwargs["cross_validation"] = self.cross_validation # Respect explicit fold count if getattr(self, "cross_validation_folds", None) is not None: compare_kwargs["fold"] = self.cross_validation_folds LOG.info(f"compare_models kwargs: {compare_kwargs}") self.best_model = self.exp.compare_models(**compare_kwargs) self.results = self.exp.pull() if getattr(self, "tune_model", False): LOG.info("Tuning hyperparameters of the best model") self.best_model = self.exp.tune_model(self.best_model) self.results = self.exp.pull() if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) _ = self.exp.predict_model(self.best_model) self.test_result_df = self.exp.pull() if self.task_type == "classification": self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) def save_model(self): hdf5_path = Path(self.output_dir) / "pycaret_model.h5" with h5py.File(hdf5_path, "w") as f: with tempfile.NamedTemporaryFile(delete=False) as tmp: joblib.dump(self.best_model, tmp.name) tmp.seek(0) model_bytes = tmp.read() f.create_dataset("model", data=np.void(model_bytes)) def generate_plots(self): LOG.info("Generating PyCaret diagnostic pltos") # choose the right plots based on task if self.task_type == "classification": plot_names = [ "learning", "vc", "calibration", "dimension", "manifold", "rfe", "threshold", "percentage_above_below", "class_report", "pr_auc", "roc_auc", ] else: plot_names = ["residuals", "vc", "parameter", "error", "learning"] for name in plot_names: try: ax = self.exp.plot_model(self.best_model, plot=name, save=False) out_path = Path(self.output_dir) / f"plot_{name}.png" fig = ax.get_figure() fig.savefig(out_path, bbox_inches="tight") self.plots[name] = str(out_path) except Exception as e: LOG.warning(f"Could not generate {name} plot: {e}") def encode_image_to_base64(self, img_path: str) -> str: with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") def save_html_report(self): LOG.info("Saving HTML report") # 1) Determine best model name try: best_model_name = str(self.results.iloc[0]["Model"]) except Exception: best_model_name = type(self.best_model).__name__ LOG.info(f"Best model determined as: {best_model_name}") # 2) Compute training sample count try: n_train = self.exp.X_train.shape[0] except Exception: n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] total_rows = self.data.shape[0] # 3) Build setup parameters table all_params = self.setup_params display_keys = [ "Target", "Session ID", "Train Size", "Normalize", "Feature Selection", "Cross Validation", "Cross Validation Folds", "Remove Outliers", "Remove Multicollinearity", "Polynomial Features", "Fix Imbalance", "Models", ] setup_rows = [] for key in display_keys: pk = key.lower().replace(" ", "_") v = all_params.get(pk) if key == "Train Size": frac = ( float(v) if v is not None else (n_train / total_rows if total_rows else 0) ) dv = f"{frac:.2f} ({n_train} rows)" elif key in { "Normalize", "Feature Selection", "Cross Validation", "Remove Outliers", "Remove Multicollinearity", "Polynomial Features", "Fix Imbalance", }: dv = bool(v) elif key == "Cross Validation Folds": dv = v if v is not None else "None" elif key == "Models": dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" else: dv = v if v is not None else "None" setup_rows.append([key, dv]) if hasattr(self.exp, "_fold_metric"): setup_rows.append(["best_model_metric", self.exp._fold_metric]) df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) # 4) Persist CSVs self.results.to_csv( Path(self.output_dir) / "comparison_results.csv", index=False ) self.test_result_df.to_csv( Path(self.output_dir) / "test_results.csv", index=False ) pd.DataFrame( self.best_model.get_params().items(), columns=["Parameter", "Value"] ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) # 5) Header header = f"<h2>Best Model: {best_model_name}</h2>" # — Validation Summary & Configuration — val_df = self.results.copy() # mapping raw plot keys to user-friendly titles plot_title_map = { "learning": "Learning Curve", "vc": "Validation Curve", "calibration": "Calibration Curve", "dimension": "Dimensionality Reduction", "manifold": "Manifold Learning", "rfe": "Recursive Feature Elimination", "threshold": "Threshold Plot", "percentage_above_below": "Percentage Above vs. Below Cutoff", "class_report": "Classification Report", "pr_auc": "Precision-Recall AUC", "roc_auc": "Receiver Operating Characteristic AUC", "residuals": "Residuals Distribution", "error": "Prediction Error Distribution", } val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) summary_html = ( header + "<h2>Train & Validation Summary</h2>" + '<div class="table-wrapper">' + val_df.to_html(index=False, classes="table sortable") + "</div>" + "<h2>Setup Parameters</h2>" + '<div class="table-wrapper">' + df_setup.to_html(index=False, classes="table sortable") + "</div>" # — Hyperparameters + "<h2>Best Model Hyperparameters</h2>" + '<div class="table-wrapper">' + pd.DataFrame( self.best_model.get_params().items(), columns=["Parameter", "Value"] ).to_html(index=False, classes="table sortable") + "</div>" ) # choose summary plots based on task type if self.task_type == "classification": summary_plots = [ "learning", "vc", "calibration", "dimension", "manifold", "rfe", "threshold", "percentage_above_below", ] else: summary_plots = ["learning", "vc", "parameter", "residuals"] for name in summary_plots: if name in self.plots: summary_html += "<hr>" b64 = encode_image_to_base64(self.plots[name]) title = plot_title_map.get(name, name.replace("_", " ").title()) summary_html += ( '<div class="plot">' f"<h2>{title}</h2>" f'<img src="data:image/png;base64,{b64}" ' 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' "</div>" ) # — Test Summary — test_html = ( header + '<div class="table-wrapper">' + self.test_result_df.to_html(index=False, classes="table sortable") + "</div>" ) if self.task_type == "regression": try: y_true = ( pd.Series(self.exp.y_test_transformed) .reset_index(drop=True) .rename("True") ) y_pred = pd.Series( self.best_model.predict(self.exp.X_test_transformed) ).rename("Predicted") df_tp = pd.concat([y_true, y_pred], axis=1) test_html += "<h2>True vs Predicted Values</h2>" test_html += ( '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' + df_tp.head(50).to_html(index=False, classes="table sortable") + "</div>" + add_hr_to_html() ) except Exception as e: LOG.warning(f"Could not generate True vs Predicted table: {e}") # 5a) Explainer-substituted plots in order if self.task_type == "regression": test_order = ["residuals"] else: test_order = [ "confusion_matrix", "roc_auc", "pr_auc", "lift_curve", "threshold", "cumulative_precision", ] for key in test_order: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn title = plot_title_map.get(key, key.replace("_", " ").title()) test_html += ( f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() ) # 5b) Remaining PyCaret test plots for name, path in self.plots.items(): # classification: include only the small extras, before skipping anything if self.task_type == "classification" and name in { "threshold", "pr_auc", "class_report", }: title = plot_title_map.get(name, name.replace("_", " ").title()) b64 = encode_image_to_base64(path) test_html += ( f"<h2>{title}</h2>" "<div class='plot'>" f"<img src='data:image/png;base64,{b64}' " "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" "</div>" + add_hr_to_html() ) continue # regression: explicitly include the 'error' plot, before skipping if self.task_type == "regression" and name == "error": title = plot_title_map.get("error", "Prediction Error Distribution") b64 = encode_image_to_base64(path) test_html += ( f"<h2>{title}</h2>" "<div class='plot'>" f"<img src='data:image/png;base64,{b64}' " "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" "</div>" + add_hr_to_html() ) continue # now skip any plots already rendered via test_order if name in test_order: continue # — Feature Importance — feature_html = header # 6a) PyCaret’s default feature importances feature_html += FeatureImportanceAnalyzer( data=self.data, target_col=self.target_col, task_type=self.task_type, output_dir=self.output_dir, exp=self.exp, best_model=self.best_model, ).run() # 6b) Explainer SHAP importances for key in ["shap_mean", "shap_perm"]: fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn # give SHAP plots explicit titles title = ( "Mean Absolute SHAP Value Impact" if key == "shap_mean" else "Permutation Feature Importance" ) feature_html += ( f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() ) # 6c) PDPs last pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) for k in pdp_keys: fig_or_fn = self.explainer_plots[k] fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn # extract feature name feature = k.split("__", 1)[1] title = f"Partial Dependence for {feature}" feature_html += ( f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() ) # 7) Assemble final HTML (three tabs) html = get_html_template() html += "<h1>Tabular Learner Model Report</h1>" html += build_tabbed_html(summary_html, test_html, feature_html) html += get_feature_metrics_help_modal() html += get_html_closing() # 8) Write out (Path(self.output_dir) / "comparison_result.html").write_text( html, encoding="utf-8" ) LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") def save_dashboard(self): raise NotImplementedError("Subclasses should implement this method") def generate_plots_explainer(self): raise NotImplementedError("Subclasses should implement this method") def generate_tree_plots(self): from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from xgboost import XGBClassifier, XGBRegressor from explainerdashboard.explainers import RandomForestExplainer LOG.info("Generating tree plots") X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): n_trees = self.best_model.n_estimators elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): n_trees = len(self.best_model.get_booster().get_dump()) else: LOG.warning("Tree plots not supported for this model type.") return explainer = RandomForestExplainer(self.best_model, X_test, y_test) for i in range(n_trees): fig = explainer.decisiontree_encoded(tree_idx=i, index=0) self.trees.append(fig) def run(self): self.load_data() self.setup_pycaret() self.train_model() self.save_model() self.generate_plots() self.generate_plots_explainer() self.generate_tree_plots() self.save_html_report() # self.save_dashboard()