diff base_model_trainer.py @ 7:f4cb41f458fd draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
author goeckslab
date Wed, 09 Jul 2025 01:13:01 +0000
parents a32ff7201629
children
line wrap: on
line diff
--- a/base_model_trainer.py	Wed Jul 02 19:00:03 2025 +0000
+++ b/base_model_trainer.py	Wed Jul 09 01:13:01 2025 +0000
@@ -127,9 +127,11 @@
             and self.cross_validation is not None
             and self.cross_validation is False
         ):
-            self.setup_params["cross_validation"] = self.cross_validation
+            logging.info(
+                "cross_validation is set to False. This will disable cross-validation."
+            )
 
-        if hasattr(self, "cross_validation") and self.cross_validation is not None:
+        if hasattr(self, "cross_validation") and self.cross_validation:
             if hasattr(self, "cross_validation_folds"):
                 self.setup_params["fold"] = self.cross_validation_folds
 
@@ -182,10 +184,11 @@
             )
 
         if hasattr(self, "models") and self.models is not None:
-            self.best_model = self.exp.compare_models(include=self.models)
+            self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation)
         else:
-            self.best_model = self.exp.compare_models()
+            self.best_model = self.exp.compare_models(cross_validation=self.cross_validation)
         self.results = self.exp.pull()
+
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
@@ -314,7 +317,7 @@
         html_content += (
             "</div>"
             '<div id="summary" class="tab-content">'
-            "<h2>Model Metrics from Cross-Validation Set</h2>"
+            f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>"
             f"<h2>Best Model: {model_name}</h2>"
             "<h5>The best model is selected by: Accuracy (Classification)"
             " or R2 (Regression).</h5>"