diff base_model_trainer.py @ 6:a32ff7201629 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 19:00:03 +0000
parents ccd798db5abb
children
line wrap: on
line diff
--- a/base_model_trainer.py	Sat Jun 21 15:07:04 2025 +0000
+++ b/base_model_trainer.py	Wed Jul 02 19:00:03 2025 +0000
@@ -7,6 +7,7 @@
 import joblib
 import numpy as np
 import pandas as pd
+from feature_help_modal import get_feature_metrics_help_modal
 from feature_importance import FeatureImportanceAnalyzer
 from sklearn.metrics import average_precision_score
 from utils import get_html_closing, get_html_template
@@ -16,16 +17,16 @@
 
 
 class BaseModelTrainer:
-
     def __init__(
-            self,
-            input_file,
-            target_col,
-            output_dir,
-            task_type,
-            random_seed,
-            test_file=None,
-            **kwargs):
+        self,
+        input_file,
+        target_col,
+        output_dir,
+        task_type,
+        random_seed,
+        test_file=None,
+        **kwargs,
+    ):
         self.exp = None  # This will be set in the subclass
         self.input_file = input_file
         self.target_col = target_col
@@ -47,18 +48,26 @@
         self.test_file = test_file
         self.test_data = None
 
+        if not self.output_dir:
+            raise ValueError("output_dir must be specified and not None")
+
         LOG.info(f"Model kwargs: {self.__dict__}")
 
     def load_data(self):
         LOG.info(f"Loading data from {self.input_file}")
-        self.data = pd.read_csv(self.input_file, sep=None, engine='python')
-        self.data.columns = self.data.columns.str.replace('.', '_')
+        self.data = pd.read_csv(self.input_file, sep=None, engine="python")
+        self.data.columns = self.data.columns.str.replace(".", "_")
 
-        numeric_cols = self.data.select_dtypes(include=['number']).columns
-        non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns
+        # Remove prediction_label if present
+        if "prediction_label" in self.data.columns:
+            self.data = self.data.drop(columns=["prediction_label"])
+
+        numeric_cols = self.data.select_dtypes(include=["number"]).columns
+        non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
 
         self.data[numeric_cols] = self.data[numeric_cols].apply(
-            pd.to_numeric, errors='coerce')
+            pd.to_numeric, errors="coerce"
+        )
 
         if len(non_numeric_cols) > 0:
             LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
@@ -66,17 +75,13 @@
         names = self.data.columns.to_list()
         target_index = int(self.target_col) - 1
         self.target = names[target_index]
-        self.features_name = [name
-                              for i, name in enumerate(names)
-                              if i != target_index]
-        if hasattr(self, 'missing_value_strategy'):
-            if self.missing_value_strategy == 'mean':
-                self.data = self.data.fillna(
-                    self.data.mean(numeric_only=True))
-            elif self.missing_value_strategy == 'median':
-                self.data = self.data.fillna(
-                    self.data.median(numeric_only=True))
-            elif self.missing_value_strategy == 'drop':
+        self.features_name = [name for i, name in enumerate(names) if i != target_index]
+        if hasattr(self, "missing_value_strategy"):
+            if self.missing_value_strategy == "mean":
+                self.data = self.data.fillna(self.data.mean(numeric_only=True))
+            elif self.missing_value_strategy == "median":
+                self.data = self.data.fillna(self.data.median(numeric_only=True))
+            elif self.missing_value_strategy == "drop":
                 self.data = self.data.dropna()
         else:
             # Default strategy if not specified
@@ -84,287 +89,322 @@
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
-            self.test_data = pd.read_csv(
-                self.test_file, sep=None, engine='python')
+            self.test_data = pd.read_csv(self.test_file, sep=None, engine="python")
             self.test_data = self.test_data[numeric_cols].apply(
-                pd.to_numeric, errors='coerce')
-            self.test_data.columns = self.test_data.columns.str.replace(
-                '.', '_'
+                pd.to_numeric, errors="coerce"
             )
