Mercurial > repos > goeckslab > pycaret_predict
diff pycaret_train.py @ 9:c6c1f8777aae draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author | goeckslab |
---|---|
date | Thu, 31 Jul 2025 15:41:24 +0000 |
parents | 1aed7d47c5ec |
children |
line wrap: on
line diff
--- a/pycaret_train.py Fri Jul 25 19:02:32 2025 +0000 +++ b/pycaret_train.py Thu Jul 31 15:41:24 2025 +0000 @@ -103,16 +103,22 @@ help="Tune the best model hyperparameters after training", ) parser.add_argument( + "--test_file", + type=str, + default=None, + help="Path to the test data file", + ) + parser.add_argument( "--random_seed", type=int, default=42, help="Random seed for PyCaret setup", ) parser.add_argument( - "--test_file", - type=str, + "--probability_threshold", + type=float, default=None, - help="Path to the test data file", + help="Probability threshold for classification decision,", ) args = parser.parse_args() @@ -120,7 +126,7 @@ # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation if args.no_cross_validation: args.cross_validation = False - # If --cross_validation was passed, args.cross_validation is True + # If --cross_validation was passed, args.cross_validation is True # If neither was passed, args.cross_validation remains None # Build the model_kwargs dict from CLI args @@ -137,6 +143,7 @@ "feature_ratio": args.feature_ratio, "fix_imbalance": args.fix_imbalance, "tune_model": args.tune_model, + "probability_threshold": args.probability_threshold, } LOG.info(f"Model kwargs: {model_kwargs}")