comparison pycaret_train.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
comparison
equal deleted inserted replaced
7:f4cb41f458fd 8:1aed7d47c5ec
10 10
11 def main(): 11 def main():
12 parser = argparse.ArgumentParser() 12 parser = argparse.ArgumentParser()
13 parser.add_argument("--input_file", help="Path to the input file") 13 parser.add_argument("--input_file", help="Path to the input file")
14 parser.add_argument("--target_col", help="Column number of the target") 14 parser.add_argument("--target_col", help="Column number of the target")
15 parser.add_argument("--output_dir", 15 parser.add_argument("--output_dir", help="Path to the output directory")
16 help="Path to the output directory") 16 parser.add_argument(
17 parser.add_argument("--model_type", 17 "--model_type",
18 choices=["classification", "regression"], 18 choices=["classification", "regression"],
19 help="Type of the model") 19 help="Type of the model",
20 parser.add_argument("--train_size", type=float, 20 )
21 default=None, 21 parser.add_argument(
22 help="Train size for PyCaret setup") 22 "--train_size",
23 parser.add_argument("--normalize", action="store_true", 23 type=float,
24 default=None, 24 default=None,
25 help="Normalize data for PyCaret setup") 25 help="Train size for PyCaret setup",
26 parser.add_argument("--feature_selection", action="store_true", 26 )
27 default=None, 27 parser.add_argument(
28 help="Perform feature selection for PyCaret setup") 28 "--normalize",
29 parser.add_argument("--cross_validation", action="store_true", 29 action="store_true",
30 default=None, 30 default=None,
31 help="Perform cross-validation for PyCaret setup") 31 help="Normalize data for PyCaret setup",
32 parser.add_argument("--no_cross_validation", action="store_true", 32 )
33 default=None, 33 parser.add_argument(
34 help="Don't perform cross-validation for PyCaret setup") 34 "--feature_selection",
35 parser.add_argument("--cross_validation_folds", type=int, 35 action="store_true",
36 default=None, 36 default=None,
37 help="Number of cross-validation folds \ 37 help="Perform feature selection for PyCaret setup",
38 for PyCaret setup") 38 )
39 parser.add_argument("--remove_outliers", action="store_true", 39 parser.add_argument(
40 default=None, 40 "--cross_validation",
41 help="Remove outliers for PyCaret setup") 41 action="store_true",
42 parser.add_argument("--remove_multicollinearity", action="store_true", 42 default=None,
43 default=None, 43 help="Enable cross-validation for PyCaret setup",
44 help="Remove multicollinearity for PyCaret setup") 44 )
45 parser.add_argument("--polynomial_features", action="store_true", 45 parser.add_argument(
46 default=None, 46 "--no_cross_validation",
47 help="Generate polynomial features for PyCaret setup") 47 action="store_true",
48 parser.add_argument("--feature_interaction", action="store_true", 48 default=None,
49 default=None, 49 help="Disable cross-validation for PyCaret setup",
50 help="Generate feature interactions for PyCaret setup") 50 )
51 parser.add_argument("--feature_ratio", action="store_true", 51 parser.add_argument(
52 default=None, 52 "--cross_validation_folds",
53 help="Generate feature ratios for PyCaret setup") 53 type=int,
54 parser.add_argument("--fix_imbalance", action="store_true", 54 default=None,
55 default=None, 55 help="Number of cross-validation folds for PyCaret setup",
56 help="Fix class imbalance for PyCaret setup") 56 )
57 parser.add_argument("--models", nargs='+', 57 parser.add_argument(
58 default=None, 58 "--remove_outliers",
59 help="Selected models for training") 59 action="store_true",
60 parser.add_argument("--random_seed", type=int, 60 default=None,
61 default=42, 61 help="Remove outliers for PyCaret setup",
62 help="Random seed for PyCaret setup") 62 )
63 parser.add_argument("--test_file", type=str, default=None, 63 parser.add_argument(
64 help="Path to the test data file") 64 "--remove_multicollinearity",
65 action="store_true",
66 default=None,
67 help="Remove multicollinearity for PyCaret setup",
68 )
69 parser.add_argument(
70 "--polynomial_features",
71 action="store_true",
72 default=None,
73 help="Generate polynomial features for PyCaret setup",
74 )
75 parser.add_argument(
76 "--feature_interaction",
77 action="store_true",
78 default=None,
79 help="Generate feature interactions for PyCaret setup",
80 )
81 parser.add_argument(
82 "--feature_ratio",
83 action="store_true",
84 default=None,
85 help="Generate feature ratios for PyCaret setup",
86 )
87 parser.add_argument(
88 "--fix_imbalance",
89 action="store_true",
90 default=None,
91 help="Fix class imbalance for PyCaret setup",
92 )
93 parser.add_argument(
94 "--models",
95 nargs="+",
96 default=None,
97 help="Selected models for training",
98 )
99 parser.add_argument(
100 "--tune_model",
101 action="store_true",
102 default=False,
103 help="Tune the best model hyperparameters after training",
104 )
105 parser.add_argument(
106 "--random_seed",
107 type=int,
108 default=42,
109 help="Random seed for PyCaret setup",
110 )
111 parser.add_argument(
112 "--test_file",
113 type=str,
114 default=None,
115 help="Path to the test data file",
116 )
65 117
66 args = parser.parse_args() 118 args = parser.parse_args()
67 119
68 cross_validation = True 120 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
69 if args.no_cross_validation: 121 if args.no_cross_validation:
70 cross_validation = False 122 args.cross_validation = False
123 # If --cross_validation was passed, args.cross_validation is True
124 # If neither was passed, args.cross_validation remains None
71 125
126 # Build the model_kwargs dict from CLI args
72 model_kwargs = { 127 model_kwargs = {
73 "train_size": args.train_size, 128 "train_size": args.train_size,
74 "normalize": args.normalize, 129 "normalize": args.normalize,
75 "feature_selection": args.feature_selection, 130 "feature_selection": args.feature_selection,
76 "cross_validation": cross_validation, 131 "cross_validation": args.cross_validation,
77 "cross_validation_folds": args.cross_validation_folds, 132 "cross_validation_folds": args.cross_validation_folds,
78 "remove_outliers": args.remove_outliers, 133 "remove_outliers": args.remove_outliers,
79 "remove_multicollinearity": args.remove_multicollinearity, 134 "remove_multicollinearity": args.remove_multicollinearity,
80 "polynomial_features": args.polynomial_features, 135 "polynomial_features": args.polynomial_features,
81 "feature_interaction": args.feature_interaction, 136 "feature_interaction": args.feature_interaction,
82 "feature_ratio": args.feature_ratio, 137 "feature_ratio": args.feature_ratio,
83 "fix_imbalance": args.fix_imbalance, 138 "fix_imbalance": args.fix_imbalance,
139 "tune_model": args.tune_model,
84 } 140 }
85 LOG.info(f"Model kwargs: {model_kwargs}") 141 LOG.info(f"Model kwargs: {model_kwargs}")
86 142
87 # Remove None values from model_kwargs 143 # If the XML passed a comma-separated string in a single list element, split it out
88
89 LOG.info(f"Model kwargs 2: {model_kwargs}")
90 if args.models: 144 if args.models:
91 model_kwargs["models"] = args.models[0].split(",") 145 model_kwargs["models"] = args.models[0].split(",")
92 146
147 # Drop None entries so PyCaret uses its default values
93 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} 148 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
149 LOG.info(f"Model kwargs 2: {model_kwargs}")
94 150
151 # Instantiate the appropriate trainer
95 if args.model_type == "classification": 152 if args.model_type == "classification":
96 trainer = ClassificationModelTrainer( 153 trainer = ClassificationModelTrainer(
97 args.input_file, 154 args.input_file,
98 args.target_col, 155 args.target_col,
99 args.output_dir, 156 args.output_dir,
100 args.model_type, 157 args.model_type,
101 args.random_seed, 158 args.random_seed,
102 args.test_file, 159 args.test_file,
103 **model_kwargs) 160 **model_kwargs,
161 )
104 elif args.model_type == "regression": 162 elif args.model_type == "regression":
105 if "fix_imbalance" in model_kwargs: 163 # regression doesn't support fix_imbalance
106 del model_kwargs["fix_imbalance"] 164 model_kwargs.pop("fix_imbalance", None)
107 trainer = RegressionModelTrainer( 165 trainer = RegressionModelTrainer(
108 args.input_file, 166 args.input_file,
109 args.target_col, 167 args.target_col,
110 args.output_dir, 168 args.output_dir,
111 args.model_type, 169 args.model_type,
112 args.random_seed, 170 args.random_seed,
113 args.test_file, 171 args.test_file,
114 **model_kwargs) 172 **model_kwargs,
173 )
115 else: 174 else:
116 LOG.error("Invalid model type. Please choose \ 175 LOG.error("Invalid model type. Please choose 'classification' or 'regression'.")
117 'classification' or 'regression'.")
118 return 176 return
177
119 trainer.run() 178 trainer.run()
120 179
121 180
122 if __name__ == "__main__": 181 if __name__ == "__main__":
123 main() 182 main()