changeset 5:3d42f82b3c7f draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author goeckslab
date Thu, 31 Jul 2025 15:41:07 +0000
parents 11fdac5affb3
children
files base_model_trainer.py pycaret_train.py tabular_learner.xml
diffstat 3 files changed, 40 insertions(+), 15 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Fri Jul 25 19:02:12 2025 +0000
+++ b/base_model_trainer.py	Thu Jul 31 15:41:07 2025 +0000
@@ -175,7 +175,13 @@
 
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
-        _ = self.exp.predict_model(self.best_model)
+
+        prob_thresh = getattr(self, "probability_threshold", None)
+        if self.task_type == "classification" and prob_thresh is not None:
+            _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh)
+        else:
+            _ = self.exp.predict_model(self.best_model)
+
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
             self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
@@ -233,7 +239,7 @@
             best_model_name = type(self.best_model).__name__
         LOG.info(f"Best model determined as: {best_model_name}")
 
-        # 2) Compute training sample count
+    # 2) Compute training sample count
         try:
             n_train = self.exp.X_train.shape[0]
         except Exception:
@@ -241,7 +247,10 @@
         total_rows = self.data.shape[0]
 
         # 3) Build setup parameters table
-        all_params = self.setup_params
+        all_params = self.setup_params.copy()
+        if self.task_type == "classification" and hasattr(self, "probability_threshold"):
+            all_params["probability_threshold"] = self.probability_threshold
+
         display_keys = [
             "Target",
             "Session ID",
@@ -255,6 +264,7 @@
             "Polynomial Features",
             "Fix Imbalance",
             "Models",
+            "Probability Threshold",
         ]
         setup_rows = []
         for key in display_keys:
@@ -281,6 +291,8 @@
                 dv = v if v is not None else "None"
             elif key == "Models":
                 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
+            elif key == "Probability Threshold":
+                dv = v if v is not None else "None"
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
--- a/pycaret_train.py	Fri Jul 25 19:02:12 2025 +0000
+++ b/pycaret_train.py	Thu Jul 31 15:41:07 2025 +0000
@@ -103,16 +103,22 @@
         help="Tune the best model hyperparameters after training",
     )
     parser.add_argument(
+        "--test_file",
+        type=str,
+        default=None,
+        help="Path to the test data file",
+    )
+    parser.add_argument(
         "--random_seed",
         type=int,
         default=42,
         help="Random seed for PyCaret setup",
     )
     parser.add_argument(
-        "--test_file",
-        type=str,
+        "--probability_threshold",
+        type=float,
         default=None,
-        help="Path to the test data file",
+        help="Probability threshold for classification decision,",
     )
 
     args = parser.parse_args()
@@ -120,7 +126,7 @@
     # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
     if args.no_cross_validation:
         args.cross_validation = False
-    # If --cross_validation was passed, args.cross_validation is True
+    # If --cross_validation was passed,  args.cross_validation is True
     # If neither was passed, args.cross_validation remains None
 
     # Build the model_kwargs dict from CLI args
@@ -137,6 +143,7 @@
         "feature_ratio": args.feature_ratio,
         "fix_imbalance": args.fix_imbalance,
         "tune_model": args.tune_model,
+        "probability_threshold": args.probability_threshold,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")
 
--- a/tabular_learner.xml	Fri Jul 25 19:02:12 2025 +0000
+++ b/tabular_learner.xml	Thu Jul 31 15:41:07 2025 +0000
@@ -22,10 +22,10 @@
         #end if
         #if $customize_defaults == "true"
                 #if $train_size
-                --train_size '$train_size' 
+                --train_size '$train_size'
                 #end if
                 #if $normalize
-                --normalize  
+                --normalize
                 #end if
                 #if $feature_selection
                 --feature_selection
@@ -34,27 +34,30 @@
                     --cross_validation
                     #if $cross_validation_folds
                         --cross_validation_folds '$cross_validation_folds'
-                    #end if 
+                    #end if
                 #end if
                 #if $enable_cross_validation == "false"
                     --no_cross_validation
                 #end if
                 #if $remove_outliers
-                --remove_outliers  
+                --remove_outliers
                 #end if
                 #if $remove_multicollinearity
-                --remove_multicollinearity 
+                --remove_multicollinearity
                 #end if
                 #if $polynomial_features
-                --polynomial_features  
+                --polynomial_features
                 #end if
                 #if $fix_imbalance
-                --fix_imbalance 
+                --fix_imbalance
+                #end if
+                #if $probability_threshold
+                --probability_threshold '$probability_threshold'
                 #end if
         #end if
         #if $test_file
             --test_file '$test_file'
-        #end if 
+        #end if
         --model_type '$model_type'
         ]]>
     </command>
@@ -150,6 +153,7 @@
                 <param name="remove_multicollinearity" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Remove Multicollinearity" help="Whether to remove multicollinear features before training." />
                 <param name="polynomial_features" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Polynomial Features" help="Whether to create polynomial features before training." />
                 <param name="fix_imbalance" type="boolean" truevalue="True" falsevalue="False" checked="false" label="Fix Imbalance" help="ONLY for classfication! Whether to use SMOTE or similar methods to fix imbalance in the input dataset." />
+                <param name="probability_threshold" type="float" min="0.0" max="1.0" value="0.5" label="Classification Probability Threshold" help="Only applies to classification. Probability above which a prediction is considered positive. Default is 0.5." />
             </when>
             <when value="false">
                 <!-- No additional parameters to show if the user selects 'No' -->
@@ -175,6 +179,7 @@
             <param name="cross_validation_folds" value="5"/>
             <param name="remove_outliers" value="true"/>
             <param name="remove_multicollinearity" value="true"/>
+            <param name="probability_threshold" value="0.4" />
             <output name="model" file="expected_model_classification_customized.h5" compare="sim_size"/>
             <output name="comparison_result">
                 <assert_contents>
@@ -197,6 +202,7 @@
             <param name="enable_cross_validation" value="false"/>
             <param name="remove_outliers" value="true"/>
             <param name="remove_multicollinearity" value="true"/>
+            <param name="probability_threshold" value="0.6" />
             <output name="model" file="expected_model_classification_customized_cross_off.h5" compare="sim_size"/>
             <output name="comparison_result">
                 <assert_contents>