Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_train.py @ 9:c6c1f8777aae draft
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
| author | goeckslab |
|---|---|
| date | Thu, 31 Jul 2025 15:41:24 +0000 |
| parents | 1aed7d47c5ec |
| children | e2a6fed32d54 |
comparison
equal
deleted
inserted
replaced
| 8:1aed7d47c5ec | 9:c6c1f8777aae |
|---|---|
| 101 action="store_true", | 101 action="store_true", |
| 102 default=False, | 102 default=False, |
| 103 help="Tune the best model hyperparameters after training", | 103 help="Tune the best model hyperparameters after training", |
| 104 ) | 104 ) |
| 105 parser.add_argument( | 105 parser.add_argument( |
| 106 "--test_file", | |
| 107 type=str, | |
| 108 default=None, | |
| 109 help="Path to the test data file", | |
| 110 ) | |
| 111 parser.add_argument( | |
| 106 "--random_seed", | 112 "--random_seed", |
| 107 type=int, | 113 type=int, |
| 108 default=42, | 114 default=42, |
| 109 help="Random seed for PyCaret setup", | 115 help="Random seed for PyCaret setup", |
| 110 ) | 116 ) |
| 111 parser.add_argument( | 117 parser.add_argument( |
| 112 "--test_file", | 118 "--probability_threshold", |
| 113 type=str, | 119 type=float, |
| 114 default=None, | 120 default=None, |
| 115 help="Path to the test data file", | 121 help="Probability threshold for classification decision,", |
| 116 ) | 122 ) |
| 117 | 123 |
| 118 args = parser.parse_args() | 124 args = parser.parse_args() |
| 119 | 125 |
| 120 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation | 126 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation |
| 121 if args.no_cross_validation: | 127 if args.no_cross_validation: |
| 122 args.cross_validation = False | 128 args.cross_validation = False |
| 123 # If --cross_validation was passed, args.cross_validation is True | 129 # If --cross_validation was passed, args.cross_validation is True |
| 124 # If neither was passed, args.cross_validation remains None | 130 # If neither was passed, args.cross_validation remains None |
| 125 | 131 |
| 126 # Build the model_kwargs dict from CLI args | 132 # Build the model_kwargs dict from CLI args |
| 127 model_kwargs = { | 133 model_kwargs = { |
| 128 "train_size": args.train_size, | 134 "train_size": args.train_size, |
| 135 "polynomial_features": args.polynomial_features, | 141 "polynomial_features": args.polynomial_features, |
| 136 "feature_interaction": args.feature_interaction, | 142 "feature_interaction": args.feature_interaction, |
| 137 "feature_ratio": args.feature_ratio, | 143 "feature_ratio": args.feature_ratio, |
| 138 "fix_imbalance": args.fix_imbalance, | 144 "fix_imbalance": args.fix_imbalance, |
| 139 "tune_model": args.tune_model, | 145 "tune_model": args.tune_model, |
| 146 "probability_threshold": args.probability_threshold, | |
| 140 } | 147 } |
| 141 LOG.info(f"Model kwargs: {model_kwargs}") | 148 LOG.info(f"Model kwargs: {model_kwargs}") |
| 142 | 149 |
| 143 # If the XML passed a comma-separated string in a single list element, split it out | 150 # If the XML passed a comma-separated string in a single list element, split it out |
| 144 if args.models: | 151 if args.models: |
