Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_train.py @ 8:1aed7d47c5ec draft
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
| author | goeckslab |
|---|---|
| date | Fri, 25 Jul 2025 19:02:32 +0000 |
| parents | f4cb41f458fd |
| children | c6c1f8777aae |
comparison
equal
deleted
inserted
replaced
| 7:f4cb41f458fd | 8:1aed7d47c5ec |
|---|---|
| 10 | 10 |
| 11 def main(): | 11 def main(): |
| 12 parser = argparse.ArgumentParser() | 12 parser = argparse.ArgumentParser() |
| 13 parser.add_argument("--input_file", help="Path to the input file") | 13 parser.add_argument("--input_file", help="Path to the input file") |
| 14 parser.add_argument("--target_col", help="Column number of the target") | 14 parser.add_argument("--target_col", help="Column number of the target") |
| 15 parser.add_argument("--output_dir", | 15 parser.add_argument("--output_dir", help="Path to the output directory") |
| 16 help="Path to the output directory") | 16 parser.add_argument( |
| 17 parser.add_argument("--model_type", | 17 "--model_type", |
| 18 choices=["classification", "regression"], | 18 choices=["classification", "regression"], |
| 19 help="Type of the model") | 19 help="Type of the model", |
| 20 parser.add_argument("--train_size", type=float, | 20 ) |
| 21 default=None, | 21 parser.add_argument( |
| 22 help="Train size for PyCaret setup") | 22 "--train_size", |
| 23 parser.add_argument("--normalize", action="store_true", | 23 type=float, |
| 24 default=None, | 24 default=None, |
| 25 help="Normalize data for PyCaret setup") | 25 help="Train size for PyCaret setup", |
| 26 parser.add_argument("--feature_selection", action="store_true", | 26 ) |
| 27 default=None, | 27 parser.add_argument( |
| 28 help="Perform feature selection for PyCaret setup") | 28 "--normalize", |
| 29 parser.add_argument("--cross_validation", action="store_true", | 29 action="store_true", |
| 30 default=None, | 30 default=None, |
| 31 help="Perform cross-validation for PyCaret setup") | 31 help="Normalize data for PyCaret setup", |
| 32 parser.add_argument("--no_cross_validation", action="store_true", | 32 ) |
| 33 default=None, | 33 parser.add_argument( |
| 34 help="Don't perform cross-validation for PyCaret setup") | 34 "--feature_selection", |
| 35 parser.add_argument("--cross_validation_folds", type=int, | 35 action="store_true", |
| 36 default=None, | 36 default=None, |
| 37 help="Number of cross-validation folds \ | 37 help="Perform feature selection for PyCaret setup", |
| 38 for PyCaret setup") | 38 ) |
| 39 parser.add_argument("--remove_outliers", action="store_true", | 39 parser.add_argument( |
| 40 default=None, | 40 "--cross_validation", |
| 41 help="Remove outliers for PyCaret setup") | 41 action="store_true", |
| 42 parser.add_argument("--remove_multicollinearity", action="store_true", | 42 default=None, |
| 43 default=None, | 43 help="Enable cross-validation for PyCaret setup", |
| 44 help="Remove multicollinearity for PyCaret setup") | 44 ) |
| 45 parser.add_argument("--polynomial_features", action="store_true", | 45 parser.add_argument( |
| 46 default=None, | 46 "--no_cross_validation", |
| 47 help="Generate polynomial features for PyCaret setup") | 47 action="store_true", |
| 48 parser.add_argument("--feature_interaction", action="store_true", | 48 default=None, |
| 49 default=None, | 49 help="Disable cross-validation for PyCaret setup", |
| 50 help="Generate feature interactions for PyCaret setup") | 50 ) |
| 51 parser.add_argument("--feature_ratio", action="store_true", | 51 parser.add_argument( |
| 52 default=None, | 52 "--cross_validation_folds", |
| 53 help="Generate feature ratios for PyCaret setup") | 53 type=int, |
| 54 parser.add_argument("--fix_imbalance", action="store_true", | 54 default=None, |
| 55 default=None, | 55 help="Number of cross-validation folds for PyCaret setup", |
| 56 help="Fix class imbalance for PyCaret setup") | 56 ) |
| 57 parser.add_argument("--models", nargs='+', | 57 parser.add_argument( |
| 58 default=None, | 58 "--remove_outliers", |
| 59 help="Selected models for training") | 59 action="store_true", |
| 60 parser.add_argument("--random_seed", type=int, | 60 default=None, |
| 61 default=42, | 61 help="Remove outliers for PyCaret setup", |
| 62 help="Random seed for PyCaret setup") | 62 ) |
| 63 parser.add_argument("--test_file", type=str, default=None, | 63 parser.add_argument( |
| 64 help="Path to the test data file") | 64 "--remove_multicollinearity", |
| 65 action="store_true", | |
| 66 default=None, | |
| 67 help="Remove multicollinearity for PyCaret setup", | |
| 68 ) | |
| 69 parser.add_argument( | |
| 70 "--polynomial_features", | |
| 71 action="store_true", | |
| 72 default=None, | |
| 73 help="Generate polynomial features for PyCaret setup", | |
| 74 ) | |
| 75 parser.add_argument( | |
| 76 "--feature_interaction", | |
| 77 action="store_true", | |
| 78 default=None, | |
| 79 help="Generate feature interactions for PyCaret setup", | |
| 80 ) | |
| 81 parser.add_argument( | |
| 82 "--feature_ratio", | |
| 83 action="store_true", | |
| 84 default=None, | |
| 85 help="Generate feature ratios for PyCaret setup", | |
| 86 ) | |
| 87 parser.add_argument( | |
| 88 "--fix_imbalance", | |
| 89 action="store_true", | |
| 90 default=None, | |
| 91 help="Fix class imbalance for PyCaret setup", | |
| 92 ) | |
| 93 parser.add_argument( | |
| 94 "--models", | |
| 95 nargs="+", | |
| 96 default=None, | |
| 97 help="Selected models for training", | |
| 98 ) | |
| 99 parser.add_argument( | |
| 100 "--tune_model", | |
| 101 action="store_true", | |
| 102 default=False, | |
| 103 help="Tune the best model hyperparameters after training", | |
| 104 ) | |
| 105 parser.add_argument( | |
| 106 "--random_seed", | |
| 107 type=int, | |
| 108 default=42, | |
| 109 help="Random seed for PyCaret setup", | |
| 110 ) | |
| 111 parser.add_argument( | |
| 112 "--test_file", | |
| 113 type=str, | |
| 114 default=None, | |
| 115 help="Path to the test data file", | |
| 116 ) | |
| 65 | 117 |
| 66 args = parser.parse_args() | 118 args = parser.parse_args() |
| 67 | 119 |
| 68 cross_validation = True | 120 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation |
| 69 if args.no_cross_validation: | 121 if args.no_cross_validation: |
| 70 cross_validation = False | 122 args.cross_validation = False |
| 123 # If --cross_validation was passed, args.cross_validation is True | |
| 124 # If neither was passed, args.cross_validation remains None | |
| 71 | 125 |
| 126 # Build the model_kwargs dict from CLI args | |
| 72 model_kwargs = { | 127 model_kwargs = { |
| 73 "train_size": args.train_size, | 128 "train_size": args.train_size, |
| 74 "normalize": args.normalize, | 129 "normalize": args.normalize, |
| 75 "feature_selection": args.feature_selection, | 130 "feature_selection": args.feature_selection, |
| 76 "cross_validation": cross_validation, | 131 "cross_validation": args.cross_validation, |
| 77 "cross_validation_folds": args.cross_validation_folds, | 132 "cross_validation_folds": args.cross_validation_folds, |
| 78 "remove_outliers": args.remove_outliers, | 133 "remove_outliers": args.remove_outliers, |
| 79 "remove_multicollinearity": args.remove_multicollinearity, | 134 "remove_multicollinearity": args.remove_multicollinearity, |
| 80 "polynomial_features": args.polynomial_features, | 135 "polynomial_features": args.polynomial_features, |
| 81 "feature_interaction": args.feature_interaction, | 136 "feature_interaction": args.feature_interaction, |
| 82 "feature_ratio": args.feature_ratio, | 137 "feature_ratio": args.feature_ratio, |
| 83 "fix_imbalance": args.fix_imbalance, | 138 "fix_imbalance": args.fix_imbalance, |
| 139 "tune_model": args.tune_model, | |
| 84 } | 140 } |
| 85 LOG.info(f"Model kwargs: {model_kwargs}") | 141 LOG.info(f"Model kwargs: {model_kwargs}") |
| 86 | 142 |
| 87 # Remove None values from model_kwargs | 143 # If the XML passed a comma-separated string in a single list element, split it out |
| 88 | |
| 89 LOG.info(f"Model kwargs 2: {model_kwargs}") | |
| 90 if args.models: | 144 if args.models: |
| 91 model_kwargs["models"] = args.models[0].split(",") | 145 model_kwargs["models"] = args.models[0].split(",") |
| 92 | 146 |
| 147 # Drop None entries so PyCaret uses its default values | |
| 93 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} | 148 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} |
| 149 LOG.info(f"Model kwargs 2: {model_kwargs}") | |
| 94 | 150 |
| 151 # Instantiate the appropriate trainer | |
| 95 if args.model_type == "classification": | 152 if args.model_type == "classification": |
| 96 trainer = ClassificationModelTrainer( | 153 trainer = ClassificationModelTrainer( |
| 97 args.input_file, | 154 args.input_file, |
| 98 args.target_col, | 155 args.target_col, |
| 99 args.output_dir, | 156 args.output_dir, |
| 100 args.model_type, | 157 args.model_type, |
| 101 args.random_seed, | 158 args.random_seed, |
| 102 args.test_file, | 159 args.test_file, |
| 103 **model_kwargs) | 160 **model_kwargs, |
| 161 ) | |
| 104 elif args.model_type == "regression": | 162 elif args.model_type == "regression": |
| 105 if "fix_imbalance" in model_kwargs: | 163 # regression doesn't support fix_imbalance |
| 106 del model_kwargs["fix_imbalance"] | 164 model_kwargs.pop("fix_imbalance", None) |
| 107 trainer = RegressionModelTrainer( | 165 trainer = RegressionModelTrainer( |
| 108 args.input_file, | 166 args.input_file, |
| 109 args.target_col, | 167 args.target_col, |
| 110 args.output_dir, | 168 args.output_dir, |
| 111 args.model_type, | 169 args.model_type, |
| 112 args.random_seed, | 170 args.random_seed, |
| 113 args.test_file, | 171 args.test_file, |
| 114 **model_kwargs) | 172 **model_kwargs, |
| 173 ) | |
| 115 else: | 174 else: |
| 116 LOG.error("Invalid model type. Please choose \ | 175 LOG.error("Invalid model type. Please choose 'classification' or 'regression'.") |
| 117 'classification' or 'regression'.") | |
| 118 return | 176 return |
| 177 | |
| 119 trainer.run() | 178 trainer.run() |
| 120 | 179 |
| 121 | 180 |
| 122 if __name__ == "__main__": | 181 if __name__ == "__main__": |
| 123 main() | 182 main() |
