Mercurial > repos > goeckslab > tabular_learner
comparison base_model_trainer.py @ 9:e7dd78077b72 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 84d5cd0b1fa5c1ff0ad892bc39c95dad1ceb4920
| author | goeckslab |
|---|---|
| date | Sat, 08 Nov 2025 14:20:19 +0000 |
| parents | 4bd75b45a7a1 |
| children |
comparison
equal
deleted
inserted
replaced
| 8:ba45bc057d70 | 9:e7dd78077b72 |
|---|---|
| 197 ) | 197 ) |
| 198 | 198 |
| 199 self.exp.setup(self.data, **self.setup_params) | 199 self.exp.setup(self.data, **self.setup_params) |
| 200 self.setup_params.update(self.user_kwargs) | 200 self.setup_params.update(self.user_kwargs) |
| 201 | 201 |
| 202 def _normalize_metric(self, m: str) -> str: | |
| 203 if not m: | |
| 204 return "R2" if self.task_type == "regression" else "Accuracy" | |
| 205 m_low = str(m).strip().lower() | |
| 206 alias = { | |
| 207 "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC", | |
| 208 "accuracy": "Accuracy", | |
| 209 "precision": "Precision", | |
| 210 "recall": "Recall", | |
| 211 "f1": "F1", | |
| 212 "kappa": "Kappa", | |
| 213 "logloss": "Log Loss", "log_loss": "Log Loss", | |
| 214 "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted", | |
| 215 "r2": "R2", | |
| 216 "mae": "MAE", | |
| 217 "mse": "MSE", | |
| 218 "rmse": "RMSE", | |
| 219 "rmsle": "RMSLE", | |
| 220 "mape": "MAPE", | |
| 221 } | |
| 222 return alias.get(m_low, m) | |
| 223 | |
| 202 def train_model(self): | 224 def train_model(self): |
| 203 LOG.info("Training and selecting the best model") | 225 LOG.info("Training and selecting the best model") |
| 204 if self.task_type == "classification": | 226 if self.task_type == "classification": |
| 205 self.exp.add_metric( | 227 self.exp.add_metric( |
| 206 id="PR-AUC-Weighted", | 228 id="PR-AUC-Weighted", |
| 219 compare_kwargs["cross_validation"] = self.cross_validation | 241 compare_kwargs["cross_validation"] = self.cross_validation |
| 220 | 242 |
| 221 # Respect explicit fold count | 243 # Respect explicit fold count |
| 222 if getattr(self, "cross_validation_folds", None) is not None: | 244 if getattr(self, "cross_validation_folds", None) is not None: |
| 223 compare_kwargs["fold"] = self.cross_validation_folds | 245 compare_kwargs["fold"] = self.cross_validation_folds |
| 246 | |
| 247 chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None)) | |
| 248 if chosen_metric: | |
| 249 compare_kwargs["sort"] = chosen_metric | |
| 250 self.chosen_metric_label = chosen_metric | |
| 251 try: | |
| 252 setattr(self.exp, "_fold_metric", chosen_metric) | |
| 253 except Exception as e: | |
| 254 LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True) | |
| 224 | 255 |
| 225 LOG.info(f"compare_models kwargs: {compare_kwargs}") | 256 LOG.info(f"compare_models kwargs: {compare_kwargs}") |
| 226 self.best_model = self.exp.compare_models(**compare_kwargs) | 257 self.best_model = self.exp.compare_models(**compare_kwargs) |
| 227 self.results = self.exp.pull() | 258 self.results = self.exp.pull() |
| 228 if getattr(self, "tune_model", False): | 259 if getattr(self, "tune_model", False): |
| 367 elif key == "Probability Threshold": | 398 elif key == "Probability Threshold": |
| 368 dv = f"{v:.2f}" if v is not None else "0.5" | 399 dv = f"{v:.2f}" if v is not None else "0.5" |
| 369 else: | 400 else: |
| 370 dv = v if v is not None else "None" | 401 dv = v if v is not None else "None" |
| 371 setup_rows.append([key, dv]) | 402 setup_rows.append([key, dv]) |
| 372 if hasattr(self.exp, "_fold_metric"): | 403 if getattr(self, "chosen_metric_label", None): |
| 373 setup_rows.append(["best_model_metric", self.exp._fold_metric]) | 404 setup_rows.append(["Best Model Metric", self.chosen_metric_label]) |
| 374 | 405 |
| 375 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) | 406 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) |
| 376 df_setup.to_csv( | 407 df_setup.to_csv( |
| 377 Path(self.output_dir) / "setup_params.csv", index=False | 408 Path(self.output_dir) / "setup_params.csv", index=False |
| 378 ) | 409 ) |
