Mercurial > repos > goeckslab > pycaret_predict
diff base_model_trainer.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | f4cb41f458fd |
children |
line wrap: on
line diff
--- a/base_model_trainer.py Wed Jul 09 01:13:01 2025 +0000 +++ b/base_model_trainer.py Fri Jul 25 19:02:32 2025 +0000 @@ -1,7 +1,7 @@ import base64 import logging -import os import tempfile +from pathlib import Path import h5py import joblib @@ -10,7 +10,14 @@ from feature_help_modal import get_feature_metrics_help_modal from feature_importance import FeatureImportanceAnalyzer from sklearn.metrics import average_precision_score -from utils import get_html_closing, get_html_template +from utils import ( + add_hr_to_html, + add_plot_to_html, + build_tabbed_html, + encode_image_to_base64, + get_html_closing, + get_html_template, +) logging.basicConfig(level=logging.DEBUG) LOG = logging.getLogger(__name__) @@ -27,7 +34,7 @@ test_file=None, **kwargs, ): - self.exp = None # This will be set in the subclass + self.exp = None self.input_file = input_file self.target_col = target_col self.output_dir = output_dir @@ -39,10 +46,11 @@ self.results = None self.features_name = None self.plots = {} - self.expaliner = None + self.explainer_plots = {} self.plots_explainer_html = None self.trees = [] - for key, value in kwargs.items(): + self.user_kwargs = kwargs.copy() + for key, value in self.user_kwargs.items(): setattr(self, key, value) self.setup_params = {} self.test_file = test_file @@ -57,43 +65,38 @@ LOG.info(f"Loading data from {self.input_file}") self.data = pd.read_csv(self.input_file, sep=None, engine="python") self.data.columns = self.data.columns.str.replace(".", "_") - - # Remove prediction_label if present if "prediction_label" in self.data.columns: self.data = self.data.drop(columns=["prediction_label"]) numeric_cols = self.data.select_dtypes(include=["number"]).columns non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns - self.data[numeric_cols] = self.data[numeric_cols].apply( pd.to_numeric, errors="coerce" ) - if len(non_numeric_cols) > 0: LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") names = self.data.columns.to_list() target_index = int(self.target_col) - 1 self.target = names[target_index] - self.features_name = [name for i, name in enumerate(names) if i != target_index] - if hasattr(self, "missing_value_strategy"): - if self.missing_value_strategy == "mean": + self.features_name = [n for i, n in enumerate(names) if i != target_index] + + if getattr(self, "missing_value_strategy", None): + strat = self.missing_value_strategy + if strat == "mean": self.data = self.data.fillna(self.data.mean(numeric_only=True)) - elif self.missing_value_strategy == "median": + elif strat == "median": self.data = self.data.fillna(self.data.median(numeric_only=True)) - elif self.missing_value_strategy == "drop": + elif strat == "drop": self.data = self.data.dropna() else: - # Default strategy if not specified self.data = self.data.fillna(self.data.median(numeric_only=True)) if self.test_file: LOG.info(f"Loading test data from {self.test_file}") - self.test_data = pd.read_csv(self.test_file, sep=None, engine="python") - self.test_data = self.test_data[numeric_cols].apply( - pd.to_numeric, errors="coerce" - ) - self.test_data.columns = self.test_data.columns.str.replace(".", "_") + df_test = pd.read_csv(self.test_file, sep=None, engine="python") + df_test.columns = df_test.columns.str.replace(".", "_") + self.test_data = df_test def setup_pycaret(self): LOG.info("Initializing PyCaret") @@ -105,59 +108,26 @@ "system_log": False, "index": False, } - if self.test_data is not None: self.setup_params["test_data"] = self.test_data - - if ( - hasattr(self, "train_size") - and self.train_size is not None - and self.test_data is None - ): - self.setup_params["train_size"] = self.train_size - - if hasattr(self, "normalize") and self.normalize is not None: - self.setup_params["normalize"] = self.normalize - - if hasattr(self, "feature_selection") and self.feature_selection is not None: - self.setup_params["feature_selection"] = self.feature_selection - - if ( - hasattr(self, "cross_validation") - and self.cross_validation is not None - and self.cross_validation is False - ): - logging.info( - "cross_validation is set to False. This will disable cross-validation." - ) - - if hasattr(self, "cross_validation") and self.cross_validation: - if hasattr(self, "cross_validation_folds"): - self.setup_params["fold"] = self.cross_validation_folds - - if hasattr(self, "remove_outliers") and self.remove_outliers is not None: - self.setup_params["remove_outliers"] = self.remove_outliers - - if ( - hasattr(self, "remove_multicollinearity") - and self.remove_multicollinearity is not None - ): - self.setup_params["remove_multicollinearity"] = ( - self.remove_multicollinearity - ) - - if ( - hasattr(self, "polynomial_features") - and self.polynomial_features is not None - ): - self.setup_params["polynomial_features"] = self.polynomial_features - - if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None: - self.setup_params["fix_imbalance"] = self.fix_imbalance - + for attr in [ + "train_size", + "normalize", + "feature_selection", + "remove_outliers", + "remove_multicollinearity", + "polynomial_features", + "feature_interaction", + "feature_ratio", + "fix_imbalance", + ]: + val = getattr(self, attr, None) + if val is not None: + self.setup_params[attr] = val + if getattr(self, "cross_validation_folds", None) is not None: + self.setup_params["fold"] = self.cross_validation_folds LOG.info(self.setup_params) - # Solution: instantiate the correct PyCaret experiment based on task_type if self.task_type == "classification": from pycaret.classification import ClassificationExperiment @@ -170,246 +140,371 @@ raise ValueError("task_type must be 'classification' or 'regression'") self.exp.setup(self.data, **self.setup_params) + self.setup_params.update(self.user_kwargs) def train_model(self): LOG.info("Training and selecting the best model") if self.task_type == "classification": - average_displayed = "Weighted" self.exp.add_metric( - id=f"PR-AUC-{average_displayed}", - name=f"PR-AUC-{average_displayed}", + id="PR-AUC-Weighted", + name="PR-AUC-Weighted", target="pred_proba", score_func=average_precision_score, average="weighted", ) + # Build arguments for compare_models() + compare_kwargs = {} + if getattr(self, "models", None): + compare_kwargs["include"] = self.models - if hasattr(self, "models") and self.models is not None: - self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) - else: - self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) + # Respect explicit cross-validation flag + if getattr(self, "cross_validation", None) is not None: + compare_kwargs["cross_validation"] = self.cross_validation + + # Respect explicit fold count + if getattr(self, "cross_validation_folds", None) is not None: + compare_kwargs["fold"] = self.cross_validation_folds + + LOG.info(f"compare_models kwargs: {compare_kwargs}") + self.best_model = self.exp.compare_models(**compare_kwargs) self.results = self.exp.pull() + if getattr(self, "tune_model", False): + LOG.info("Tuning hyperparameters of the best model") + self.best_model = self.exp.tune_model(self.best_model) + self.results = self.exp.pull() if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) - _ = self.exp.predict_model(self.best_model) self.test_result_df = self.exp.pull() if self.task_type == "classification": self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) def save_model(self): - hdf5_model_path = "pycaret_model.h5" - with h5py.File(hdf5_model_path, "w") as f: - with tempfile.NamedTemporaryFile(delete=False) as temp_file: - joblib.dump(self.best_model, temp_file.name) - temp_file.seek(0) - model_bytes = temp_file.read() + hdf5_path = Path(self.output_dir) / "pycaret_model.h5" + with h5py.File(hdf5_path, "w") as f: + with tempfile.NamedTemporaryFile(delete=False) as tmp: + joblib.dump(self.best_model, tmp.name) + tmp.seek(0) + model_bytes = tmp.read() f.create_dataset("model", data=np.void(model_bytes)) def generate_plots(self): - raise NotImplementedError("Subclasses should implement this method") + LOG.info("Generating PyCaret diagnostic pltos") - def encode_image_to_base64(self, img_path): + # choose the right plots based on task + if self.task_type == "classification": + plot_names = [ + "learning", + "vc", + "calibration", + "dimension", + "manifold", + "rfe", + "threshold", + "percentage_above_below", + "class_report", + "pr_auc", + "roc_auc", + ] + else: + plot_names = ["residuals", "vc", "parameter", "error", "learning"] + for name in plot_names: + try: + ax = self.exp.plot_model(self.best_model, plot=name, save=False) + out_path = Path(self.output_dir) / f"plot_{name}.png" + fig = ax.get_figure() + fig.savefig(out_path, bbox_inches="tight") + self.plots[name] = str(out_path) + except Exception as e: + LOG.warning(f"Could not generate {name} plot: {e}") + + def encode_image_to_base64(self, img_path: str) -> str: with open(img_path, "rb") as img_file: return base64.b64encode(img_file.read()).decode("utf-8") def save_html_report(self): LOG.info("Saving HTML report") - if not self.output_dir: - raise ValueError("output_dir must be specified and not None") + # 1) Determine best model name + try: + best_model_name = str(self.results.iloc[0]["Model"]) + except Exception: + best_model_name = type(self.best_model).__name__ + LOG.info(f"Best model determined as: {best_model_name}") + + # 2) Compute training sample count + try: + n_train = self.exp.X_train.shape[0] + except Exception: + n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] + total_rows = self.data.shape[0] - model_name = type(self.best_model).__name__ - excluded_params = ["html", "log_experiment", "system_log", "test_data"] - filtered_setup_params = { - k: v for k, v in self.setup_params.items() if k not in excluded_params + # 3) Build setup parameters table + all_params = self.setup_params + display_keys = [ + "Target", + "Session ID", + "Train Size", + "Normalize", + "Feature Selection", + "Cross Validation", + "Cross Validation Folds", + "Remove Outliers", + "Remove Multicollinearity", + "Polynomial Features", + "Fix Imbalance", + "Models", + ] + setup_rows = [] + for key in display_keys: + pk = key.lower().replace(" ", "_") + v = all_params.get(pk) + if key == "Train Size": + frac = ( + float(v) + if v is not None + else (n_train / total_rows if total_rows else 0) + ) + dv = f"{frac:.2f} ({n_train} rows)" + elif key in { + "Normalize", + "Feature Selection", + "Cross Validation", + "Remove Outliers", + "Remove Multicollinearity", + "Polynomial Features", + "Fix Imbalance", + }: + dv = bool(v) + elif key == "Cross Validation Folds": + dv = v if v is not None else "None" + elif key == "Models": + dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" + else: + dv = v if v is not None else "None" + setup_rows.append([key, dv]) + if hasattr(self.exp, "_fold_metric"): + setup_rows.append(["best_model_metric", self.exp._fold_metric]) + + df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) + df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) + + # 4) Persist CSVs + self.results.to_csv( + Path(self.output_dir) / "comparison_results.csv", index=False + ) + self.test_result_df.to_csv( + Path(self.output_dir) / "test_results.csv", index=False + ) + pd.DataFrame( + self.best_model.get_params().items(), columns=["Parameter", "Value"] + ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) + + # 5) Header + header = f"<h2>Best Model: {best_model_name}</h2>" + + # — Validation Summary & Configuration — + val_df = self.results.copy() + # mapping raw plot keys to user-friendly titles + plot_title_map = { + "learning": "Learning Curve", + "vc": "Validation Curve", + "calibration": "Calibration Curve", + "dimension": "Dimensionality Reduction", + "manifold": "Manifold Learning", + "rfe": "Recursive Feature Elimination", + "threshold": "Threshold Plot", + "percentage_above_below": "Percentage Above vs. Below Cutoff", + "class_report": "Classification Report", + "pr_auc": "Precision-Recall AUC", + "roc_auc": "Receiver Operating Characteristic AUC", + "residuals": "Residuals Distribution", + "error": "Prediction Error Distribution", } - setup_params_table = pd.DataFrame( - list(filtered_setup_params.items()), columns=["Parameter", "Value"] + val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) + summary_html = ( + header + + "<h2>Train & Validation Summary</h2>" + + '<div class="table-wrapper">' + + val_df.to_html(index=False, classes="table sortable") + + "</div>" + + "<h2>Setup Parameters</h2>" + + '<div class="table-wrapper">' + + df_setup.to_html(index=False, classes="table sortable") + + "</div>" + # — Hyperparameters + + "<h2>Best Model Hyperparameters</h2>" + + '<div class="table-wrapper">' + + pd.DataFrame( + self.best_model.get_params().items(), columns=["Parameter", "Value"] + ).to_html(index=False, classes="table sortable") + + "</div>" ) - best_model_params = pd.DataFrame( - self.best_model.get_params().items(), columns=["Parameter", "Value"] - ) - best_model_params.to_csv( - os.path.join(self.output_dir, "best_model.csv"), index=False - ) - self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv")) - self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv")) + # choose summary plots based on task type + if self.task_type == "classification": + summary_plots = [ + "learning", + "vc", + "calibration", + "dimension", + "manifold", + "rfe", + "threshold", + "percentage_above_below", + ] + else: + summary_plots = ["learning", "vc", "parameter", "residuals"] - plots_html = "" - length = len(self.plots) - for i, (plot_name, plot_path) in enumerate(self.plots.items()): - encoded_image = self.encode_image_to_base64(plot_path) - plots_html += ( - f'<div class="plot">' - f"<h3>{plot_name.capitalize()}</h3>" - f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">' - f"</div>" - ) - if i < length - 1: - plots_html += "<hr>" - - tree_plots = "" - for i, tree in enumerate(self.trees): - if tree: - tree_plots += ( - f'<div class="plot">' - f"<h3>Tree {i + 1}</h3>" - f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">' - f"</div>" + for name in summary_plots: + if name in self.plots: + summary_html += "<hr>" + b64 = encode_image_to_base64(self.plots[name]) + title = plot_title_map.get(name, name.replace("_", " ").title()) + summary_html += ( + '<div class="plot">' + f"<h2>{title}</h2>" + f'<img src="data:image/png;base64,{b64}" ' + 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' + "</div>" ) - analyzer = FeatureImportanceAnalyzer( + # — Test Summary — + test_html = ( + header + + '<div class="table-wrapper">' + + self.test_result_df.to_html(index=False, classes="table sortable") + + "</div>" + ) + if self.task_type == "regression": + try: + y_true = ( + pd.Series(self.exp.y_test_transformed) + .reset_index(drop=True) + .rename("True") + ) + y_pred = pd.Series( + self.best_model.predict(self.exp.X_test_transformed) + ).rename("Predicted") + df_tp = pd.concat([y_true, y_pred], axis=1) + test_html += "<h2>True vs Predicted Values</h2>" + test_html += ( + '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' + + df_tp.head(50).to_html(index=False, classes="table sortable") + + "</div>" + + add_hr_to_html() + ) + except Exception as e: + LOG.warning(f"Could not generate True vs Predicted table: {e}") + + # 5a) Explainer-substituted plots in order + if self.task_type == "regression": + test_order = ["residuals"] + else: + test_order = [ + "confusion_matrix", + "roc_auc", + "pr_auc", + "lift_curve", + "threshold", + "cumulative_precision", + ] + for key in test_order: + fig_or_fn = self.explainer_plots.pop(key, None) + if fig_or_fn is not None: + fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + title = plot_title_map.get(key, key.replace("_", " ").title()) + test_html += ( + f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() + ) + # 5b) Remaining PyCaret test plots + for name, path in self.plots.items(): + # classification: include only the small extras, before skipping anything + if self.task_type == "classification" and name in { + "threshold", + "pr_auc", + "class_report", + }: + title = plot_title_map.get(name, name.replace("_", " ").title()) + b64 = encode_image_to_base64(path) + test_html += ( + f"<h2>{title}</h2>" + "<div class='plot'>" + f"<img src='data:image/png;base64,{b64}' " + "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" + "</div>" + add_hr_to_html() + ) + continue + + # regression: explicitly include the 'error' plot, before skipping + if self.task_type == "regression" and name == "error": + title = plot_title_map.get("error", "Prediction Error Distribution") + b64 = encode_image_to_base64(path) + test_html += ( + f"<h2>{title}</h2>" + "<div class='plot'>" + f"<img src='data:image/png;base64,{b64}' " + "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" + "</div>" + add_hr_to_html() + ) + continue + + # now skip any plots already rendered via test_order + if name in test_order: + continue + + # — Feature Importance — + feature_html = header + + # 6a) PyCaret’s default feature importances + feature_html += FeatureImportanceAnalyzer( data=self.data, target_col=self.target_col, task_type=self.task_type, output_dir=self.output_dir, exp=self.exp, best_model=self.best_model, - ) - feature_importance_html = analyzer.run() + ).run() - # --- Feature Metrics Help Button --- - feature_metrics_button_html = ( - '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">' - "Help: Metrics Guide" - "</button>" - "<style>" - ".help-modal-btn {" - "background-color: #17623b;" - "color: #fff;" - "border: none;" - "border-radius: 24px;" - "padding: 10px 28px;" - "font-size: 1.1rem;" - "font-weight: bold;" - "letter-spacing: 0.03em;" - "cursor: pointer;" - "transition: background 0.2s, box-shadow 0.2s;" - "box-shadow: 0 2px 8px rgba(23,98,59,0.07);" - "}" - ".help-modal-btn:hover, .help-modal-btn:focus {" - "background-color: #21895e;" - "outline: none;" - "box-shadow: 0 4px 16px rgba(23,98,59,0.14);" - "}" - "</style>" - ) + # 6b) Explainer SHAP importances + for key in ["shap_mean", "shap_perm"]: + fig_or_fn = self.explainer_plots.pop(key, None) + if fig_or_fn is not None: + fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + # give SHAP plots explicit titles + title = ( + "Mean Absolute SHAP Value Impact" + if key == "shap_mean" + else "Permutation Feature Importance" + ) + feature_html += ( + f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() + ) - html_content = ( - f"{get_html_template()}" - "<h1>Tabular Learner Model Report</h1>" - f"{feature_metrics_button_html}" - '<div class="tabs">' - '<div class="tab" onclick="openTab(event, \'summary\')">' - "Validation Result Summary & Config</div>" - '<div class="tab" onclick="openTab(event, \'plots\')">' - "Test Results</div>" - '<div class="tab" onclick="openTab(event, \'feature\')">' - "Feature Importance</div>" - ) - if self.plots_explainer_html: - html_content += ( - '<div class="tab" onclick="openTab(event, \'explainer\')">' - "Explainer Plots</div>" + # 6c) PDPs last + pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) + for k in pdp_keys: + fig_or_fn = self.explainer_plots[k] + fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn + # extract feature name + feature = k.split("__", 1)[1] + title = f"Partial Dependence for {feature}" + feature_html += ( + f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() ) - html_content += ( - "</div>" - '<div id="summary" class="tab-content">' - f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" - f"<h2>Best Model: {model_name}</h2>" - "<h5>The best model is selected by: Accuracy (Classification)" - " or R2 (Regression).</h5>" - f"{self.results.to_html(index=False, classes='table sortable')}" - "<h2>Best Model's Hyperparameters</h2>" - f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}" - "<h2>Setup Parameters</h2>" - f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}" - "<h5>If you want to know all the experiment setup parameters," - " please check the PyCaret documentation for" - " the classification/regression <code>exp</code> function.</h5>" - "</div>" - '<div id="plots" class="tab-content">' - f"<h2>Best Model: {model_name}</h2>" - "<h5>The best model is selected by: Accuracy (Classification)" - " or R2 (Regression).</h5>" - "<h2>Test Metrics</h2>" - f"{self.test_result_df.to_html(index=False)}" - "<h2>Test Results</h2>" - f"{plots_html}" - "</div>" - '<div id="feature" class="tab-content">' - f"{feature_importance_html}" - "</div>" + # 7) Assemble final HTML (three tabs) + html = get_html_template() + html += "<h1>Tabular Learner Model Report</h1>" + html += build_tabbed_html(summary_html, test_html, feature_html) + html += get_feature_metrics_help_modal() + html += get_html_closing() + + # 8) Write out + (Path(self.output_dir) / "comparison_result.html").write_text( + html, encoding="utf-8" ) - if self.plots_explainer_html: - html_content += ( - '<div id="explainer" class="tab-content">' - f"{self.plots_explainer_html}" - f"{tree_plots}" - "</div>" - ) - html_content += ( - "<script>" - "document.addEventListener(\"DOMContentLoaded\", function() {" - "var tables = document.querySelectorAll(\"table.sortable\");" - "tables.forEach(function(table) {" - "var headers = table.querySelectorAll(\"th\");" - "headers.forEach(function(header, index) {" - "header.style.cursor = \"pointer\";" - "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)" - "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';" - "header.addEventListener(\"click\", function() {" - "var direction = this.getAttribute(" - "\"data-sort-direction\"" - ") || \"asc\";" - "// Reset arrows in all headers of this table" - "headers.forEach(function(h) {" - "var arrow = h.querySelector(\".sort-arrow\");" - "if (arrow) arrow.textContent = \" ↑\";" - "});" - "// Set arrow for clicked header" - "var arrow = this.querySelector(\".sort-arrow\");" - "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";" - "sortTable(table, index, direction);" - "this.setAttribute(\"data-sort-direction\"," - "direction === \"asc\" ? \"desc\" : \"asc\");" - "});" - "});" - "});" - "});" - "function sortTable(table, colNum, direction) {" - "var tb = table.tBodies[0];" - "var tr = Array.prototype.slice.call(tb.rows, 0);" - "var multiplier = direction === \"asc\" ? 1 : -1;" - "tr = tr.sort(function(a, b) {" - "var aText = a.cells[colNum].textContent.trim();" - "var bText = b.cells[colNum].textContent.trim();" - "// Remove arrow from text comparison" - "aText = aText.replace(/[↑↓]/g, '').trim();" - "bText = bText.replace(/[↑↓]/g, '').trim();" - "if (!isNaN(aText) && !isNaN(bText)) {" - "return multiplier * (" - "parseFloat(aText) - parseFloat(bText)" - ");" - "} else {" - "return multiplier * aText.localeCompare(bText);" - "}" - "});" - "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);" - "}" - "</script>" - ) - # --- Add the Feature Metrics Help Modal --- - html_content += get_feature_metrics_help_modal() - html_content += f"{get_html_closing()}" - with open( - os.path.join(self.output_dir, "comparison_result.html"), - "w", - encoding="utf-8", - ) as file: - file.write(html_content) + LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") def save_dashboard(self): raise NotImplementedError("Subclasses should implement this method") @@ -426,29 +521,18 @@ X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed - is_rf = isinstance( - self.best_model, (RandomForestClassifier, RandomForestRegressor) - ) - is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor)) - - num_trees = None - if is_rf: - num_trees = self.best_model.n_estimators - elif is_xgb: - num_trees = len(self.best_model.get_booster().get_dump()) + if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): + n_trees = self.best_model.n_estimators + elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): + n_trees = len(self.best_model.get_booster().get_dump()) else: LOG.warning("Tree plots not supported for this model type.") return - try: - explainer = RandomForestExplainer(self.best_model, X_test, y_test) - for i in range(num_trees): - fig = explainer.decisiontree_encoded(tree_idx=i, index=0) - LOG.info(f"Tree {i + 1}") - LOG.info(fig) - self.trees.append(fig) - except Exception as e: - LOG.error(f"Error generating tree plots: {e}") + explainer = RandomForestExplainer(self.best_model, X_test, y_test) + for i in range(n_trees): + fig = explainer.decisiontree_encoded(tree_idx=i, index=0) + self.trees.append(fig) def run(self): self.load_data()