comparison pycaret_train.py @ 10:49f73a3c12f3 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 1ffd143e57fa952ee9dd84fc141771520aea0791
author goeckslab
date Wed, 26 Nov 2025 17:49:36 +0000
parents e7dd78077b72
children
comparison
equal deleted inserted replaced
9:e7dd78077b72 10:49f73a3c12f3
1 import argparse 1 import argparse
2 import logging 2 import logging
3 import os
3 4
4 from pycaret_classification import ClassificationModelTrainer 5 from pycaret_classification import ClassificationModelTrainer
5 from pycaret_regression import RegressionModelTrainer 6 from pycaret_regression import RegressionModelTrainer
6 7
7 logging.basicConfig(level=logging.DEBUG) 8 logging.basicConfig(level=logging.DEBUG)
113 type=int, 114 type=int,
114 default=42, 115 default=42,
115 help="Random seed for PyCaret setup", 116 help="Random seed for PyCaret setup",
116 ) 117 )
117 parser.add_argument( 118 parser.add_argument(
119 "--n-jobs",
120 dest="n_jobs",
121 type=int,
122 default=None,
123 help="Number of parallel jobs; defaults to GALAXY_SLOTS or 1 if unset/invalid.",
124 )
125 parser.add_argument(
118 "--probability_threshold", 126 "--probability_threshold",
119 type=float, 127 type=float,
120 default=None, 128 default=None,
121 help="Probability threshold for classification decision,", 129 help="Probability threshold for classification decision,",
122 ) 130 )
126 default=None, 134 default=None,
127 help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).", 135 help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).",
128 ) 136 )
129 137
130 args = parser.parse_args() 138 args = parser.parse_args()
139
140 # Derive n_jobs from CLI or GALAXY_SLOTS env var
141 if args.n_jobs is not None:
142 n_jobs = args.n_jobs
143 else:
144 slots_str = os.environ.get("GALAXY_SLOTS")
145 try:
146 n_jobs = int(slots_str) if slots_str is not None else 1
147 except ValueError:
148 n_jobs = 1
131 149
132 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation 150 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
133 if args.no_cross_validation: 151 if args.no_cross_validation:
134 args.cross_validation = False 152 args.cross_validation = False
135 # If --cross_validation was passed, args.cross_validation is True 153 # If --cross_validation was passed, args.cross_validation is True
147 "polynomial_features": args.polynomial_features, 165 "polynomial_features": args.polynomial_features,
148 "feature_interaction": args.feature_interaction, 166 "feature_interaction": args.feature_interaction,
149 "feature_ratio": args.feature_ratio, 167 "feature_ratio": args.feature_ratio,
150 "fix_imbalance": args.fix_imbalance, 168 "fix_imbalance": args.fix_imbalance,
151 "tune_model": args.tune_model, 169 "tune_model": args.tune_model,
170 "n_jobs": n_jobs,
152 "probability_threshold": args.probability_threshold, 171 "probability_threshold": args.probability_threshold,
153 "best_model_metric": args.best_model_metric, 172 "best_model_metric": args.best_model_metric,
154 } 173 }
155 LOG.info(f"Model kwargs: {model_kwargs}") 174 LOG.info(f"Model kwargs: {model_kwargs}")
156 175