Mercurial > repos > bgruening > keras_train_and_eval
comparison keras_deep_learning.py @ 0:03f61bb3ca43 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author | bgruening |
---|---|
date | Mon, 16 Dec 2019 05:36:53 -0500 |
parents | |
children | 3866911c93ae |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:03f61bb3ca43 |
---|---|
1 import argparse | |
2 import json | |
3 import keras | |
4 import pandas as pd | |
5 import pickle | |
6 import six | |
7 import warnings | |
8 | |
9 from ast import literal_eval | |
10 from keras.models import Sequential, Model | |
11 from galaxy_ml.utils import try_get_attr, get_search_params, SafeEval | |
12 | |
13 | |
14 safe_eval = SafeEval() | |
15 | |
16 | |
17 def _handle_shape(literal): | |
18 """Eval integer or list/tuple of integers from string | |
19 | |
20 Parameters: | |
21 ----------- | |
22 literal : str. | |
23 """ | |
24 literal = literal.strip() | |
25 if not literal: | |
26 return None | |
27 try: | |
28 return literal_eval(literal) | |
29 except NameError as e: | |
30 print(e) | |
31 return literal | |
32 | |
33 | |
34 def _handle_regularizer(literal): | |
35 """Construct regularizer from string literal | |
36 | |
37 Parameters | |
38 ---------- | |
39 literal : str. E.g. '(0.1, 0)' | |
40 """ | |
41 literal = literal.strip() | |
42 if not literal: | |
43 return None | |
44 | |
45 l1, l2 = literal_eval(literal) | |
46 | |
47 if not l1 and not l2: | |
48 return None | |
49 | |
50 if l1 is None: | |
51 l1 = 0. | |
52 if l2 is None: | |
53 l2 = 0. | |
54 | |
55 return keras.regularizers.l1_l2(l1=l1, l2=l2) | |
56 | |
57 | |
58 def _handle_constraint(config): | |
59 """Construct constraint from galaxy tool parameters. | |
60 Suppose correct dictionary format | |
61 | |
62 Parameters | |
63 ---------- | |
64 config : dict. E.g. | |
65 "bias_constraint": | |
66 {"constraint_options": | |
67 {"max_value":1.0, | |
68 "min_value":0.0, | |
69 "axis":"[0, 1, 2]" | |
70 }, | |
71 "constraint_type": | |
72 "MinMaxNorm" | |
73 } | |
74 """ | |
75 constraint_type = config['constraint_type'] | |
76 if constraint_type in ('None', ''): | |
77 return None | |
78 | |
79 klass = getattr(keras.constraints, constraint_type) | |
80 options = config.get('constraint_options', {}) | |
81 if 'axis' in options: | |
82 options['axis'] = literal_eval(options['axis']) | |
83 | |
84 return klass(**options) | |
85 | |
86 | |
87 def _handle_lambda(literal): | |
88 return None | |
89 | |
90 | |
91 def _handle_layer_parameters(params): | |
92 """Access to handle all kinds of parameters | |
93 """ | |
94 for key, value in six.iteritems(params): | |
95 if value in ('None', ''): | |
96 params[key] = None | |
97 continue | |
98 | |
99 if type(value) in [int, float, bool]\ | |
100 or (type(value) is str and value.isalpha()): | |
101 continue | |
102 | |
103 if key in ['input_shape', 'noise_shape', 'shape', 'batch_shape', | |
104 'target_shape', 'dims', 'kernel_size', 'strides', | |
105 'dilation_rate', 'output_padding', 'cropping', 'size', | |
106 'padding', 'pool_size', 'axis', 'shared_axes'] \ | |
107 and isinstance(value, str): | |
108 params[key] = _handle_shape(value) | |
109 | |
110 elif key.endswith('_regularizer') and isinstance(value, dict): | |
111 params[key] = _handle_regularizer(value) | |
112 | |
113 elif key.endswith('_constraint') and isinstance(value, dict): | |
114 params[key] = _handle_constraint(value) | |
115 | |
116 elif key == 'function': # No support for lambda/function eval | |
117 params.pop(key) | |
118 | |
119 return params | |
120 | |
121 | |
122 def get_sequential_model(config): | |
123 """Construct keras Sequential model from Galaxy tool parameters | |
124 | |
125 Parameters: | |
126 ----------- | |
127 config : dictionary, galaxy tool parameters loaded by JSON | |
128 """ | |
129 model = Sequential() | |
130 input_shape = _handle_shape(config['input_shape']) | |
131 layers = config['layers'] | |
132 for layer in layers: | |
133 options = layer['layer_selection'] | |
134 layer_type = options.pop('layer_type') | |
135 klass = getattr(keras.layers, layer_type) | |
136 kwargs = options.pop('kwargs', '') | |
137 | |
138 # parameters needs special care | |
139 options = _handle_layer_parameters(options) | |
140 | |
141 if kwargs: | |
142 kwargs = safe_eval('dict(' + kwargs + ')') | |
143 options.update(kwargs) | |
144 | |
145 # add input_shape to the first layer only | |
146 if not getattr(model, '_layers') and input_shape is not None: | |
147 options['input_shape'] = input_shape | |
148 | |
149 model.add(klass(**options)) | |
150 | |
151 return model | |
152 | |
153 | |
154 def get_functional_model(config): | |
155 """Construct keras functional model from Galaxy tool parameters | |
156 | |
157 Parameters | |
158 ----------- | |
159 config : dictionary, galaxy tool parameters loaded by JSON | |
160 """ | |
161 layers = config['layers'] | |
162 all_layers = [] | |
163 for layer in layers: | |
164 options = layer['layer_selection'] | |
165 layer_type = options.pop('layer_type') | |
166 klass = getattr(keras.layers, layer_type) | |
167 inbound_nodes = options.pop('inbound_nodes', None) | |
168 kwargs = options.pop('kwargs', '') | |
169 | |
170 # parameters needs special care | |
171 options = _handle_layer_parameters(options) | |
172 | |
173 if kwargs: | |
174 kwargs = safe_eval('dict(' + kwargs + ')') | |
175 options.update(kwargs) | |
176 | |
177 # merge layers | |
178 if 'merging_layers' in options: | |
179 idxs = literal_eval(options.pop('merging_layers')) | |
180 merging_layers = [all_layers[i-1] for i in idxs] | |
181 new_layer = klass(**options)(merging_layers) | |
182 # non-input layers | |
183 elif inbound_nodes is not None: | |
184 new_layer = klass(**options)(all_layers[inbound_nodes-1]) | |
185 # input layers | |
186 else: | |
187 new_layer = klass(**options) | |
188 | |
189 all_layers.append(new_layer) | |
190 | |
191 input_indexes = _handle_shape(config['input_layers']) | |
192 input_layers = [all_layers[i-1] for i in input_indexes] | |
193 | |
194 output_indexes = _handle_shape(config['output_layers']) | |
195 output_layers = [all_layers[i-1] for i in output_indexes] | |
196 | |
197 return Model(inputs=input_layers, outputs=output_layers) | |
198 | |
199 | |
200 def get_batch_generator(config): | |
201 """Construct keras online data generator from Galaxy tool parameters | |
202 | |
203 Parameters | |
204 ----------- | |
205 config : dictionary, galaxy tool parameters loaded by JSON | |
206 """ | |
207 generator_type = config.pop('generator_type') | |
208 if generator_type == 'none': | |
209 return None | |
210 | |
211 klass = try_get_attr('galaxy_ml.preprocessors', generator_type) | |
212 | |
213 if generator_type == 'GenomicIntervalBatchGenerator': | |
214 config['ref_genome_path'] = 'to_be_determined' | |
215 config['intervals_path'] = 'to_be_determined' | |
216 config['target_path'] = 'to_be_determined' | |
217 config['features'] = 'to_be_determined' | |
218 else: | |
219 config['fasta_path'] = 'to_be_determined' | |
220 | |
221 return klass(**config) | |
222 | |
223 | |
224 def config_keras_model(inputs, outfile): | |
225 """ config keras model layers and output JSON | |
226 | |
227 Parameters | |
228 ---------- | |
229 inputs : dict | |
230 loaded galaxy tool parameters from `keras_model_config` | |
231 tool. | |
232 outfile : str | |
233 Path to galaxy dataset containing keras model JSON. | |
234 """ | |
235 model_type = inputs['model_selection']['model_type'] | |
236 layers_config = inputs['model_selection'] | |
237 | |
238 if model_type == 'sequential': | |
239 model = get_sequential_model(layers_config) | |
240 else: | |
241 model = get_functional_model(layers_config) | |
242 | |
243 json_string = model.to_json() | |
244 | |
245 with open(outfile, 'w') as f: | |
246 json.dump(json.loads(json_string), f, indent=2) | |
247 | |
248 | |
249 def build_keras_model(inputs, outfile, model_json, infile_weights=None, | |
250 batch_mode=False, outfile_params=None): | |
251 """ for `keras_model_builder` tool | |
252 | |
253 Parameters | |
254 ---------- | |
255 inputs : dict | |
256 loaded galaxy tool parameters from `keras_model_builder` tool. | |
257 outfile : str | |
258 Path to galaxy dataset containing the keras_galaxy model output. | |
259 model_json : str | |
260 Path to dataset containing keras model JSON. | |
261 infile_weights : str or None | |
262 If string, path to dataset containing model weights. | |
263 batch_mode : bool, default=False | |
264 Whether to build online batch classifier. | |
265 outfile_params : str, default=None | |
266 File path to search parameters output. | |
267 """ | |
268 with open(model_json, 'r') as f: | |
269 json_model = json.load(f) | |
270 | |
271 config = json_model['config'] | |
272 | |
273 options = {} | |
274 | |
275 if json_model['class_name'] == 'Sequential': | |
276 options['model_type'] = 'sequential' | |
277 klass = Sequential | |
278 elif json_model['class_name'] == 'Model': | |
279 options['model_type'] = 'functional' | |
280 klass = Model | |
281 else: | |
282 raise ValueError("Unknow Keras model class: %s" | |
283 % json_model['class_name']) | |
284 | |
285 # load prefitted model | |
286 if inputs['mode_selection']['mode_type'] == 'prefitted': | |
287 estimator = klass.from_config(config) | |
288 estimator.load_weights(infile_weights) | |
289 # build train model | |
290 else: | |
291 cls_name = inputs['mode_selection']['learning_type'] | |
292 klass = try_get_attr('galaxy_ml.keras_galaxy_models', cls_name) | |
293 | |
294 options['loss'] = (inputs['mode_selection'] | |
295 ['compile_params']['loss']) | |
296 options['optimizer'] =\ | |
297 (inputs['mode_selection']['compile_params'] | |
298 ['optimizer_selection']['optimizer_type']).lower() | |
299 | |
300 options.update((inputs['mode_selection']['compile_params'] | |
301 ['optimizer_selection']['optimizer_options'])) | |
302 | |
303 train_metrics = (inputs['mode_selection']['compile_params'] | |
304 ['metrics']).split(',') | |
305 if train_metrics[-1] == 'none': | |
306 train_metrics = train_metrics[:-1] | |
307 options['metrics'] = train_metrics | |
308 | |
309 options.update(inputs['mode_selection']['fit_params']) | |
310 options['seed'] = inputs['mode_selection']['random_seed'] | |
311 | |
312 if batch_mode: | |
313 generator = get_batch_generator(inputs['mode_selection'] | |
314 ['generator_selection']) | |
315 options['data_batch_generator'] = generator | |
316 options['prediction_steps'] = \ | |
317 inputs['mode_selection']['prediction_steps'] | |
318 options['class_positive_factor'] = \ | |
319 inputs['mode_selection']['class_positive_factor'] | |
320 estimator = klass(config, **options) | |
321 if outfile_params: | |
322 hyper_params = get_search_params(estimator) | |
323 # TODO: remove this after making `verbose` tunable | |
324 for h_param in hyper_params: | |
325 if h_param[1].endswith('verbose'): | |
326 h_param[0] = '@' | |
327 df = pd.DataFrame(hyper_params, columns=['', 'Parameter', 'Value']) | |
328 df.to_csv(outfile_params, sep='\t', index=False) | |
329 | |
330 print(repr(estimator)) | |
331 # save model by pickle | |
332 with open(outfile, 'wb') as f: | |
333 pickle.dump(estimator, f, pickle.HIGHEST_PROTOCOL) | |
334 | |
335 | |
336 if __name__ == '__main__': | |
337 warnings.simplefilter('ignore') | |
338 | |
339 aparser = argparse.ArgumentParser() | |
340 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
341 aparser.add_argument("-m", "--model_json", dest="model_json") | |
342 aparser.add_argument("-t", "--tool_id", dest="tool_id") | |
343 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
344 aparser.add_argument("-o", "--outfile", dest="outfile") | |
345 aparser.add_argument("-p", "--outfile_params", dest="outfile_params") | |
346 args = aparser.parse_args() | |
347 | |
348 input_json_path = args.inputs | |
349 with open(input_json_path, 'r') as param_handler: | |
350 inputs = json.load(param_handler) | |
351 | |
352 tool_id = args.tool_id | |
353 outfile = args.outfile | |
354 outfile_params = args.outfile_params | |
355 model_json = args.model_json | |
356 infile_weights = args.infile_weights | |
357 | |
358 # for keras_model_config tool | |
359 if tool_id == 'keras_model_config': | |
360 config_keras_model(inputs, outfile) | |
361 | |
362 # for keras_model_builder tool | |
363 else: | |
364 batch_mode = False | |
365 if tool_id == 'keras_batch_models': | |
366 batch_mode = True | |
367 | |
368 build_keras_model(inputs=inputs, | |
369 model_json=model_json, | |
370 infile_weights=infile_weights, | |
371 batch_mode=batch_mode, | |
372 outfile=outfile, | |
373 outfile_params=outfile_params) |