+            self.test_data.columns = self.test_data.columns.str.replace(".", "_")
 
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
         self.setup_params = {
-            'target': self.target,
-            'session_id': self.random_seed,
-            'html': True,
-            'log_experiment': False,
-            'system_log': False,
-            'index': False,
+            "target": self.target,
+            "session_id": self.random_seed,
+            "html": True,
+            "log_experiment": False,
+            "system_log": False,
+            "index": False,
         }
 
         if self.test_data is not None:
-            self.setup_params['test_data'] = self.test_data
+            self.setup_params["test_data"] = self.test_data
 
-        if hasattr(self, 'train_size') and self.train_size is not None \
-                and self.test_data is None:
-            self.setup_params['train_size'] = self.train_size
-
-        if hasattr(self, 'normalize') and self.normalize is not None:
-            self.setup_params['normalize'] = self.normalize
+        if (
+            hasattr(self, "train_size")
+            and self.train_size is not None
+            and self.test_data is None
+        ):
+            self.setup_params["train_size"] = self.train_size
 
-        if hasattr(self, 'feature_selection') and \
-                self.feature_selection is not None:
-            self.setup_params['feature_selection'] = self.feature_selection
+        if hasattr(self, "normalize") and self.normalize is not None:
+            self.setup_params["normalize"] = self.normalize
+
+        if hasattr(self, "feature_selection") and self.feature_selection is not None:
+            self.setup_params["feature_selection"] = self.feature_selection
 
-        if hasattr(self, 'cross_validation') and \
-                self.cross_validation is not None \
-                and self.cross_validation is False:
-            self.setup_params['cross_validation'] = self.cross_validation
+        if (
+            hasattr(self, "cross_validation")
+            and self.cross_validation is not None
+            and self.cross_validation is False
+        ):
+            self.setup_params["cross_validation"] = self.cross_validation
 
-        if hasattr(self, 'cross_validation') and \
-                self.cross_validation is not None:
-            if hasattr(self, 'cross_validation_folds'):
-                self.setup_params['fold'] = self.cross_validation_folds
+        if hasattr(self, "cross_validation") and self.cross_validation is not None:
+            if hasattr(self, "cross_validation_folds"):
+                self.setup_params["fold"] = self.cross_validation_folds
 
-        if hasattr(self, 'remove_outliers') and \
-                self.remove_outliers is not None:
-            self.setup_params['remove_outliers'] = self.remove_outliers
+        if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
+            self.setup_params["remove_outliers"] = self.remove_outliers
 
-        if hasattr(self, 'remove_multicollinearity') and \
-                self.remove_multicollinearity is not None:
-            self.setup_params['remove_multicollinearity'] = \
+        if (
+            hasattr(self, "remove_multicollinearity")
+            and self.remove_multicollinearity is not None
+        ):
+            self.setup_params["remove_multicollinearity"] = (
                 self.remove_multicollinearity
+            )
 
-        if hasattr(self, 'polynomial_features') and \
-                self.polynomial_features is not None:
-            self.setup_params['polynomial_features'] = self.polynomial_features
+        if (
+            hasattr(self, "polynomial_features")
+            and self.polynomial_features is not None
+        ):
+            self.setup_params["polynomial_features"] = self.polynomial_features
 
-        if hasattr(self, 'fix_imbalance') and \
-                self.fix_imbalance is not None:
-            self.setup_params['fix_imbalance'] = self.fix_imbalance
+        if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
+            self.setup_params["fix_imbalance"] = self.fix_imbalance
 
         LOG.info(self.setup_params)
+
+        # Solution: instantiate the correct PyCaret experiment based on task_type
+        if self.task_type == "classification":
+            from pycaret.classification import ClassificationExperiment
+
+            self.exp = ClassificationExperiment()
+        elif self.task_type == "regression":
+            from pycaret.regression import RegressionExperiment
+
+            self.exp = RegressionExperiment()
+        else:
+            raise ValueError("task_type must be 'classification' or 'regression'")
+
         self.exp.setup(self.data, **self.setup_params)
 
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
             average_displayed = "Weighted"
