Mercurial > repos > goeckslab > tabular_learner
diff base_model_trainer.py @ 3:f6a65e05d6ec draft
planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
author | goeckslab |
---|---|
date | Wed, 09 Jul 2025 01:12:48 +0000 |
parents | 77c88226bfde |
children | 11fdac5affb3 |
line wrap: on
line diff
--- a/base_model_trainer.py Wed Jul 02 18:59:39 2025 +0000 +++ b/base_model_trainer.py Wed Jul 09 01:12:48 2025 +0000 @@ -127,9 +127,11 @@ and self.cross_validation is not None and self.cross_validation is False ): - self.setup_params["cross_validation"] = self.cross_validation + logging.info( + "cross_validation is set to False. This will disable cross-validation." + ) - if hasattr(self, "cross_validation") and self.cross_validation is not None: + if hasattr(self, "cross_validation") and self.cross_validation: if hasattr(self, "cross_validation_folds"): self.setup_params["fold"] = self.cross_validation_folds @@ -182,10 +184,11 @@ ) if hasattr(self, "models") and self.models is not None: - self.best_model = self.exp.compare_models(include=self.models) + self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) else: - self.best_model = self.exp.compare_models() + self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) self.results = self.exp.pull() + if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) @@ -314,7 +317,7 @@ html_content += ( "</div>" '<div id="summary" class="tab-content">' - "<h2>Model Metrics from Cross-Validation Set</h2>" + f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" f"<h2>Best Model: {model_name}</h2>" "<h5>The best model is selected by: Accuracy (Classification)" " or R2 (Regression).</h5>"