Mercurial > repos > bgruening > sklearn_train_test_split
comparison keras_deep_learning.py @ 11:5da2217cd788 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:33:25 +0000 |
parents | 3312fb686ffb |
children |
comparison
equal
deleted
inserted
replaced
10:6e25381dad5c | 11:5da2217cd788 |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import pickle | |
4 import warnings | 3 import warnings |
5 from ast import literal_eval | 4 from ast import literal_eval |
6 | 5 |
7 import keras | |
8 import pandas as pd | |
9 import six | 6 import six |
10 from galaxy_ml.utils import get_search_params, SafeEval, try_get_attr | 7 from galaxy_ml.model_persist import dump_model_to_h5 |
11 from keras.models import Model, Sequential | 8 from galaxy_ml.utils import SafeEval, try_get_attr |
9 from tensorflow import keras | |
10 from tensorflow.keras.models import Model, Sequential | |
12 | 11 |
13 safe_eval = SafeEval() | 12 safe_eval = SafeEval() |
14 | 13 |
15 | 14 |
16 def _handle_shape(literal): | 15 def _handle_shape(literal): |
17 """ | 16 """Eval integer or list/tuple of integers from string |
18 Eval integer or list/tuple of integers from string | |
19 | 17 |
20 Parameters: | 18 Parameters: |
21 ----------- | 19 ----------- |
22 literal : str. | 20 literal : str. |
23 """ | 21 """ |
30 print(e) | 28 print(e) |
31 return literal | 29 return literal |
32 | 30 |
33 | 31 |
34 def _handle_regularizer(literal): | 32 def _handle_regularizer(literal): |
35 """ | 33 """Construct regularizer from string literal |
36 Construct regularizer from string literal | |
37 | 34 |
38 Parameters | 35 Parameters |
39 ---------- | 36 ---------- |
40 literal : str. E.g. '(0.1, 0)' | 37 literal : str. E.g. '(0.1, 0)' |
41 """ | 38 """ |
55 | 52 |
56 return keras.regularizers.l1_l2(l1=l1, l2=l2) | 53 return keras.regularizers.l1_l2(l1=l1, l2=l2) |
57 | 54 |
58 | 55 |
59 def _handle_constraint(config): | 56 def _handle_constraint(config): |
60 """ | 57 """Construct constraint from galaxy tool parameters. |
61 Construct constraint from galaxy tool parameters. | |
62 Suppose correct dictionary format | 58 Suppose correct dictionary format |
63 | 59 |
64 Parameters | 60 Parameters |
65 ---------- | 61 ---------- |
66 config : dict. E.g. | 62 config : dict. E.g. |
89 def _handle_lambda(literal): | 85 def _handle_lambda(literal): |
90 return None | 86 return None |
91 | 87 |
92 | 88 |
93 def _handle_layer_parameters(params): | 89 def _handle_layer_parameters(params): |
94 """ | 90 """Access to handle all kinds of parameters""" |
95 Access to handle all kinds of parameters | |
96 """ | |
97 for key, value in six.iteritems(params): | 91 for key, value in six.iteritems(params): |
98 if value in ("None", ""): | 92 if value in ("None", ""): |
99 params[key] = None | 93 params[key] = None |
100 continue | 94 continue |
101 | 95 |
102 if type(value) in [int, float, bool] or ( | 96 if type(value) in [int, float, bool] or ( |
103 type(value) is str and value.isalpha() | 97 type(value) is str and value.isalpha() |
104 ): | 98 ): |
105 continue | 99 continue |
106 | 100 |
107 if ( | 101 if key in [ |
108 key | 102 "input_shape", |
109 in [ | 103 "noise_shape", |
110 "input_shape", | 104 "shape", |
111 "noise_shape", | 105 "batch_shape", |
112 "shape", | 106 "target_shape", |
113 "batch_shape", | 107 "dims", |
114 "target_shape", | 108 "kernel_size", |
115 "dims", | 109 "strides", |
116 "kernel_size", | 110 "dilation_rate", |
117 "strides", | 111 "output_padding", |
118 "dilation_rate", | 112 "cropping", |
119 "output_padding", | 113 "size", |
120 "cropping", | 114 "padding", |
121 "size", | 115 "pool_size", |
122 "padding", | 116 "axis", |
123 "pool_size", | 117 "shared_axes", |
124 "axis", | 118 ] and isinstance(value, str): |
125 "shared_axes", | |
126 ] | |
127 and isinstance(value, str) | |
128 ): | |
129 params[key] = _handle_shape(value) | 119 params[key] = _handle_shape(value) |
130 | 120 |
131 elif key.endswith("_regularizer") and isinstance(value, dict): | 121 elif key.endswith("_regularizer") and isinstance(value, dict): |
132 params[key] = _handle_regularizer(value) | 122 params[key] = _handle_regularizer(value) |
133 | 123 |
139 | 129 |
140 return params | 130 return params |
141 | 131 |
142 | 132 |
143 def get_sequential_model(config): | 133 def get_sequential_model(config): |
144 """ | 134 """Construct keras Sequential model from Galaxy tool parameters |
145 Construct keras Sequential model from Galaxy tool parameters | |
146 | 135 |
147 Parameters: | 136 Parameters: |
148 ----------- | 137 ----------- |
149 config : dictionary, galaxy tool parameters loaded by JSON | 138 config : dictionary, galaxy tool parameters loaded by JSON |
150 """ | 139 """ |
163 if kwargs: | 152 if kwargs: |
164 kwargs = safe_eval("dict(" + kwargs + ")") | 153 kwargs = safe_eval("dict(" + kwargs + ")") |
165 options.update(kwargs) | 154 options.update(kwargs) |
166 | 155 |
167 # add input_shape to the first layer only | 156 # add input_shape to the first layer only |
168 if not getattr(model, "_layers") and input_shape is not None: | 157 if not model.get_config()["layers"] and input_shape is not None: |
169 options["input_shape"] = input_shape | 158 options["input_shape"] = input_shape |
170 | 159 |
171 model.add(klass(**options)) | 160 model.add(klass(**options)) |
172 | 161 |
173 return model | 162 return model |
174 | 163 |
175 | 164 |
176 def get_functional_model(config): | 165 def get_functional_model(config): |
177 """ | 166 """Construct keras functional model from Galaxy tool parameters |
178 Construct keras functional model from Galaxy tool parameters | |
179 | 167 |
180 Parameters | 168 Parameters |
181 ----------- | 169 ----------- |
182 config : dictionary, galaxy tool parameters loaded by JSON | 170 config : dictionary, galaxy tool parameters loaded by JSON |
183 """ | 171 """ |
219 | 207 |
220 return Model(inputs=input_layers, outputs=output_layers) | 208 return Model(inputs=input_layers, outputs=output_layers) |
221 | 209 |
222 | 210 |
223 def get_batch_generator(config): | 211 def get_batch_generator(config): |
224 """ | 212 """Construct keras online data generator from Galaxy tool parameters |
225 Construct keras online data generator from Galaxy tool parameters | |
226 | 213 |
227 Parameters | 214 Parameters |
228 ----------- | 215 ----------- |
229 config : dictionary, galaxy tool parameters loaded by JSON | 216 config : dictionary, galaxy tool parameters loaded by JSON |
230 """ | 217 """ |
244 | 231 |
245 return klass(**config) | 232 return klass(**config) |
246 | 233 |
247 | 234 |
248 def config_keras_model(inputs, outfile): | 235 def config_keras_model(inputs, outfile): |
249 """ | 236 """config keras model layers and output JSON |
250 config keras model layers and output JSON | |
251 | 237 |
252 Parameters | 238 Parameters |
253 ---------- | 239 ---------- |
254 inputs : dict | 240 inputs : dict |
255 loaded galaxy tool parameters from `keras_model_config` | 241 loaded galaxy tool parameters from `keras_model_config` |
269 | 255 |
270 with open(outfile, "w") as f: | 256 with open(outfile, "w") as f: |
271 json.dump(json.loads(json_string), f, indent=2) | 257 json.dump(json.loads(json_string), f, indent=2) |
272 | 258 |
273 | 259 |
274 def build_keras_model( | 260 def build_keras_model(inputs, outfile, model_json, batch_mode=False): |
275 inputs, | 261 """for `keras_model_builder` tool |
276 outfile, | |
277 model_json, | |
278 infile_weights=None, | |
279 batch_mode=False, | |
280 outfile_params=None, | |
281 ): | |
282 """ | |
283 for `keras_model_builder` tool | |
284 | 262 |
285 Parameters | 263 Parameters |
286 ---------- | 264 ---------- |
287 inputs : dict | 265 inputs : dict |
288 loaded galaxy tool parameters from `keras_model_builder` tool. | 266 loaded galaxy tool parameters from `keras_model_builder` tool. |
289 outfile : str | 267 outfile : str |
290 Path to galaxy dataset containing the keras_galaxy model output. | 268 Path to galaxy dataset containing the keras_galaxy model output. |
291 model_json : str | 269 model_json : str |
292 Path to dataset containing keras model JSON. | 270 Path to dataset containing keras model JSON. |
293 infile_weights : str or None | |
294 If string, path to dataset containing model weights. | |
295 batch_mode : bool, default=False | 271 batch_mode : bool, default=False |
296 Whether to build online batch classifier. | 272 Whether to build online batch classifier. |
297 outfile_params : str, default=None | |
298 File path to search parameters output. | |
299 """ | 273 """ |
300 with open(model_json, "r") as f: | 274 with open(model_json, "r") as f: |
301 json_model = json.load(f) | 275 json_model = json.load(f) |
302 | 276 |
303 config = json_model["config"] | 277 config = json_model["config"] |
305 options = {} | 279 options = {} |
306 | 280 |
307 if json_model["class_name"] == "Sequential": | 281 if json_model["class_name"] == "Sequential": |
308 options["model_type"] = "sequential" | 282 options["model_type"] = "sequential" |
309 klass = Sequential | 283 klass = Sequential |
310 elif json_model["class_name"] == "Model": | 284 elif json_model["class_name"] == "Functional": |
311 options["model_type"] = "functional" | 285 options["model_type"] = "functional" |
312 klass = Model | 286 klass = Model |
313 else: | 287 else: |
314 raise ValueError("Unknow Keras model class: %s" % json_model["class_name"]) | 288 raise ValueError("Unknow Keras model class: %s" % json_model["class_name"]) |
315 | 289 |
316 # load prefitted model | 290 # load prefitted model |
317 if inputs["mode_selection"]["mode_type"] == "prefitted": | 291 if inputs["mode_selection"]["mode_type"] == "prefitted": |
318 estimator = klass.from_config(config) | 292 # estimator = klass.from_config(config) |
319 estimator.load_weights(infile_weights) | 293 # estimator.load_weights(infile_weights) |
294 raise Exception("Prefitted was deprecated!") | |
320 # build train model | 295 # build train model |
321 else: | 296 else: |
322 cls_name = inputs["mode_selection"]["learning_type"] | 297 cls_name = inputs["mode_selection"]["learning_type"] |
323 klass = try_get_attr("galaxy_ml.keras_galaxy_models", cls_name) | 298 klass = try_get_attr("galaxy_ml.keras_galaxy_models", cls_name) |
324 | 299 |
336 ] | 311 ] |
337 ) | 312 ) |
338 ) | 313 ) |
339 | 314 |
340 train_metrics = inputs["mode_selection"]["compile_params"]["metrics"] | 315 train_metrics = inputs["mode_selection"]["compile_params"]["metrics"] |
316 if not isinstance(train_metrics, list): # for older galaxy | |
317 train_metrics = train_metrics.split(",") | |
341 if train_metrics[-1] == "none": | 318 if train_metrics[-1] == "none": |
342 train_metrics = train_metrics[:-1] | 319 train_metrics.pop() |
343 options["metrics"] = train_metrics | 320 options["metrics"] = train_metrics |
344 | 321 |
345 options.update(inputs["mode_selection"]["fit_params"]) | 322 options.update(inputs["mode_selection"]["fit_params"]) |
346 options["seed"] = inputs["mode_selection"]["random_seed"] | 323 options["seed"] = inputs["mode_selection"]["random_seed"] |
347 | 324 |
353 options["prediction_steps"] = inputs["mode_selection"]["prediction_steps"] | 330 options["prediction_steps"] = inputs["mode_selection"]["prediction_steps"] |
354 options["class_positive_factor"] = inputs["mode_selection"][ | 331 options["class_positive_factor"] = inputs["mode_selection"][ |
355 "class_positive_factor" | 332 "class_positive_factor" |
356 ] | 333 ] |
357 estimator = klass(config, **options) | 334 estimator = klass(config, **options) |
358 if outfile_params: | |
359 hyper_params = get_search_params(estimator) | |
360 # TODO: remove this after making `verbose` tunable | |
361 for h_param in hyper_params: | |
362 if h_param[1].endswith("verbose"): | |
363 h_param[0] = "@" | |
364 df = pd.DataFrame(hyper_params, columns=["", "Parameter", "Value"]) | |
365 df.to_csv(outfile_params, sep="\t", index=False) | |
366 | 335 |
367 print(repr(estimator)) | 336 print(repr(estimator)) |
368 # save model by pickle | 337 # save model |
369 with open(outfile, "wb") as f: | 338 dump_model_to_h5(estimator, outfile, verbose=1) |
370 pickle.dump(estimator, f, pickle.HIGHEST_PROTOCOL) | |
371 | 339 |
372 | 340 |
373 if __name__ == "__main__": | 341 if __name__ == "__main__": |
374 warnings.simplefilter("ignore") | 342 warnings.simplefilter("ignore") |
375 | 343 |
376 aparser = argparse.ArgumentParser() | 344 aparser = argparse.ArgumentParser() |
377 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 345 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
378 aparser.add_argument("-m", "--model_json", dest="model_json") | 346 aparser.add_argument("-m", "--model_json", dest="model_json") |
379 aparser.add_argument("-t", "--tool_id", dest="tool_id") | 347 aparser.add_argument("-t", "--tool_id", dest="tool_id") |
380 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
381 aparser.add_argument("-o", "--outfile", dest="outfile") | 348 aparser.add_argument("-o", "--outfile", dest="outfile") |
382 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
383 args = aparser.parse_args() | 349 args = aparser.parse_args() |
384 | 350 |
385 input_json_path = args.inputs | 351 input_json_path = args.inputs |
386 with open(input_json_path, "r") as param_handler: | 352 with open(input_json_path, "r") as param_handler: |
387 inputs = json.load(param_handler) | 353 inputs = json.load(param_handler) |
388 | 354 |
389 tool_id = args.tool_id | 355 tool_id = args.tool_id |
390 outfile = args.outfile | 356 outfile = args.outfile |
391 outfile_params = args.outfile_params | |
392 model_json = args.model_json | 357 model_json = args.model_json |
393 infile_weights = args.infile_weights | |
394 | 358 |
395 # for keras_model_config tool | 359 # for keras_model_config tool |
396 if tool_id == "keras_model_config": | 360 if tool_id == "keras_model_config": |
397 config_keras_model(inputs, outfile) | 361 config_keras_model(inputs, outfile) |
398 | 362 |
401 batch_mode = False | 365 batch_mode = False |
402 if tool_id == "keras_batch_models": | 366 if tool_id == "keras_batch_models": |
403 batch_mode = True | 367 batch_mode = True |
404 | 368 |
405 build_keras_model( | 369 build_keras_model( |
406 inputs=inputs, | 370 inputs=inputs, model_json=model_json, batch_mode=batch_mode, outfile=outfile |
407 model_json=model_json, | |
408 infile_weights=infile_weights, | |
409 batch_mode=batch_mode, | |
410 outfile=outfile, | |
411 outfile_params=outfile_params, | |
412 ) | 371 ) |