-            self.exp.add_metric(id=f'PR-AUC-{average_displayed}',
-                                name=f'PR-AUC-{average_displayed}',
-                                target='pred_proba',
-                                score_func=average_precision_score,
-                                average='weighted'
-                                )
+            self.exp.add_metric(
+                id=f"PR-AUC-{average_displayed}",
+                name=f"PR-AUC-{average_displayed}",
+                target="pred_proba",
+                score_func=average_precision_score,
+                average="weighted",
+            )
 
-        if hasattr(self, 'models') and self.models is not None:
-            self.best_model = self.exp.compare_models(
-                include=self.models)
+        if hasattr(self, "models") and self.models is not None:
+            self.best_model = self.exp.compare_models(include=self.models)
         else:
             self.best_model = self.exp.compare_models()
         self.results = self.exp.pull()
         if self.task_type == "classification":
-            self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)
+            self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
         _ = self.exp.predict_model(self.best_model)
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
-            self.test_result_df.rename(
-                columns={'AUC': 'ROC-AUC'}, inplace=True)
+            self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
     def save_model(self):
         hdf5_model_path = "pycaret_model.h5"
-        with h5py.File(hdf5_model_path, 'w') as f:
+        with h5py.File(hdf5_model_path, "w") as f:
             with tempfile.NamedTemporaryFile(delete=False) as temp_file:
                 joblib.dump(self.best_model, temp_file.name)
                 temp_file.seek(0)
                 model_bytes = temp_file.read()
-            f.create_dataset('model', data=np.void(model_bytes))
+            f.create_dataset("model", data=np.void(model_bytes))
 
     def generate_plots(self):
         raise NotImplementedError("Subclasses should implement this method")
 
     def encode_image_to_base64(self, img_path):
-        with open(img_path, 'rb') as img_file:
-            return base64.b64encode(img_file.read()).decode('utf-8')
+        with open(img_path, "rb") as img_file:
+            return base64.b64encode(img_file.read()).decode("utf-8")
 
     def save_html_report(self):
         LOG.info("Saving HTML report")
 
+        if not self.output_dir:
+            raise ValueError("output_dir must be specified and not None")
+
         model_name = type(self.best_model).__name__
-        excluded_params = ['html', 'log_experiment', 'system_log', 'test_data']
+        excluded_params = ["html", "log_experiment", "system_log", "test_data"]
         filtered_setup_params = {
-            k: v
-            for k, v in self.setup_params.items() if k not in excluded_params
+            k: v for k, v in self.setup_params.items() if k not in excluded_params
         }
         setup_params_table = pd.DataFrame(
-            list(filtered_setup_params.items()), columns=['Parameter', 'Value']
+            list(filtered_setup_params.items()), columns=["Parameter", "Value"]
         )
 
         best_model_params = pd.DataFrame(
-            self.best_model.get_params().items(),
-            columns=['Parameter', 'Value']
+            self.best_model.get_params().items(), columns=["Parameter", "Value"]
         )
         best_model_params.to_csv(
             os.path.join(self.output_dir, "best_model.csv"), index=False
         )
-        self.results.to_csv(
-            os.path.join(self.output_dir, "comparison_results.csv")
-        )
-        self.test_result_df.to_csv(
-            os.path.join(self.output_dir, "test_results.csv")
-        )
+        self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
+        self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv"))
 
         plots_html = ""
         length = len(self.plots)
         for i, (plot_name, plot_path) in enumerate(self.plots.items()):
             encoded_image = self.encode_image_to_base64(plot_path)
-            plots_html += f"""
-            <div class="plot">
-                <h3>{plot_name.capitalize()}</h3>
-                <img src="data:image/png;base64,{encoded_image}"
-                    alt="{plot_name}">
-            </div>
-            """
+            plots_html += (
+                f'<div class="plot">'
+                f"<h3>{plot_name.capitalize()}</h3>"
+                f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">'
+                f"</div>"
+            )
             if i < length - 1:
                 plots_html += "<hr>"
 
         tree_plots = ""
         for i, tree in enumerate(self.trees):
             if tree:
