Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 7:f4cb41f458fd draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit b430f8b466655878c3bf63b053655fdbf039ddb0
author | goeckslab |
---|---|
date | Wed, 09 Jul 2025 01:13:01 +0000 |
parents | a32ff7201629 |
children |
comparison
equal
deleted
inserted
replaced
6:a32ff7201629 | 7:f4cb41f458fd |
---|---|
125 if ( | 125 if ( |
126 hasattr(self, "cross_validation") | 126 hasattr(self, "cross_validation") |
127 and self.cross_validation is not None | 127 and self.cross_validation is not None |
128 and self.cross_validation is False | 128 and self.cross_validation is False |
129 ): | 129 ): |
130 self.setup_params["cross_validation"] = self.cross_validation | 130 logging.info( |
131 | 131 "cross_validation is set to False. This will disable cross-validation." |
132 if hasattr(self, "cross_validation") and self.cross_validation is not None: | 132 ) |
133 | |
134 if hasattr(self, "cross_validation") and self.cross_validation: | |
133 if hasattr(self, "cross_validation_folds"): | 135 if hasattr(self, "cross_validation_folds"): |
134 self.setup_params["fold"] = self.cross_validation_folds | 136 self.setup_params["fold"] = self.cross_validation_folds |
135 | 137 |
136 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: | 138 if hasattr(self, "remove_outliers") and self.remove_outliers is not None: |
137 self.setup_params["remove_outliers"] = self.remove_outliers | 139 self.setup_params["remove_outliers"] = self.remove_outliers |
180 score_func=average_precision_score, | 182 score_func=average_precision_score, |
181 average="weighted", | 183 average="weighted", |
182 ) | 184 ) |
183 | 185 |
184 if hasattr(self, "models") and self.models is not None: | 186 if hasattr(self, "models") and self.models is not None: |
185 self.best_model = self.exp.compare_models(include=self.models) | 187 self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation) |
186 else: | 188 else: |
187 self.best_model = self.exp.compare_models() | 189 self.best_model = self.exp.compare_models(cross_validation=self.cross_validation) |
188 self.results = self.exp.pull() | 190 self.results = self.exp.pull() |
191 | |
189 if self.task_type == "classification": | 192 if self.task_type == "classification": |
190 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 193 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
191 | 194 |
192 _ = self.exp.predict_model(self.best_model) | 195 _ = self.exp.predict_model(self.best_model) |
193 self.test_result_df = self.exp.pull() | 196 self.test_result_df = self.exp.pull() |
312 "Explainer Plots</div>" | 315 "Explainer Plots</div>" |
313 ) | 316 ) |
314 html_content += ( | 317 html_content += ( |
315 "</div>" | 318 "</div>" |
316 '<div id="summary" class="tab-content">' | 319 '<div id="summary" class="tab-content">' |
317 "<h2>Model Metrics from Cross-Validation Set</h2>" | 320 f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>" |
318 f"<h2>Best Model: {model_name}</h2>" | 321 f"<h2>Best Model: {model_name}</h2>" |
319 "<h5>The best model is selected by: Accuracy (Classification)" | 322 "<h5>The best model is selected by: Accuracy (Classification)" |
320 " or R2 (Regression).</h5>" | 323 " or R2 (Regression).</h5>" |
321 f"{self.results.to_html(index=False, classes='table sortable')}" | 324 f"{self.results.to_html(index=False, classes='table sortable')}" |
322 "<h2>Best Model's Hyperparameters</h2>" | 325 "<h2>Best Model's Hyperparameters</h2>" |