comparison pycaret_train.py @ 0:209b663a4f62 draft

planemo upload for repository https://github.com/goeckslab/gleam commit 5dd048419fcbd285a327f88267e93996cd279ee6
author goeckslab
date Wed, 18 Jun 2025 15:38:19 +0000
parents
children f6a65e05d6ec
comparison
equal deleted inserted replaced
-1:000000000000 0:209b663a4f62
1 import argparse
2 import logging
3
4 from pycaret_classification import ClassificationModelTrainer
5 from pycaret_regression import RegressionModelTrainer
6
7 logging.basicConfig(level=logging.DEBUG)
8 LOG = logging.getLogger(__name__)
9
10
11 def main():
12 parser = argparse.ArgumentParser()
13 parser.add_argument("--input_file", help="Path to the input file")
14 parser.add_argument("--target_col", help="Column number of the target")
15 parser.add_argument("--output_dir",
16 help="Path to the output directory")
17 parser.add_argument("--model_type",
18 choices=["classification", "regression"],
19 help="Type of the model")
20 parser.add_argument("--train_size", type=float,
21 default=None,
22 help="Train size for PyCaret setup")
23 parser.add_argument("--normalize", action="store_true",
24 default=None,
25 help="Normalize data for PyCaret setup")
26 parser.add_argument("--feature_selection", action="store_true",
27 default=None,
28 help="Perform feature selection for PyCaret setup")
29 parser.add_argument("--cross_validation", action="store_true",
30 default=None,
31 help="Perform cross-validation for PyCaret setup")
32 parser.add_argument("--cross_validation_folds", type=int,
33 default=None,
34 help="Number of cross-validation folds \
35 for PyCaret setup")
36 parser.add_argument("--remove_outliers", action="store_true",
37 default=None,
38 help="Remove outliers for PyCaret setup")
39 parser.add_argument("--remove_multicollinearity", action="store_true",
40 default=None,
41 help="Remove multicollinearity for PyCaret setup")
42 parser.add_argument("--polynomial_features", action="store_true",
43 default=None,
44 help="Generate polynomial features for PyCaret setup")
45 parser.add_argument("--feature_interaction", action="store_true",
46 default=None,
47 help="Generate feature interactions for PyCaret setup")
48 parser.add_argument("--feature_ratio", action="store_true",
49 default=None,
50 help="Generate feature ratios for PyCaret setup")
51 parser.add_argument("--fix_imbalance", action="store_true",
52 default=None,
53 help="Fix class imbalance for PyCaret setup")
54 parser.add_argument("--models", nargs='+',
55 default=None,
56 help="Selected models for training")
57 parser.add_argument("--random_seed", type=int,
58 default=42,
59 help="Random seed for PyCaret setup")
60 parser.add_argument("--test_file", type=str, default=None,
61 help="Path to the test data file")
62
63 args = parser.parse_args()
64
65 model_kwargs = {
66 "train_size": args.train_size,
67 "normalize": args.normalize,
68 "feature_selection": args.feature_selection,
69 "cross_validation": args.cross_validation,
70 "cross_validation_folds": args.cross_validation_folds,
71 "remove_outliers": args.remove_outliers,
72 "remove_multicollinearity": args.remove_multicollinearity,
73 "polynomial_features": args.polynomial_features,
74 "feature_interaction": args.feature_interaction,
75 "feature_ratio": args.feature_ratio,
76 "fix_imbalance": args.fix_imbalance,
77 }
78 LOG.info(f"Model kwargs: {model_kwargs}")
79
80 # Remove None values from model_kwargs
81
82 LOG.info(f"Model kwargs 2: {model_kwargs}")
83 if args.models:
84 model_kwargs["models"] = args.models[0].split(",")
85
86 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
87
88 if args.model_type == "classification":
89 trainer = ClassificationModelTrainer(
90 args.input_file,
91 args.target_col,
92 args.output_dir,
93 args.model_type,
94 args.random_seed,
95 args.test_file,
96 **model_kwargs)
97 elif args.model_type == "regression":
98 if "fix_imbalance" in model_kwargs:
99 del model_kwargs["fix_imbalance"]
100 trainer = RegressionModelTrainer(
101 args.input_file,
102 args.target_col,
103 args.output_dir,
104 args.model_type,
105 args.random_seed,
106 args.test_file,
107 **model_kwargs)
108 else:
109 LOG.error("Invalid model type. Please choose \
110 'classification' or 'regression'.")
111 return
112 trainer.run()
113
114
115 if __name__ == "__main__":
116 main()