changeset 9:e7dd78077b72 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 84d5cd0b1fa5c1ff0ad892bc39c95dad1ceb4920
author goeckslab
date Sat, 08 Nov 2025 14:20:19 +0000
parents ba45bc057d70
children
files base_model_trainer.py pycaret_macros.xml pycaret_train.py tabular_learner.xml test-data/expected_best_model_classification_customized.csv test-data/expected_model_classification_customized.h5
diffstat 6 files changed, 70 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Mon Sep 08 22:38:55 2025 +0000
+++ b/base_model_trainer.py	Sat Nov 08 14:20:19 2025 +0000
@@ -199,6 +199,28 @@
         self.exp.setup(self.data, **self.setup_params)
         self.setup_params.update(self.user_kwargs)
 
+    def _normalize_metric(self, m: str) -> str:
+        if not m:
+            return "R2" if self.task_type == "regression" else "Accuracy"
+        m_low = str(m).strip().lower()
+        alias = {
+            "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC",
+            "accuracy": "Accuracy",
+            "precision": "Precision",
+            "recall": "Recall",
+            "f1": "F1",
+            "kappa": "Kappa",
+            "logloss": "Log Loss", "log_loss": "Log Loss",
+            "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted",
+            "r2": "R2",
+            "mae": "MAE",
+            "mse": "MSE",
+            "rmse": "RMSE",
+            "rmsle": "RMSLE",
+            "mape": "MAPE",
+        }
+        return alias.get(m_low, m)
+
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
@@ -222,6 +244,15 @@
         if getattr(self, "cross_validation_folds", None) is not None:
             compare_kwargs["fold"] = self.cross_validation_folds
 
+        chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None))
+        if chosen_metric:
+            compare_kwargs["sort"] = chosen_metric
+            self.chosen_metric_label = chosen_metric
+            try:
+                setattr(self.exp, "_fold_metric", chosen_metric)
+            except Exception as e:
+                LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
+
         LOG.info(f"compare_models kwargs: {compare_kwargs}")
         self.best_model = self.exp.compare_models(**compare_kwargs)
         self.results = self.exp.pull()
@@ -369,8 +400,8 @@
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
-        if hasattr(self.exp, "_fold_metric"):
-            setup_rows.append(["best_model_metric", self.exp._fold_metric])
+        if getattr(self, "chosen_metric_label", None):
+            setup_rows.append(["Best Model Metric", self.chosen_metric_label])
 
         df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
         df_setup.to_csv(
--- a/pycaret_macros.xml	Mon Sep 08 22:38:55 2025 +0000
+++ b/pycaret_macros.xml	Sat Nov 08 14:20:19 2025 +0000
@@ -1,5 +1,5 @@
 <macros>
-    <token name="@TABULAR_LEARNER_VERSION@">0.1.0.1</token>
+    <token name="@TABULAR_LEARNER_VERSION@">0.1.1</token>
     <token name="@PYCARET_VERSION@">3.3.2</token>
     <token name="@SUFFIX@">1</token>
     <token name="@PYCARET_PREDICT_VERSION@">@PYCARET_VERSION@+@SUFFIX@</token>
--- a/pycaret_train.py	Mon Sep 08 22:38:55 2025 +0000
+++ b/pycaret_train.py	Sat Nov 08 14:20:19 2025 +0000
@@ -120,6 +120,12 @@
         default=None,
         help="Probability threshold for classification decision,",
     )
+    parser.add_argument(
+        "--best_model_metric",
+        type=str,
+        default=None,
+        help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).",
+    )
 
     args = parser.parse_args()
 
@@ -144,6 +150,7 @@
         "fix_imbalance": args.fix_imbalance,
         "tune_model": args.tune_model,
         "probability_threshold": args.probability_threshold,
+        "best_model_metric": args.best_model_metric,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")
 
--- a/tabular_learner.xml	Mon Sep 08 22:38:55 2025 +0000
+++ b/tabular_learner.xml	Sat Nov 08 14:20:19 2025 +0000
@@ -59,6 +59,9 @@
             --test_file '$test_file'
         #end if
         --model_type '$model_type'
+        #if $best_model_metric
+            --best_model_metric '$best_model_metric'
+        #end if
         ]]>
     </command>
     <inputs>