-                tree_plots += f"""
-                <div class="plot">
-                    <h3>Tree {i+1}</h3>
-                    <img src="data:image/png;base64,
-                    {tree}"
-                    alt="tree {i+1}">
-                </div>
-                """
+                tree_plots += (
+                    f'<div class="plot">'
+                    f"<h3>Tree {i + 1}</h3>"
+                    f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">'
+                    f"</div>"
+                )
 
         analyzer = FeatureImportanceAnalyzer(
             data=self.data,
             target_col=self.target_col,
             task_type=self.task_type,
             output_dir=self.output_dir,
+            exp=self.exp,
+            best_model=self.best_model,
         )
         feature_importance_html = analyzer.run()
 
-        html_content = f"""
-        {get_html_template()}
-            <h1>PyCaret Model Training Report</h1>
-            <div class="tabs">
-                <div class="tab" onclick="openTab(event, 'summary')">
-                Setup & Best Model</div>
-                <div class="tab" onclick="openTab(event, 'plots')">
-                Best Model Plots</div>
-                <div class="tab" onclick="openTab(event, 'feature')">
-                Feature Importance</div>
-        """
-        if self.plots_explainer_html:
-            html_content += """
-                <div class="tab" onclick="openTab(event, 'explainer')">
-                Explainer Plots</div>
-            """
-        html_content += f"""
-            </div>
-            <div id="summary" class="tab-content">
-                <h2>Setup Parameters</h2>
-                {setup_params_table.to_html(
-                    index=False,
-                    header=True,
-                    classes='table sortable'
-                )}
-                <h5>If you want to know all the experiment setup parameters,
-                  please check the PyCaret documentation for
-                  the classification/regression <code>exp</code> function.</h5>
-                <h2>Best Model: {model_name}</h2>
-                {best_model_params.to_html(
-                    index=False,
-                    header=True,
-                    classes='table sortable'
-                )}
-                <h2>Comparison Results on the Cross-Validation Set</h2>
-                {self.results.to_html(index=False, classes='table sortable')}
-                <h2>Results on the Test Set for the best model</h2>
-                {self.test_result_df.to_html(
-                    index=False,
-                    classes='table sortable'
-                )}
-            </div>
-            <div id="plots" class="tab-content">
-                <h2>Best Model Plots on the testing set</h2>
-                {plots_html}
-            </div>
-            <div id="feature" class="tab-content">
-                {feature_importance_html}
-            </div>
-        """
+        # --- Feature Metrics Help Button ---
+        feature_metrics_button_html = (
+            '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
+            "Help: Metrics Guide"
+            "</button>"
+            "<style>"
+            ".help-modal-btn {"
+            "background-color: #17623b;"
+            "color: #fff;"
+            "border: none;"
+            "border-radius: 24px;"
+            "padding: 10px 28px;"
+            "font-size: 1.1rem;"
+            "font-weight: bold;"
+            "letter-spacing: 0.03em;"
+            "cursor: pointer;"
+            "transition: background 0.2s, box-shadow 0.2s;"
+            "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
+            "}"
+            ".help-modal-btn:hover, .help-modal-btn:focus {"
+            "background-color: #21895e;"
+            "outline: none;"
+            "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
+            "}"
+            "</style>"
+        )
+
+        html_content = (
+            f"{get_html_template()}"
+            "<h1>Tabular Learner Model Report</h1>"
+            f"{feature_metrics_button_html}"
+            '<div class="tabs">'
+            '<div class="tab" onclick="openTab(event, \'summary\')">'
+            "Validation Result Summary & Config</div>"
+            '<div class="tab" onclick="openTab(event, \'plots\')">'
+            "Test Results</div>"
+            '<div class="tab" onclick="openTab(event, \'feature\')">'
+            "Feature Importance</div>"
+        )
         if self.plots_explainer_html:
