diff pycaret_train.py @ 5:3d42f82b3c7f draft

planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author goeckslab
date Thu, 31 Jul 2025 15:41:07 +0000
parents 11fdac5affb3
children 4bd75b45a7a1
line wrap: on
line diff
--- a/pycaret_train.py	Fri Jul 25 19:02:12 2025 +0000
+++ b/pycaret_train.py	Thu Jul 31 15:41:07 2025 +0000
@@ -103,16 +103,22 @@
         help="Tune the best model hyperparameters after training",
     )
     parser.add_argument(
+        "--test_file",
+        type=str,
+        default=None,
+        help="Path to the test data file",
+    )
+    parser.add_argument(
         "--random_seed",
         type=int,
         default=42,
         help="Random seed for PyCaret setup",
     )
     parser.add_argument(
-        "--test_file",
-        type=str,
+        "--probability_threshold",
+        type=float,
         default=None,
-        help="Path to the test data file",
+        help="Probability threshold for classification decision,",
     )
 
     args = parser.parse_args()
@@ -120,7 +126,7 @@
     # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
     if args.no_cross_validation:
         args.cross_validation = False
-    # If --cross_validation was passed, args.cross_validation is True
+    # If --cross_validation was passed,  args.cross_validation is True
     # If neither was passed, args.cross_validation remains None
 
     # Build the model_kwargs dict from CLI args
@@ -137,6 +143,7 @@
         "feature_ratio": args.feature_ratio,
         "fix_imbalance": args.fix_imbalance,
         "tune_model": args.tune_model,
+        "probability_threshold": args.probability_threshold,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")