comparison keras_deep_learning.py @ 11:5da2217cd788 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:33:25 +0000
parents 3312fb686ffb
children
comparison
equal deleted inserted replaced
10:6e25381dad5c 11:5da2217cd788
1 import argparse 1 import argparse
2 import json 2 import json
3 import pickle
4 import warnings 3 import warnings
5 from ast import literal_eval 4 from ast import literal_eval
6 5
7 import keras
8 import pandas as pd
9 import six 6 import six
10 from galaxy_ml.utils import get_search_params, SafeEval, try_get_attr 7 from galaxy_ml.model_persist import dump_model_to_h5
11 from keras.models import Model, Sequential 8 from galaxy_ml.utils import SafeEval, try_get_attr
9 from tensorflow import keras
10 from tensorflow.keras.models import Model, Sequential
12 11
13 safe_eval = SafeEval() 12 safe_eval = SafeEval()
14 13
15 14
16 def _handle_shape(literal): 15 def _handle_shape(literal):
17 """ 16 """Eval integer or list/tuple of integers from string
18 Eval integer or list/tuple of integers from string
19 17
20 Parameters: 18 Parameters:
21 ----------- 19 -----------
22 literal : str. 20 literal : str.
23 """ 21 """
30 print(e) 28 print(e)
31 return literal 29 return literal
32 30
33 31
34 def _handle_regularizer(literal): 32 def _handle_regularizer(literal):
35 """ 33 """Construct regularizer from string literal
36 Construct regularizer from string literal
37 34
38 Parameters 35 Parameters
39 ---------- 36 ----------
40 literal : str. E.g. '(0.1, 0)' 37 literal : str. E.g. '(0.1, 0)'
41 """ 38 """
55 52
56 return keras.regularizers.l1_l2(l1=l1, l2=l2) 53 return keras.regularizers.l1_l2(l1=l1, l2=l2)
57 54
58 55
59 def _handle_constraint(config): 56 def _handle_constraint(config):
60 """ 57 """Construct constraint from galaxy tool parameters.
61 Construct constraint from galaxy tool parameters.
62 Suppose correct dictionary format 58 Suppose correct dictionary format
63 59
64 Parameters 60 Parameters
65 ---------- 61 ----------
66 config : dict. E.g. 62 config : dict. E.g.
89 def _handle_lambda(literal): 85 def _handle_lambda(literal):
90 return None 86 return None
91 87
92 88
93 def _handle_layer_parameters(params): 89 def _handle_layer_parameters(params):
94 """ 90 """Access to handle all kinds of parameters"""
95 Access to handle all kinds of parameters
96 """
97 for key, value in six.iteritems(params): 91 for key, value in six.iteritems(params):
98 if value in ("None", ""): 92 if value in ("None", ""):
99 params[key] = None 93 params[key] = None
100 continue 94 continue
101 95
102 if type(value) in [int, float, bool] or ( 96 if type(value) in [int, float, bool] or (
103 type(value) is str and value.isalpha() 97 type(value) is str and value.isalpha()
104 ): 98 ):
105 continue 99 continue
106 100
107 if ( 101 if key in [
108 key 102 "input_shape",
109 in [ 103 "noise_shape",
110 "input_shape", 104 "shape",
111 "noise_shape", 105 "batch_shape",
112 "shape", 106 "target_shape",
113 "batch_shape", 107 "dims",
114 "target_shape", 108 "kernel_size",
115 "dims", 109 "strides",
116 "kernel_size", 110 "dilation_rate",
117 "strides", 111 "output_padding",
118 "dilation_rate", 112 "cropping",
119 "output_padding", 113 "size",
120 "cropping", 114 "padding",
121 "size", 115 "pool_size",
122 "padding", 116 "axis",
123 "pool_size", 117 "shared_axes",
124 "axis", 118 ] and isinstance(value, str):
125 "shared_axes",
126 ]
127 and isinstance(value, str)
128 ):
129 params[key] = _handle_shape(value) 119 params[key] = _handle_shape(value)
130 120
131 elif key.endswith("_regularizer") and isinstance(value, dict): 121 elif key.endswith("_regularizer") and isinstance(value, dict):
132 params[key] = _handle_regularizer(value) 122 params[key] = _handle_regularizer(value)
133 123
139 129
140 return params 130 return params
141 131
142 132
143 def get_sequential_model(config): 133 def get_sequential_model(config):
144 """ 134 """Construct keras Sequential model from Galaxy tool parameters
145 Construct keras Sequential model from Galaxy tool parameters
146 135
147 Parameters: 136 Parameters:
148 ----------- 137 -----------
149 config : dictionary, galaxy tool parameters loaded by JSON 138 config : dictionary, galaxy tool parameters loaded by JSON
150 """ 139 """
163 if kwargs: 152 if kwargs:
164 kwargs = safe_eval("dict(" + kwargs + ")") 153 kwargs = safe_eval("dict(" + kwargs + ")")
165 options.update(kwargs) 154 options.update(kwargs)
166 155
167 # add input_shape to the first layer only 156 # add input_shape to the first layer only
168 if not getattr(model, "_layers") and input_shape is not None: 157 if not model.get_config()["layers"] and input_shape is not None:
169 options["input_shape"] = input_shape 158 options["input_shape"] = input_shape
170 159
171 model.add(klass(**options)) 160 model.add(klass(**options))
172 161
173 return model 162 return model
174 163
175 164
176 def get_functional_model(config): 165 def get_functional_model(config):
177 """ 166 """Construct keras functional model from Galaxy tool parameters
178 Construct keras functional model from Galaxy tool parameters
179 167
180 Parameters 168 Parameters
181 ----------- 169 -----------
182 config : dictionary, galaxy tool parameters loaded by JSON 170 config : dictionary, galaxy tool parameters loaded by JSON
183 """ 171 """
219 207
220 return Model(inputs=input_layers, outputs=output_layers) 208 return Model(inputs=input_layers, outputs=output_layers)
221 209
222 210
223 def get_batch_generator(config): 211 def get_batch_generator(config):
224 """ 212 """Construct keras online data generator from Galaxy tool parameters
225 Construct keras online data generator from Galaxy tool parameters
226 213
227 Parameters 214 Parameters
228 ----------- 215 -----------
229 config : dictionary, galaxy tool parameters loaded by JSON 216 config : dictionary, galaxy tool parameters loaded by JSON
230 """ 217 """
244 231
245 return klass(**config) 232 return klass(**config)
246 233
247 234
248 def config_keras_model(inputs, outfile): 235 def config_keras_model(inputs, outfile):
249 """ 236 """config keras model layers and output JSON
250 config keras model layers and output JSON
251 237
252 Parameters 238 Parameters
253 ---------- 239 ----------
254 inputs : dict 240 inputs : dict
255 loaded galaxy tool parameters from `keras_model_config` 241 loaded galaxy tool parameters from `keras_model_config`
269 255
270 with open(outfile, "w") as f: 256 with open(outfile, "w") as f:
271 json.dump(json.loads(json_string), f, indent=2) 257 json.dump(json.loads(json_string), f, indent=2)
272 258
273 259
274 def build_keras_model( 260 def build_keras_model(inputs, outfile, model_json, batch_mode=False):
275 inputs, 261 """for `keras_model_builder` tool
276 outfile,
277 model_json,
278 infile_weights=None,
279 batch_mode=False,
280 outfile_params=None,
281 ):
282 """
283 for `keras_model_builder` tool
284 262
285 Parameters 263 Parameters
286 ---------- 264 ----------
287 inputs : dict 265 inputs : dict
288 loaded galaxy tool parameters from `keras_model_builder` tool. 266 loaded galaxy tool parameters from `keras_model_builder` tool.
289 outfile : str 267 outfile : str
290 Path to galaxy dataset containing the keras_galaxy model output. 268 Path to galaxy dataset containing the keras_galaxy model output.
291 model_json : str 269 model_json : str
292 Path to dataset containing keras model JSON. 270 Path to dataset containing keras model JSON.
293 infile_weights : str or None
294 If string, path to dataset containing model weights.
295 batch_mode : bool, default=False 271 batch_mode : bool, default=False
296 Whether to build online batch classifier. 272 Whether to build online batch classifier.
297 outfile_params : str, default=None
298 File path to search parameters output.
299 """ 273 """
300 with open(model_json, "r") as f: 274 with open(model_json, "r") as f:
301 json_model = json.load(f) 275 json_model = json.load(f)
302 276
303 config = json_model["config"] 277 config = json_model["config"]
305 options = {} 279 options = {}
306 280
307 if json_model["class_name"] == "Sequential": 281 if json_model["class_name"] == "Sequential":
308 options["model_type"] = "sequential" 282 options["model_type"] = "sequential"
309 klass = Sequential 283 klass = Sequential
310 elif json_model["class_name"] == "Model": 284 elif json_model["class_name"] == "Functional":
311 options["model_type"] = "functional" 285 options["model_type"] = "functional"
312 klass = Model 286 klass = Model
313 else: 287 else:
314 raise ValueError("Unknow Keras model class: %s" % json_model["class_name"]) 288 raise ValueError("Unknow Keras model class: %s" % json_model["class_name"])
315 289
316 # load prefitted model 290 # load prefitted model
317 if inputs["mode_selection"]["mode_type"] == "prefitted": 291 if inputs["mode_selection"]["mode_type"] == "prefitted":
318 estimator = klass.from_config(config) 292 # estimator = klass.from_config(config)
319 estimator.load_weights(infile_weights) 293 # estimator.load_weights(infile_weights)
294 raise Exception("Prefitted was deprecated!")
320 # build train model 295 # build train model
321 else: 296 else:
322 cls_name = inputs["mode_selection"]["learning_type"] 297 cls_name = inputs["mode_selection"]["learning_type"]
323 klass = try_get_attr("galaxy_ml.keras_galaxy_models", cls_name) 298 klass = try_get_attr("galaxy_ml.keras_galaxy_models", cls_name)
324 299
336 ] 311 ]
337 ) 312 )
338 ) 313 )
339 314
340 train_metrics = inputs["mode_selection"]["compile_params"]["metrics"] 315 train_metrics = inputs["mode_selection"]["compile_params"]["metrics"]
316 if not isinstance(train_metrics, list): # for older galaxy
317 train_metrics = train_metrics.split(",")
341 if train_metrics[-1] == "none": 318 if train_metrics[-1] == "none":
342 train_metrics = train_metrics[:-1] 319 train_metrics.pop()
343 options["metrics"] = train_metrics 320 options["metrics"] = train_metrics
344 321
345 options.update(inputs["mode_selection"]["fit_params"]) 322 options.update(inputs["mode_selection"]["fit_params"])
346 options["seed"] = inputs["mode_selection"]["random_seed"] 323 options["seed"] = inputs["mode_selection"]["random_seed"]
347 324
353 options["prediction_steps"] = inputs["mode_selection"]["prediction_steps"] 330 options["prediction_steps"] = inputs["mode_selection"]["prediction_steps"]
354 options["class_positive_factor"] = inputs["mode_selection"][ 331 options["class_positive_factor"] = inputs["mode_selection"][
355 "class_positive_factor" 332 "class_positive_factor"
356 ] 333 ]
357 estimator = klass(config, **options) 334 estimator = klass(config, **options)
358 if outfile_params:
359 hyper_params = get_search_params(estimator)
360 # TODO: remove this after making `verbose` tunable
361 for h_param in hyper_params:
362 if h_param[1].endswith("verbose"):
363 h_param[0] = "@"
364 df = pd.DataFrame(hyper_params, columns=["", "Parameter", "Value"])
365 df.to_csv(outfile_params, sep="\t", index=False)
366 335
367 print(repr(estimator)) 336 print(repr(estimator))
368 # save model by pickle 337 # save model
369 with open(outfile, "wb") as f: 338 dump_model_to_h5(estimator, outfile, verbose=1)
370 pickle.dump(estimator, f, pickle.HIGHEST_PROTOCOL)
371 339
372 340
373 if __name__ == "__main__": 341 if __name__ == "__main__":
374 warnings.simplefilter("ignore") 342 warnings.simplefilter("ignore")
375 343
376 aparser = argparse.ArgumentParser() 344 aparser = argparse.ArgumentParser()
377 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) 345 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
378 aparser.add_argument("-m", "--model_json", dest="model_json") 346 aparser.add_argument("-m", "--model_json", dest="model_json")
379 aparser.add_argument("-t", "--tool_id", dest="tool_id") 347 aparser.add_argument("-t", "--tool_id", dest="tool_id")
380 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
381 aparser.add_argument("-o", "--outfile", dest="outfile") 348 aparser.add_argument("-o", "--outfile", dest="outfile")
382 aparser.add_argument("-p", "--outfile_params", dest="outfile_params")
383 args = aparser.parse_args() 349 args = aparser.parse_args()
384 350
385 input_json_path = args.inputs 351 input_json_path = args.inputs
386 with open(input_json_path, "r") as param_handler: 352 with open(input_json_path, "r") as param_handler:
387 inputs = json.load(param_handler) 353 inputs = json.load(param_handler)
388 354
389 tool_id = args.tool_id 355 tool_id = args.tool_id
390 outfile = args.outfile 356 outfile = args.outfile
391 outfile_params = args.outfile_params
392 model_json = args.model_json 357 model_json = args.model_json
393 infile_weights = args.infile_weights
394 358
395 # for keras_model_config tool 359 # for keras_model_config tool
396 if tool_id == "keras_model_config": 360 if tool_id == "keras_model_config":
397 config_keras_model(inputs, outfile) 361 config_keras_model(inputs, outfile)
398 362
401 batch_mode = False 365 batch_mode = False
402 if tool_id == "keras_batch_models": 366 if tool_id == "keras_batch_models":
403 batch_mode = True 367 batch_mode = True
404 368
405 build_keras_model( 369 build_keras_model(
406 inputs=inputs, 370 inputs=inputs, model_json=model_json, batch_mode=batch_mode, outfile=outfile
407 model_json=model_json,
408 infile_weights=infile_weights,
409 batch_mode=batch_mode,
410 outfile=outfile,
411 outfile_params=outfile_params,
412 ) 371 )