-            html_content += f"""
-            <div id="explainer" class="tab-content">
-                {self.plots_explainer_html}
-                {tree_plots}
-            </div>
-            """
-        html_content += """
-        <script>
-        document.addEventListener("DOMContentLoaded", function() {
-            var tables = document.querySelectorAll("table.sortable");
-            tables.forEach(function(table) {
-                var headers = table.querySelectorAll("th");
-                headers.forEach(function(header, index) {
-                    header.style.cursor = "pointer";
-                    // Add initial arrow (up) to indicate sortability
-                    header.innerHTML += '<span class="sort-arrow"> ↑</span>';
-                    header.addEventListener("click", function() {
-                        var direction = this.getAttribute(
-                            "data-sort-direction"
-                        ) || "asc";
-                        // Reset arrows in all headers of this table
-                        headers.forEach(function(h) {
-                            var arrow = h.querySelector(".sort-arrow");
-                            if (arrow) arrow.textContent = " ↑";
-                        });
-                        // Set arrow for clicked header
-                        var arrow = this.querySelector(".sort-arrow");
-                        arrow.textContent = direction === "asc" ? " ↓" : " ↑";
-                        sortTable(table, index, direction);
-                        this.setAttribute("data-sort-direction",
-                        direction === "asc" ? "desc" : "asc");
-                    });
-                });
-            });
-        });
-
-        function sortTable(table, colNum, direction) {
-            var tb = table.tBodies[0];
-            var tr = Array.prototype.slice.call(tb.rows, 0);
-            var multiplier = direction === "asc" ? 1 : -1;
-            tr = tr.sort(function(a, b) {
-                var aText = a.cells[colNum].textContent.trim();
-                var bText = b.cells[colNum].textContent.trim();
-                // Remove arrow from text comparison
-                aText = aText.replace(/[↑↓]/g, '').trim();
-                bText = bText.replace(/[↑↓]/g, '').trim();
-                if (!isNaN(aText) && !isNaN(bText)) {
-                    return multiplier * (
-                        parseFloat(aText) - parseFloat(bText)
-                    );
-                } else {
-                    return multiplier * aText.localeCompare(bText);
-                }
-            });
-            for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);
-        }
-        </script>
-        """
-        html_content += f"""
-        {get_html_closing()}
-        """
+            html_content += (
+                '<div class="tab" onclick="openTab(event, \'explainer\')">'
+                "Explainer Plots</div>"
+            )
+        html_content += (
+            "</div>"
+            '<div id="summary" class="tab-content">'
+            "<h2>Model Metrics from Cross-Validation Set</h2>"
+            f"<h2>Best Model: {model_name}</h2>"
+            "<h5>The best model is selected by: Accuracy (Classification)"
+            " or R2 (Regression).</h5>"
+            f"{self.results.to_html(index=False, classes='table sortable')}"
+            "<h2>Best Model's Hyperparameters</h2>"
+            f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
+            "<h2>Setup Parameters</h2>"
+            f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
+            "<h5>If you want to know all the experiment setup parameters,"
+            " please check the PyCaret documentation for"
+            " the classification/regression <code>exp</code> function.</h5>"
+            "</div>"
+            '<div id="plots" class="tab-content">'
+            f"<h2>Best Model: {model_name}</h2>"
+            "<h5>The best model is selected by: Accuracy (Classification)"
+            " or R2 (Regression).</h5>"
+            "<h2>Test Metrics</h2>"
+            f"{self.test_result_df.to_html(index=False)}"
+            "<h2>Test Results</h2>"
+            f"{plots_html}"
+            "</div>"
+            '<div id="feature" class="tab-content">'
+            f"{feature_importance_html}"
+            "</div>"
+        )
+        if self.plots_explainer_html:
+            html_content += (
+                '<div id="explainer" class="tab-content">'
+                f"{self.plots_explainer_html}"
+                f"{tree_plots}"
+                "</div>"
+            )
+        html_content += (
+            "<script>"
+            "document.addEventListener(\"DOMContentLoaded\", function() {"
+            "var tables = document.querySelectorAll(\"table.sortable\");"
+            "tables.forEach(function(table) {"
+            "var headers = table.querySelectorAll(\"th\");"
+            "headers.forEach(function(header, index) {"
+            "header.style.cursor = \"pointer\";"
+            "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
+            "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
+            "header.addEventListener(\"click\", function() {"
+            "var direction = this.getAttribute("
+            "\"data-sort-direction\""
+            ") || \"asc\";"
+            "// Reset arrows in all headers of this table"
+            "headers.forEach(function(h) {"
+            "var arrow = h.querySelector(\".sort-arrow\");"
+            "if (arrow) arrow.textContent = \" ↑\";"
+            "});"
+            "// Set arrow for clicked header"
+            "var arrow = this.querySelector(\".sort-arrow\");"
+            "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
+            "sortTable(table, index, direction);"
+            "this.setAttribute(\"data-sort-direction\","
+            "direction === \"asc\" ? \"desc\" : \"asc\");"
+            "});"
+            "});"
+            "});"
+            "});"
+            "function sortTable(table, colNum, direction) {"
+            "var tb = table.tBodies[0];"
+            "var tr = Array.prototype.slice.call(tb.rows, 0);"
+            "var multiplier = direction === \"asc\" ? 1 : -1;"
+            "tr = tr.sort(function(a, b) {"
+            "var aText = a.cells[colNum].textContent.trim();"
+            "var bText = b.cells[colNum].textContent.trim();"
+            "// Remove arrow from text comparison"
+            "aText = aText.replace(/[↑↓]/g, '').trim();"
+            "bText = bText.replace(/[↑↓]/g, '').trim();"
+            "if (!isNaN(aText) && !isNaN(bText)) {"
+            "return multiplier * ("
+            "parseFloat(aText) - parseFloat(bText)"
+            ");"
+            "} else {"
+            "return multiplier * aText.localeCompare(bText);"
+            "}"
+            "});"
+            "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
+            "}"
+            "</script>"
+        )
+        # --- Add the Feature Metrics Help Modal ---
+        html_content += get_feature_metrics_help_modal()
+        html_content += f"{get_html_closing()}"
         with open(
             os.path.join(self.output_dir, "comparison_result.html"),
-            "w"
+            "w",
+            encoding="utf-8",
         ) as file:
             file.write(html_content)
 
