Mercurial > repos > goeckslab > tabular_learner
diff pycaret_train.py @ 5:3d42f82b3c7f draft
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author | goeckslab |
---|---|
date | Thu, 31 Jul 2025 15:41:07 +0000 |
parents | 11fdac5affb3 |
children | 4bd75b45a7a1 |
line wrap: on
line diff
--- a/pycaret_train.py Fri Jul 25 19:02:12 2025 +0000 +++ b/pycaret_train.py Thu Jul 31 15:41:07 2025 +0000 @@ -103,16 +103,22 @@ help="Tune the best model hyperparameters after training", ) parser.add_argument( + "--test_file", + type=str, + default=None, + help="Path to the test data file", + ) + parser.add_argument( "--random_seed", type=int, default=42, help="Random seed for PyCaret setup", ) parser.add_argument( - "--test_file", - type=str, + "--probability_threshold", + type=float, default=None, - help="Path to the test data file", + help="Probability threshold for classification decision,", ) args = parser.parse_args() @@ -120,7 +126,7 @@ # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation if args.no_cross_validation: args.cross_validation = False - # If --cross_validation was passed, args.cross_validation is True + # If --cross_validation was passed, args.cross_validation is True # If neither was passed, args.cross_validation remains None # Build the model_kwargs dict from CLI args @@ -137,6 +143,7 @@ "feature_ratio": args.feature_ratio, "fix_imbalance": args.fix_imbalance, "tune_model": args.tune_model, + "probability_threshold": args.probability_threshold, } LOG.info(f"Model kwargs: {model_kwargs}")