@@ -104,6 +107,16 @@
                     <option value="lightgbm">Light Gradient Boosting Machine</option>
                     <option value="catboost">CatBoost Classifier</option>
                 </param>
+                <param name="best_model_metric" type="select" label="Select metric to pick the best model" help="PyCaret will rank models by this metric. Default is Accuracy.">
+                    <option value="Accuracy" selected="true">Accuracy</option>
+                    <option value="AUC">ROC-AUC</option>
+                    <option value="Precision">Precision</option>
+                    <option value="Recall">Recall</option>
+                    <option value="F1">F1</option>
+                    <option value="Kappa">Cohen’s Kappa</option>
+                    <option value="Log Loss">Log Loss (lower is better)</option>
+                    <option value="PR-AUC-Weighted">PR-AUC (weighted)</option>
+                </param>
             </when>
             <when value="regression">
                 <param name="regression_models" type="select" multiple="true" label="Only Select Regression Models if you don't want to compare all models">
@@ -133,6 +146,14 @@
                     <option value="lightgbm">Light Gradient Boosting Machine</option>
                     <option value="catboost">CatBoost Regressor</option>
                 </param>
+                <param name="best_model_metric" type="select" label="Select metric to pick the best model" help="PyCaret will rank models by this metric. Default is R².">
+                    <option value="R2" selected="true">R²</option>
+                    <option value="MAE">MAE</option>
+                    <option value="MSE">MSE</option>
+                    <option value="RMSE">RMSE</option>
+                    <option value="RMSLE">RMSLE</option>
+                    <option value="MAPE">MAPE</option>
+                </param>
             </when>
         </conditional>
         <param name="tune_model" type="boolean" truevalue="True" falsevalue="False" label="Tune hyperparameters" help="Hyperparameter tuning on the best model" />
@@ -179,6 +200,7 @@
             <param name="input_file" value="pcr.tsv"/>
             <param name="target_feature" value="11"/> 
             <param name="model_type" value="classification"/>
+            <param name="best_model_metric" value="F1"/>
             <param name="random_seed" value="42"/>
             <param name="customize_defaults" value="true"/>
             <param name="train_size" value="0.8"/>
@@ -195,6 +217,8 @@
                     <has_text text="Validation Summary" />
                     <has_text text="Test Summary" />
                     <has_text text="Feature Importance" />
+                    <has_text text="Best Model Metric" />
+                    <has_text text="F1" />
                 </assert_contents>
             </output>
             <output name="best_model_csv" value="expected_best_model_classification_customized.csv" />
@@ -257,6 +281,7 @@
             <param name="input_file" value="auto-mpg.tsv"/>
             <param name="target_feature" value="1"/> 
             <param name="model_type" value="regression"/>
+            <param name="best_model_metric" value="RMSE"/>
             <param name="random_seed" value="42"/>
             <output name="model" file="expected_model_regression.h5" compare="sim_size" />
             <output name="comparison_result">
@@ -264,6 +289,8 @@
                     <has_text text="Validation Summary" />
                     <has_text text="Test Summary" />
                     <has_text text="Feature Importance" />
+                    <has_text text="Best Model Metric" />
+                    <has_text text="RMSE" />
                 </assert_contents>
             </output>
             <output name="best_model_csv" value="expected_best_model_regression.csv" />
--- a/test-data/expected_best_model_classification_customized.csv	Mon Sep 08 22:38:55 2025 +0000
+++ b/test-data/expected_best_model_classification_customized.csv	Sat Nov 08 14:20:19 2025 +0000
@@ -1,20 +1,3 @@
 Parameter,Value
-boosting_type,gbdt
-class_weight,
-colsample_bytree,1.0
-importance_type,split
-learning_rate,0.1
-max_depth,-1
-min_child_samples,20
-min_child_weight,0.001
-min_split_gain,0.0
-n_estimators,100
-n_jobs,-1
-num_leaves,31
-objective,
-random_state,42
-reg_alpha,0.0
-reg_lambda,0.0
-subsample,1.0
-subsample_for_bin,200000
-subsample_freq,0
+priors,
+var_smoothing,1e-09
Binary file test-data/expected_model_classification_customized.h5 has changed