@@ -374,10 +414,8 @@
     def generate_plots_explainer(self):
         raise NotImplementedError("Subclasses should implement this method")
 
-    # not working now
     def generate_tree_plots(self):
-        from sklearn.ensemble import RandomForestClassifier, \
-            RandomForestRegressor
+        from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
         from xgboost import XGBClassifier, XGBRegressor
         from explainerdashboard.explainers import RandomForestExplainer
 
@@ -385,21 +423,25 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        is_rf = isinstance(self.best_model, RandomForestClassifier) or \
-            isinstance(self.best_model, RandomForestRegressor)
+        is_rf = isinstance(
+            self.best_model, (RandomForestClassifier, RandomForestRegressor)
+        )
+        is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor))
 
-        is_xgb = isinstance(self.best_model, XGBClassifier) or \
-            isinstance(self.best_model, XGBRegressor)
+        num_trees = None
+        if is_rf:
+            num_trees = self.best_model.n_estimators
+        elif is_xgb:
+            num_trees = len(self.best_model.get_booster().get_dump())
+        else:
+            LOG.warning("Tree plots not supported for this model type.")
+            return
 
         try:
-            if is_rf:
-                num_trees = self.best_model.n_estimators
-            if is_xgb:
-                num_trees = len(self.best_model.get_booster().get_dump())
             explainer = RandomForestExplainer(self.best_model, X_test, y_test)
             for i in range(num_trees):
                 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
-                LOG.info(f"Tree {i+1}")
+                LOG.info(f"Tree {i + 1}")
                 LOG.info(fig)
                 self.trees.append(fig)
         except Exception as e: