Mercurial > repos > goeckslab > tabular_learner
comparison pycaret_train.py @ 0:209b663a4f62 draft
planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author | goeckslab |
---|---|
date | Wed, 18 Jun 2025 15:38:19 +0000 |
parents | |
children | f6a65e05d6ec |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:209b663a4f62 |
---|---|
1 import argparse | |
2 import logging | |
3 | |
4 from pycaret_classification import ClassificationModelTrainer | |
5 from pycaret_regression import RegressionModelTrainer | |
6 | |
7 logging.basicConfig(level=logging.DEBUG) | |
8 LOG = logging.getLogger(__name__) | |
9 | |
10 | |
11 def main(): | |
12 parser = argparse.ArgumentParser() | |
13 parser.add_argument("--input_file", help="Path to the input file") | |
14 parser.add_argument("--target_col", help="Column number of the target") | |
15 parser.add_argument("--output_dir", | |
16 help="Path to the output directory") | |
17 parser.add_argument("--model_type", | |
18 choices=["classification", "regression"], | |
19 help="Type of the model") | |
20 parser.add_argument("--train_size", type=float, | |
21 default=None, | |
22 help="Train size for PyCaret setup") | |
23 parser.add_argument("--normalize", action="store_true", | |
24 default=None, | |
25 help="Normalize data for PyCaret setup") | |
26 parser.add_argument("--feature_selection", action="store_true", | |
27 default=None, | |
28 help="Perform feature selection for PyCaret setup") | |
29 parser.add_argument("--cross_validation", action="store_true", | |
30 default=None, | |
31 help="Perform cross-validation for PyCaret setup") | |
32 parser.add_argument("--cross_validation_folds", type=int, | |
33 default=None, | |
34 help="Number of cross-validation folds \ | |
35 for PyCaret setup") | |
36 parser.add_argument("--remove_outliers", action="store_true", | |
37 default=None, | |
38 help="Remove outliers for PyCaret setup") | |
39 parser.add_argument("--remove_multicollinearity", action="store_true", | |
40 default=None, | |
41 help="Remove multicollinearity for PyCaret setup") | |
42 parser.add_argument("--polynomial_features", action="store_true", | |
43 default=None, | |
44 help="Generate polynomial features for PyCaret setup") | |
45 parser.add_argument("--feature_interaction", action="store_true", | |
46 default=None, | |
47 help="Generate feature interactions for PyCaret setup") | |
48 parser.add_argument("--feature_ratio", action="store_true", | |
49 default=None, | |
50 help="Generate feature ratios for PyCaret setup") | |
51 parser.add_argument("--fix_imbalance", action="store_true", | |
52 default=None, | |
53 help="Fix class imbalance for PyCaret setup") | |
54 parser.add_argument("--models", nargs='+', | |
55 default=None, | |
56 help="Selected models for training") | |
57 parser.add_argument("--random_seed", type=int, | |
58 default=42, | |
59 help="Random seed for PyCaret setup") | |
60 parser.add_argument("--test_file", type=str, default=None, | |
61 help="Path to the test data file") | |
62 | |
63 args = parser.parse_args() | |
64 | |
65 model_kwargs = { | |
66 "train_size": args.train_size, | |
67 "normalize": args.normalize, | |
68 "feature_selection": args.feature_selection, | |
69 "cross_validation": args.cross_validation, | |
70 "cross_validation_folds": args.cross_validation_folds, | |
71 "remove_outliers": args.remove_outliers, | |
72 "remove_multicollinearity": args.remove_multicollinearity, | |
73 "polynomial_features": args.polynomial_features, | |
74 "feature_interaction": args.feature_interaction, | |
75 "feature_ratio": args.feature_ratio, | |
76 "fix_imbalance": args.fix_imbalance, | |
77 } | |
78 LOG.info(f"Model kwargs: {model_kwargs}") | |
79 | |
80 # Remove None values from model_kwargs | |
81 | |
82 LOG.info(f"Model kwargs 2: {model_kwargs}") | |
83 if args.models: | |
84 model_kwargs["models"] = args.models[0].split(",") | |
85 | |
86 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} | |
87 | |
88 if args.model_type == "classification": | |
89 trainer = ClassificationModelTrainer( | |
90 args.input_file, | |
91 args.target_col, | |
92 args.output_dir, | |
93 args.model_type, | |
94 args.random_seed, | |
95 args.test_file, | |
96 **model_kwargs) | |
97 elif args.model_type == "regression": | |
98 if "fix_imbalance" in model_kwargs: | |
99 del model_kwargs["fix_imbalance"] | |
100 trainer = RegressionModelTrainer( | |
101 args.input_file, | |
102 args.target_col, | |
103 args.output_dir, | |
104 args.model_type, | |
105 args.random_seed, | |
106 args.test_file, | |
107 **model_kwargs) | |
108 else: | |
109 LOG.error("Invalid model type. Please choose \ | |
110 'classification' or 'regression'.") | |
111 return | |
112 trainer.run() | |
113 | |
114 | |
115 if __name__ == "__main__": | |
116 main() |