comparison pycaret_train.py @ 9:c6c1f8777aae draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 4a11e8a4c4e9daa884bddedfa47090476c517667
author goeckslab
date Thu, 31 Jul 2025 15:41:24 +0000
parents 1aed7d47c5ec
children
comparison
equal deleted inserted replaced
8:1aed7d47c5ec 9:c6c1f8777aae
101 action="store_true", 101 action="store_true",
102 default=False, 102 default=False,
103 help="Tune the best model hyperparameters after training", 103 help="Tune the best model hyperparameters after training",
104 ) 104 )
105 parser.add_argument( 105 parser.add_argument(
106 "--test_file",
107 type=str,
108 default=None,
109 help="Path to the test data file",
110 )
111 parser.add_argument(
106 "--random_seed", 112 "--random_seed",
107 type=int, 113 type=int,
108 default=42, 114 default=42,
109 help="Random seed for PyCaret setup", 115 help="Random seed for PyCaret setup",
110 ) 116 )
111 parser.add_argument( 117 parser.add_argument(
112 "--test_file", 118 "--probability_threshold",
113 type=str, 119 type=float,
114 default=None, 120 default=None,
115 help="Path to the test data file", 121 help="Probability threshold for classification decision,",
116 ) 122 )
117 123
118 args = parser.parse_args() 124 args = parser.parse_args()
119 125
120 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation 126 # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
121 if args.no_cross_validation: 127 if args.no_cross_validation:
122 args.cross_validation = False 128 args.cross_validation = False
123 # If --cross_validation was passed, args.cross_validation is True 129 # If --cross_validation was passed, args.cross_validation is True
124 # If neither was passed, args.cross_validation remains None 130 # If neither was passed, args.cross_validation remains None
125 131
126 # Build the model_kwargs dict from CLI args 132 # Build the model_kwargs dict from CLI args
127 model_kwargs = { 133 model_kwargs = {
128 "train_size": args.train_size, 134 "train_size": args.train_size,
135 "polynomial_features": args.polynomial_features, 141 "polynomial_features": args.polynomial_features,
136 "feature_interaction": args.feature_interaction, 142 "feature_interaction": args.feature_interaction,
137 "feature_ratio": args.feature_ratio, 143 "feature_ratio": args.feature_ratio,
138 "fix_imbalance": args.fix_imbalance, 144 "fix_imbalance": args.fix_imbalance,
139 "tune_model": args.tune_model, 145 "tune_model": args.tune_model,
146 "probability_threshold": args.probability_threshold,
140 } 147 }
141 LOG.info(f"Model kwargs: {model_kwargs}") 148 LOG.info(f"Model kwargs: {model_kwargs}")
142 149
143 # If the XML passed a comma-separated string in a single list element, split it out 150 # If the XML passed a comma-separated string in a single list element, split it out
144 if args.models: 151 if args.models: