Mercurial > repos > goeckslab > pycaret_compare
diff pycaret_train.py @ 0:915447b14520 draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author | goeckslab |
---|---|
date | Wed, 11 Dec 2024 05:00:00 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pycaret_train.py Wed Dec 11 05:00:00 2024 +0000 @@ -0,0 +1,117 @@ +import argparse +import logging + +from pycaret_classification import ClassificationModelTrainer + +from pycaret_regression import RegressionModelTrainer + +logging.basicConfig(level=logging.DEBUG) +LOG = logging.getLogger(__name__) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--input_file", help="Path to the input file") + parser.add_argument("--target_col", help="Column number of the target") + parser.add_argument("--output_dir", + help="Path to the output directory") + parser.add_argument("--model_type", + choices=["classification", "regression"], + help="Type of the model") + parser.add_argument("--train_size", type=float, + default=None, + help="Train size for PyCaret setup") + parser.add_argument("--normalize", action="store_true", + default=None, + help="Normalize data for PyCaret setup") + parser.add_argument("--feature_selection", action="store_true", + default=None, + help="Perform feature selection for PyCaret setup") + parser.add_argument("--cross_validation", action="store_true", + default=None, + help="Perform cross-validation for PyCaret setup") + parser.add_argument("--cross_validation_folds", type=int, + default=None, + help="Number of cross-validation folds \ + for PyCaret setup") + parser.add_argument("--remove_outliers", action="store_true", + default=None, + help="Remove outliers for PyCaret setup") + parser.add_argument("--remove_multicollinearity", action="store_true", + default=None, + help="Remove multicollinearity for PyCaret setup") + parser.add_argument("--polynomial_features", action="store_true", + default=None, + help="Generate polynomial features for PyCaret setup") + parser.add_argument("--feature_interaction", action="store_true", + default=None, + help="Generate feature interactions for PyCaret setup") + parser.add_argument("--feature_ratio", action="store_true", + default=None, + help="Generate feature ratios for PyCaret setup") + parser.add_argument("--fix_imbalance", action="store_true", + default=None, + help="Fix class imbalance for PyCaret setup") + parser.add_argument("--models", nargs='+', + default=None, + help="Selected models for training") + parser.add_argument("--random_seed", type=int, + default=42, + help="Random seed for PyCaret setup") + parser.add_argument("--test_file", type=str, default=None, + help="Path to the test data file") + + args = parser.parse_args() + + model_kwargs = { + "train_size": args.train_size, + "normalize": args.normalize, + "feature_selection": args.feature_selection, + "cross_validation": args.cross_validation, + "cross_validation_folds": args.cross_validation_folds, + "remove_outliers": args.remove_outliers, + "remove_multicollinearity": args.remove_multicollinearity, + "polynomial_features": args.polynomial_features, + "feature_interaction": args.feature_interaction, + "feature_ratio": args.feature_ratio, + "fix_imbalance": args.fix_imbalance, + } + LOG.info(f"Model kwargs: {model_kwargs}") + + # Remove None values from model_kwargs + + LOG.info(f"Model kwargs 2: {model_kwargs}") + if args.models: + model_kwargs["models"] = args.models[0].split(",") + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None} + + if args.model_type == "classification": + trainer = ClassificationModelTrainer( + args.input_file, + args.target_col, + args.output_dir, + args.model_type, + args.random_seed, + args.test_file, + **model_kwargs) + elif args.model_type == "regression": + if "fix_imbalance" in model_kwargs: + del model_kwargs["fix_imbalance"] + trainer = RegressionModelTrainer( + args.input_file, + args.target_col, + args.output_dir, + args.model_type, + args.random_seed, + args.test_file, + **model_kwargs) + else: + LOG.error("Invalid model type. Please choose \ + 'classification' or 'regression'.") + return + trainer.run() + + +if __name__ == "__main__": + main()