comparison pycaret_train.py @ 0:915447b14520 draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 05:00:00 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:915447b14520
1 import argparse
2 import logging
3
4 from pycaret_classification import ClassificationModelTrainer
5
6 from pycaret_regression import RegressionModelTrainer
7
8 logging.basicConfig(level=logging.DEBUG)
9 LOG = logging.getLogger(__name__)
10
11
12 def main():
13 parser = argparse.ArgumentParser()
14 parser.add_argument("--input_file", help="Path to the input file")
15 parser.add_argument("--target_col", help="Column number of the target")
16 parser.add_argument("--output_dir",
17 help="Path to the output directory")
18 parser.add_argument("--model_type",
19 choices=["classification", "regression"],
20 help="Type of the model")
21 parser.add_argument("--train_size", type=float,
22 default=None,
23 help="Train size for PyCaret setup")
24 parser.add_argument("--normalize", action="store_true",
25 default=None,
26 help="Normalize data for PyCaret setup")
27 parser.add_argument("--feature_selection", action="store_true",
28 default=None,
29 help="Perform feature selection for PyCaret setup")
30 parser.add_argument("--cross_validation", action="store_true",
31 default=None,
32 help="Perform cross-validation for PyCaret setup")
33 parser.add_argument("--cross_validation_folds", type=int,
34 default=None,
35 help="Number of cross-validation folds \
36 for PyCaret setup")
37 parser.add_argument("--remove_outliers", action="store_true",
38 default=None,
39 help="Remove outliers for PyCaret setup")
40 parser.add_argument("--remove_multicollinearity", action="store_true",
41 default=None,
42 help="Remove multicollinearity for PyCaret setup")
43 parser.add_argument("--polynomial_features", action="store_true",
44 default=None,
45 help="Generate polynomial features for PyCaret setup")
46 parser.add_argument("--feature_interaction", action="store_true",
47 default=None,
48 help="Generate feature interactions for PyCaret setup")
49 parser.add_argument("--feature_ratio", action="store_true",
50 default=None,
51 help="Generate feature ratios for PyCaret setup")
52 parser.add_argument("--fix_imbalance", action="store_true",
53 default=None,
54 help="Fix class imbalance for PyCaret setup")
55 parser.add_argument("--models", nargs='+',
56 default=None,
57 help="Selected models for training")
58 parser.add_argument("--random_seed", type=int,
59 default=42,
60 help="Random seed for PyCaret setup")
61 parser.add_argument("--test_file", type=str, default=None,
62 help="Path to the test data file")
63
64 args = parser.parse_args()
65
66 model_kwargs = {
67 "train_size": args.train_size,
68 "normalize": args.normalize,
69 "feature_selection": args.feature_selection,
70 "cross_validation": args.cross_validation,
71 "cross_validation_folds": args.cross_validation_folds,
72 "remove_outliers": args.remove_outliers,
73 "remove_multicollinearity": args.remove_multicollinearity,
74 "polynomial_features": args.polynomial_features,
75 "feature_interaction": args.feature_interaction,
76 "feature_ratio": args.feature_ratio,
77 "fix_imbalance": args.fix_imbalance,
78 }
79 LOG.info(f"Model kwargs: {model_kwargs}")
80
81 # Remove None values from model_kwargs
82
83 LOG.info(f"Model kwargs 2: {model_kwargs}")
84 if args.models:
85 model_kwargs["models"] = args.models[0].split(",")
86
87 model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
88
89 if args.model_type == "classification":
90 trainer = ClassificationModelTrainer(
91 args.input_file,
92 args.target_col,
93 args.output_dir,
94 args.model_type,
95 args.random_seed,
96 args.test_file,
97 **model_kwargs)
98 elif args.model_type == "regression":
99 if "fix_imbalance" in model_kwargs:
100 del model_kwargs["fix_imbalance"]
101 trainer = RegressionModelTrainer(
102 args.input_file,
103 args.target_col,
104 args.output_dir,
105 args.model_type,
106 args.random_seed,
107 args.test_file,
108 **model_kwargs)
109 else:
110 LOG.error("Invalid model type. Please choose \
111 'classification' or 'regression'.")
112 return
113 trainer.run()
114
115
116 if __name__ == "__main__":
117 main()