Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_train.py @ 9:c6c1f8777aae draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author | goeckslab |
---|---|
date | Thu, 31 Jul 2025 15:41:24 +0000 |
parents | 1aed7d47c5ec |
children |
comparison
equal
deleted
inserted
replaced
8:1aed7d47c5ec | 9:c6c1f8777aae |
---|---|
101 action="store_true", | 101 action="store_true", |
102 default=False, | 102 default=False, |
103 help="Tune the best model hyperparameters after training", | 103 help="Tune the best model hyperparameters after training", |
104 ) | 104 ) |
105 parser.add_argument( | 105 parser.add_argument( |
106 "--test_file", | |
107 type=str, | |
108 default=None, | |
109 help="Path to the test data file", | |
110 ) | |
111 parser.add_argument( | |
106 "--random_seed", | 112 "--random_seed", |
107 type=int, | 113 type=int, |
108 default=42, | 114 default=42, |
109 help="Random seed for PyCaret setup", | 115 help="Random seed for PyCaret setup", |
110 ) | 116 ) |
111 parser.add_argument( | 117 parser.add_argument( |
112 "--test_file", | 118 "--probability_threshold", |
113 type=str, | 119 type=float, |
114 default=None, | 120 default=None, |
115 help="Path to the test data file", | 121 help="Probability threshold for classification decision,", |
116 ) | 122 ) |
117 | 123 |
118 args = parser.parse_args() | 124 args = parser.parse_args() |
119 | 125 |
120 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation | 126 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation |
121 if args.no_cross_validation: | 127 if args.no_cross_validation: |
122 args.cross_validation = False | 128 args.cross_validation = False |
123 # If --cross_validation was passed, args.cross_validation is True | 129 # If --cross_validation was passed, args.cross_validation is True |
124 # If neither was passed, args.cross_validation remains None | 130 # If neither was passed, args.cross_validation remains None |
125 | 131 |
126 # Build the model_kwargs dict from CLI args | 132 # Build the model_kwargs dict from CLI args |
127 model_kwargs = { | 133 model_kwargs = { |
128 "train_size": args.train_size, | 134 "train_size": args.train_size, |
135 "polynomial_features": args.polynomial_features, | 141 "polynomial_features": args.polynomial_features, |
136 "feature_interaction": args.feature_interaction, | 142 "feature_interaction": args.feature_interaction, |
137 "feature_ratio": args.feature_ratio, | 143 "feature_ratio": args.feature_ratio, |
138 "fix_imbalance": args.fix_imbalance, | 144 "fix_imbalance": args.fix_imbalance, |
139 "tune_model": args.tune_model, | 145 "tune_model": args.tune_model, |
146 "probability_threshold": args.probability_threshold, | |
140 } | 147 } |
141 LOG.info(f"Model kwargs: {model_kwargs}") | 148 LOG.info(f"Model kwargs: {model_kwargs}") |
142 | 149 |
143 # If the XML passed a comma-separated string in a single list element, split it out | 150 # If the XML passed a comma-separated string in a single list element, split it out |
144 if args.models: | 151 if args.models: |