Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_train.py @ 10:49f73a3c12f3 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 1ffd143e57fa952ee9dd84fc141771520aea0791
| author | goeckslab |
|---|---|
| date | Wed, 26 Nov 2025 17:49:36 +0000 |
| parents | e7dd78077b72 |
| children |
comparison
equal
deleted
inserted
replaced
| 9:e7dd78077b72 | 10:49f73a3c12f3 |
|---|---|
| 1 import argparse | 1 import argparse |
| 2 import logging | 2 import logging |
| 3 import os | |
| 3 | 4 |
| 4 from pycaret_classification import ClassificationModelTrainer | 5 from pycaret_classification import ClassificationModelTrainer |
| 5 from pycaret_regression import RegressionModelTrainer | 6 from pycaret_regression import RegressionModelTrainer |
| 6 | 7 |
| 7 logging.basicConfig(level=logging.DEBUG) | 8 logging.basicConfig(level=logging.DEBUG) |
| 113 type=int, | 114 type=int, |
| 114 default=42, | 115 default=42, |
| 115 help="Random seed for PyCaret setup", | 116 help="Random seed for PyCaret setup", |
| 116 ) | 117 ) |
| 117 parser.add_argument( | 118 parser.add_argument( |
| 119 "--n-jobs", | |
| 120 dest="n_jobs", | |
| 121 type=int, | |
| 122 default=None, | |
| 123 help="Number of parallel jobs; defaults to GALAXY_SLOTS or 1 if unset/invalid.", | |
| 124 ) | |
| 125 parser.add_argument( | |
| 118 "--probability_threshold", | 126 "--probability_threshold", |
| 119 type=float, | 127 type=float, |
| 120 default=None, | 128 default=None, |
| 121 help="Probability threshold for classification decision,", | 129 help="Probability threshold for classification decision,", |
| 122 ) | 130 ) |
| 126 default=None, | 134 default=None, |
| 127 help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).", | 135 help="Metric used to select the best model (e.g. AUC, Accuracy, R2, RMSE).", |
| 128 ) | 136 ) |
| 129 | 137 |
| 130 args = parser.parse_args() | 138 args = parser.parse_args() |
| 139 | |
| 140 # Derive n_jobs from CLI or GALAXY_SLOTS env var | |
| 141 if args.n_jobs is not None: | |
| 142 n_jobs = args.n_jobs | |
| 143 else: | |
| 144 slots_str = os.environ.get("GALAXY_SLOTS") | |
| 145 try: | |
| 146 n_jobs = int(slots_str) if slots_str is not None else 1 | |
| 147 except ValueError: | |
| 148 n_jobs = 1 | |
| 131 | 149 |
| 132 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation | 150 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation |
| 133 if args.no_cross_validation: | 151 if args.no_cross_validation: |
| 134 args.cross_validation = False | 152 args.cross_validation = False |
| 135 # If --cross_validation was passed, args.cross_validation is True | 153 # If --cross_validation was passed, args.cross_validation is True |
| 147 "polynomial_features": args.polynomial_features, | 165 "polynomial_features": args.polynomial_features, |
| 148 "feature_interaction": args.feature_interaction, | 166 "feature_interaction": args.feature_interaction, |
| 149 "feature_ratio": args.feature_ratio, | 167 "feature_ratio": args.feature_ratio, |
| 150 "fix_imbalance": args.fix_imbalance, | 168 "fix_imbalance": args.fix_imbalance, |
| 151 "tune_model": args.tune_model, | 169 "tune_model": args.tune_model, |
| 170 "n_jobs": n_jobs, | |
| 152 "probability_threshold": args.probability_threshold, | 171 "probability_threshold": args.probability_threshold, |
| 153 "best_model_metric": args.best_model_metric, | 172 "best_model_metric": args.best_model_metric, |
| 154 } | 173 } |
| 155 LOG.info(f"Model kwargs: {model_kwargs}") | 174 LOG.info(f"Model kwargs: {model_kwargs}") |
| 156 | 175 |
