comparison base_model_trainer.py @ 9:e7dd78077b72 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 84d5cd0b1fa5c1ff0ad892bc39c95dad1ceb4920
author goeckslab
date Sat, 08 Nov 2025 14:20:19 +0000
parents 4bd75b45a7a1
children
comparison
equal deleted inserted replaced
8:ba45bc057d70 9:e7dd78077b72
197 ) 197 )
198 198
199 self.exp.setup(self.data, **self.setup_params) 199 self.exp.setup(self.data, **self.setup_params)
200 self.setup_params.update(self.user_kwargs) 200 self.setup_params.update(self.user_kwargs)
201 201
202 def _normalize_metric(self, m: str) -> str:
203 if not m:
204 return "R2" if self.task_type == "regression" else "Accuracy"
205 m_low = str(m).strip().lower()
206 alias = {
207 "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
208 "accuracy": "Accuracy",
209 "precision": "Precision",
210 "recall": "Recall",
211 "f1": "F1",
212 "kappa": "Kappa",
213 "logloss": "Log Loss", "log_loss": "Log Loss",
214 "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
215 "r2": "R2",
216 "mae": "MAE",
217 "mse": "MSE",
218 "rmse": "RMSE",
219 "rmsle": "RMSLE",
220 "mape": "MAPE",
221 }
222 return alias.get(m_low, m)
223
202 def train_model(self): 224 def train_model(self):
203 LOG.info("Training and selecting the best model") 225 LOG.info("Training and selecting the best model")
204 if self.task_type == "classification": 226 if self.task_type == "classification":
205 self.exp.add_metric( 227 self.exp.add_metric(
206 id="PR-AUC-Weighted", 228 id="PR-AUC-Weighted",
219 compare_kwargs["cross_validation"] = self.cross_validation 241 compare_kwargs["cross_validation"] = self.cross_validation
220 242
221 # Respect explicit fold count 243 # Respect explicit fold count
222 if getattr(self, "cross_validation_folds", None) is not None: 244 if getattr(self, "cross_validation_folds", None) is not None:
223 compare_kwargs["fold"] = self.cross_validation_folds 245 compare_kwargs["fold"] = self.cross_validation_folds
246
247 chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
248 if chosen_metric:
249 compare_kwargs["sort"] = chosen_metric
250 self.chosen_metric_label = chosen_metric
251 try:
252 setattr(self.exp, "_fold_metric", chosen_metric)
253 except Exception as e:
254 LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
224 255
225 LOG.info(f"compare_models kwargs: {compare_kwargs}") 256 LOG.info(f"compare_models kwargs: {compare_kwargs}")
226 self.best_model = self.exp.compare_models(**compare_kwargs) 257 self.best_model = self.exp.compare_models(**compare_kwargs)
227 self.results = self.exp.pull() 258 self.results = self.exp.pull()
228 if getattr(self, "tune_model", False): 259 if getattr(self, "tune_model", False):
367 elif key == "Probability Threshold": 398 elif key == "Probability Threshold":
368 dv = f"{v:.2f}" if v is not None else "0.5" 399 dv = f"{v:.2f}" if v is not None else "0.5"
369 else: 400 else:
370 dv = v if v is not None else "None" 401 dv = v if v is not None else "None"
371 setup_rows.append([key, dv]) 402 setup_rows.append([key, dv])
372 if hasattr(self.exp, "_fold_metric"): 403 if getattr(self, "chosen_metric_label", None):
373 setup_rows.append(["best_model_metric", self.exp._fold_metric]) 404 setup_rows.append(["Best Model Metric", self.chosen_metric_label])
374 405
375 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) 406 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
376 df_setup.to_csv( 407 df_setup.to_csv(
377 Path(self.output_dir) / "setup_params.csv", index=False 408 Path(self.output_dir) / "setup_params.csv", index=False
378 ) 409 )