comparison keras_train_and_eval.py @ 0:59e8b4328c82 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:40:10 +0000
parents
children f93f0cdbaf18
comparison
equal deleted inserted replaced
-1:000000000000 0:59e8b4328c82
1 import argparse
2 import json
3 import os
4 import pickle
5 import warnings
6 from itertools import chain
7
8 import joblib
9 import numpy as np
10 import pandas as pd
11 from galaxy_ml.externals.selene_sdk.utils import compute_score
12 from galaxy_ml.keras_galaxy_models import _predict_generator
13 from galaxy_ml.model_validations import train_test_split
14 from galaxy_ml.utils import (
15 clean_params,
16 get_main_estimator,
17 get_module,
18 get_scoring,
19 load_model,
20 read_columns,
21 SafeEval,
22 try_get_attr,
23 )
24 from scipy.io import mmread
25 from sklearn.metrics.scorer import _check_multimetric_scoring
26 from sklearn.model_selection import _search, _validation
27 from sklearn.model_selection._validation import _score
28 from sklearn.pipeline import Pipeline
29 from sklearn.utils import indexable, safe_indexing
30
31
32 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score")
33 setattr(_search, "_fit_and_score", _fit_and_score)
34 setattr(_validation, "_fit_and_score", _fit_and_score)
35
36 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1))
37 CACHE_DIR = os.path.join(os.getcwd(), "cached")
38 del os
39 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks")
40 ALLOWED_CALLBACKS = (
41 "EarlyStopping",
42 "TerminateOnNaN",
43 "ReduceLROnPlateau",
44 "CSVLogger",
45 "None",
46 )
47
48
49 def _eval_swap_params(params_builder):
50 swap_params = {}
51
52 for p in params_builder["param_set"]:
53 swap_value = p["sp_value"].strip()
54 if swap_value == "":
55 continue
56
57 param_name = p["sp_name"]
58 if param_name.lower().endswith(NON_SEARCHABLE):
59 warnings.warn("Warning: `%s` is not eligible for search and was " "omitted!" % param_name)
60 continue
61
62 if not swap_value.startswith(":"):
63 safe_eval = SafeEval(load_scipy=True, load_numpy=True)
64 ev = safe_eval(swap_value)
65 else:
66 # Have `:` before search list, asks for estimator evaluatio
67 safe_eval_es = SafeEval(load_estimators=True)
68 swap_value = swap_value[1:].strip()
69 # TODO maybe add regular express check
70 ev = safe_eval_es(swap_value)
71
72 swap_params[param_name] = ev
73
74 return swap_params
75
76
77 def train_test_split_none(*arrays, **kwargs):
78 """extend train_test_split to take None arrays
79 and support split by group names.
80 """
81 nones = []
82 new_arrays = []
83 for idx, arr in enumerate(arrays):
84 if arr is None:
85 nones.append(idx)
86 else:
87 new_arrays.append(arr)
88
89 if kwargs["shuffle"] == "None":
90 kwargs["shuffle"] = None
91
92 group_names = kwargs.pop("group_names", None)
93
94 if group_names is not None and group_names.strip():
95 group_names = [name.strip() for name in group_names.split(",")]
96 new_arrays = indexable(*new_arrays)
97 groups = kwargs["labels"]
98 n_samples = new_arrays[0].shape[0]
99 index_arr = np.arange(n_samples)
100 test = index_arr[np.isin(groups, group_names)]
101 train = index_arr[~np.isin(groups, group_names)]
102 rval = list(chain.from_iterable((safe_indexing(a, train), safe_indexing(a, test)) for a in new_arrays))
103 else:
104 rval = train_test_split(*new_arrays, **kwargs)
105
106 for pos in nones:
107 rval[pos * 2: 2] = [None, None]
108
109 return rval
110
111
112 def _evaluate(y_true, pred_probas, scorer, is_multimetric=True):
113 """output scores based on input scorer
114
115 Parameters
116 ----------
117 y_true : array
118 True label or target values
119 pred_probas : array
120 Prediction values, probability for classification problem
121 scorer : dict
122 dict of `sklearn.metrics.scorer.SCORER`
123 is_multimetric : bool, default is True
124 """
125 if y_true.ndim == 1 or y_true.shape[-1] == 1:
126 pred_probas = pred_probas.ravel()
127 pred_labels = (pred_probas > 0.5).astype("int32")
128 targets = y_true.ravel().astype("int32")
129 if not is_multimetric:
130 preds = pred_labels if scorer.__class__.__name__ == "_PredictScorer" else pred_probas
131 score = scorer._score_func(targets, preds, **scorer._kwargs)
132
133 return score
134 else:
135 scores = {}
136 for name, one_scorer in scorer.items():
137 preds = pred_labels if one_scorer.__class__.__name__ == "_PredictScorer" else pred_probas
138 score = one_scorer._score_func(targets, preds, **one_scorer._kwargs)
139 scores[name] = score
140
141 # TODO: multi-class metrics
142 # multi-label
143 else:
144 pred_labels = (pred_probas > 0.5).astype("int32")
145 targets = y_true.astype("int32")
146 if not is_multimetric:
147 preds = pred_labels if scorer.__class__.__name__ == "_PredictScorer" else pred_probas
148 score, _ = compute_score(preds, targets, scorer._score_func)
149 return score
150 else:
151 scores = {}
152 for name, one_scorer in scorer.items():
153 preds = pred_labels if one_scorer.__class__.__name__ == "_PredictScorer" else pred_probas
154 score, _ = compute_score(preds, targets, one_scorer._score_func)
155 scores[name] = score
156
157 return scores
158
159
160 def main(
161 inputs,
162 infile_estimator,
163 infile1,
164 infile2,
165 outfile_result,
166 outfile_object=None,
167 outfile_weights=None,
168 outfile_y_true=None,
169 outfile_y_preds=None,
170 groups=None,
171 ref_seq=None,
172 intervals=None,
173 targets=None,
174 fasta_path=None,
175 ):
176 """
177 Parameter
178 ---------
179 inputs : str
180 File path to galaxy tool parameter
181
182 infile_estimator : str
183 File path to estimator
184
185 infile1 : str
186 File path to dataset containing features
187
188 infile2 : str
189 File path to dataset containing target values
190
191 outfile_result : str
192 File path to save the results, either cv_results or test result
193
194 outfile_object : str, optional
195 File path to save searchCV object
196
197 outfile_weights : str, optional
198 File path to save deep learning model weights
199
200 outfile_y_true : str, optional
201 File path to target values for prediction
202
203 outfile_y_preds : str, optional
204 File path to save deep learning model weights
205
206 groups : str
207 File path to dataset containing groups labels
208
209 ref_seq : str
210 File path to dataset containing genome sequence file
211
212 intervals : str
213 File path to dataset containing interval file
214
215 targets : str
216 File path to dataset compressed target bed file
217
218 fasta_path : str
219 File path to dataset containing fasta file
220 """
221 warnings.simplefilter("ignore")
222
223 with open(inputs, "r") as param_handler:
224 params = json.load(param_handler)
225
226 # load estimator
227 with open(infile_estimator, "rb") as estimator_handler:
228 estimator = load_model(estimator_handler)
229
230 estimator = clean_params(estimator)
231
232 # swap hyperparameter
233 swapping = params["experiment_schemes"]["hyperparams_swapping"]
234 swap_params = _eval_swap_params(swapping)
235 estimator.set_params(**swap_params)
236
237 estimator_params = estimator.get_params()
238
239 # store read dataframe object
240 loaded_df = {}
241
242 input_type = params["input_options"]["selected_input"]
243 # tabular input
244 if input_type == "tabular":
245 header = "infer" if params["input_options"]["header1"] else None
246 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
247 if column_option in [
248 "by_index_number",
249 "all_but_by_index_number",
250 "by_header_name",
251 "all_but_by_header_name",
252 ]:
253 c = params["input_options"]["column_selector_options_1"]["col1"]
254 else:
255 c = None
256
257 df_key = infile1 + repr(header)
258 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
259 loaded_df[df_key] = df
260
261 X = read_columns(df, c=c, c_option=column_option).astype(float)
262 # sparse input
263 elif input_type == "sparse":
264 X = mmread(open(infile1, "r"))
265
266 # fasta_file input
267 elif input_type == "seq_fasta":
268 pyfaidx = get_module("pyfaidx")
269 sequences = pyfaidx.Fasta(fasta_path)
270 n_seqs = len(sequences.keys())
271 X = np.arange(n_seqs)[:, np.newaxis]
272 for param in estimator_params.keys():
273 if param.endswith("fasta_path"):
274 estimator.set_params(**{param: fasta_path})
275 break
276 else:
277 raise ValueError(
278 "The selected estimator doesn't support "
279 "fasta file input! Please consider using "
280 "KerasGBatchClassifier with "
281 "FastaDNABatchGenerator/FastaProteinBatchGenerator "
282 "or having GenomeOneHotEncoder/ProteinOneHotEncoder "
283 "in pipeline!"
284 )
285
286 elif input_type == "refseq_and_interval":
287 path_params = {
288 "data_batch_generator__ref_genome_path": ref_seq,
289 "data_batch_generator__intervals_path": intervals,
290 "data_batch_generator__target_path": targets,
291 }
292 estimator.set_params(**path_params)
293 n_intervals = sum(1 for line in open(intervals))
294 X = np.arange(n_intervals)[:, np.newaxis]
295
296 # Get target y
297 header = "infer" if params["input_options"]["header2"] else None
298 column_option = params["input_options"]["column_selector_options_2"]["selected_column_selector_option2"]
299 if column_option in [
300 "by_index_number",
301 "all_but_by_index_number",
302 "by_header_name",
303 "all_but_by_header_name",
304 ]:
305 c = params["input_options"]["column_selector_options_2"]["col2"]
306 else:
307 c = None
308
309 df_key = infile2 + repr(header)
310 if df_key in loaded_df:
311 infile2 = loaded_df[df_key]
312 else:
313 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
314 loaded_df[df_key] = infile2
315
316 y = read_columns(infile2,
317 c=c,
318 c_option=column_option,
319 sep='\t',
320 header=header,
321 parse_dates=True)
322 if len(y.shape) == 2 and y.shape[1] == 1:
323 y = y.ravel()
324 if input_type == "refseq_and_interval":
325 estimator.set_params(data_batch_generator__features=y.ravel().tolist())
326 y = None
327 # end y
328
329 # load groups
330 if groups:
331 groups_selector = (params["experiment_schemes"]["test_split"]["split_algos"]).pop("groups_selector")
332
333 header = "infer" if groups_selector["header_g"] else None
334 column_option = groups_selector["column_selector_options_g"]["selected_column_selector_option_g"]
335 if column_option in [
336 "by_index_number",
337 "all_but_by_index_number",
338 "by_header_name",
339 "all_but_by_header_name",
340 ]:
341 c = groups_selector["column_selector_options_g"]["col_g"]
342 else:
343 c = None
344
345 df_key = groups + repr(header)
346 if df_key in loaded_df:
347 groups = loaded_df[df_key]
348
349 groups = read_columns(groups,
350 c=c,
351 c_option=column_option,
352 sep='\t',
353 header=header,
354 parse_dates=True)
355 groups = groups.ravel()
356
357 # del loaded_df
358 del loaded_df
359
360 # cache iraps_core fits could increase search speed significantly
361 memory = joblib.Memory(location=CACHE_DIR, verbose=0)
362 main_est = get_main_estimator(estimator)
363 if main_est.__class__.__name__ == "IRAPSClassifier":
364 main_est.set_params(memory=memory)
365
366 # handle scorer, convert to scorer dict
367 scoring = params['experiment_schemes']['metrics']['scoring']
368 if scoring is not None:
369 # get_scoring() expects secondary_scoring to be a comma separated string (not a list)
370 # Check if secondary_scoring is specified
371 secondary_scoring = scoring.get("secondary_scoring", None)
372 if secondary_scoring is not None:
373 # If secondary_scoring is specified, convert the list into comman separated string
374 scoring["secondary_scoring"] = ",".join(scoring["secondary_scoring"])
375
376 scorer = get_scoring(scoring)
377 scorer, _ = _check_multimetric_scoring(estimator, scoring=scorer)
378
379 # handle test (first) split
380 test_split_options = params["experiment_schemes"]["test_split"]["split_algos"]
381
382 if test_split_options["shuffle"] == "group":
383 test_split_options["labels"] = groups
384 if test_split_options["shuffle"] == "stratified":
385 if y is not None:
386 test_split_options["labels"] = y
387 else:
388 raise ValueError("Stratified shuffle split is not " "applicable on empty target values!")
389
390 (
391 X_train,
392 X_test,
393 y_train,
394 y_test,
395 groups_train,
396 _groups_test,
397 ) = train_test_split_none(X, y, groups, **test_split_options)
398
399 exp_scheme = params["experiment_schemes"]["selected_exp_scheme"]
400
401 # handle validation (second) split
402 if exp_scheme == "train_val_test":
403 val_split_options = params["experiment_schemes"]["val_split"]["split_algos"]
404
405 if val_split_options["shuffle"] == "group":
406 val_split_options["labels"] = groups_train
407 if val_split_options["shuffle"] == "stratified":
408 if y_train is not None:
409 val_split_options["labels"] = y_train
410 else:
411 raise ValueError("Stratified shuffle split is not " "applicable on empty target values!")
412
413 (
414 X_train,
415 X_val,
416 y_train,
417 y_val,
418 groups_train,
419 _groups_val,
420 ) = train_test_split_none(X_train, y_train, groups_train, **val_split_options)
421
422 # train and eval
423 if hasattr(estimator, "validation_data"):
424 if exp_scheme == "train_val_test":
425 estimator.fit(X_train, y_train, validation_data=(X_val, y_val))
426 else:
427 estimator.fit(X_train, y_train, validation_data=(X_test, y_test))
428 else:
429 estimator.fit(X_train, y_train)
430
431 if hasattr(estimator, "evaluate"):
432 steps = estimator.prediction_steps
433 batch_size = estimator.batch_size
434 generator = estimator.data_generator_.flow(X_test, y=y_test, batch_size=batch_size)
435 predictions, y_true = _predict_generator(estimator.model_, generator, steps=steps)
436 scores = _evaluate(y_true, predictions, scorer, is_multimetric=True)
437
438 else:
439 if hasattr(estimator, "predict_proba"):
440 predictions = estimator.predict_proba(X_test)
441 else:
442 predictions = estimator.predict(X_test)
443
444 y_true = y_test
445 scores = _score(estimator, X_test, y_test, scorer, is_multimetric=True)
446 if outfile_y_true:
447 try:
448 pd.DataFrame(y_true).to_csv(outfile_y_true, sep="\t", index=False)
449 pd.DataFrame(predictions).astype(np.float32).to_csv(
450 outfile_y_preds,
451 sep="\t",
452 index=False,
453 float_format="%g",
454 chunksize=10000,
455 )
456 except Exception as e:
457 print("Error in saving predictions: %s" % e)
458
459 # handle output
460 for name, score in scores.items():
461 scores[name] = [score]
462 df = pd.DataFrame(scores)
463 df = df[sorted(df.columns)]
464 df.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False)
465
466 memory.clear(warn=False)
467
468 if outfile_object:
469 main_est = estimator
470 if isinstance(estimator, Pipeline):
471 main_est = estimator.steps[-1][-1]
472
473 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
474 if outfile_weights:
475 main_est.save_weights(outfile_weights)
476 del main_est.model_
477 del main_est.fit_params
478 del main_est.model_class_
479 if getattr(main_est, "validation_data", None):
480 del main_est.validation_data
481 if getattr(main_est, "data_generator_", None):
482 del main_est.data_generator_
483
484 with open(outfile_object, "wb") as output_handler:
485 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
486
487
488 if __name__ == "__main__":
489 aparser = argparse.ArgumentParser()
490 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
491 aparser.add_argument("-e", "--estimator", dest="infile_estimator")
492 aparser.add_argument("-X", "--infile1", dest="infile1")
493 aparser.add_argument("-y", "--infile2", dest="infile2")
494 aparser.add_argument("-O", "--outfile_result", dest="outfile_result")
495 aparser.add_argument("-o", "--outfile_object", dest="outfile_object")
496 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights")
497 aparser.add_argument("-l", "--outfile_y_true", dest="outfile_y_true")
498 aparser.add_argument("-p", "--outfile_y_preds", dest="outfile_y_preds")
499 aparser.add_argument("-g", "--groups", dest="groups")
500 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
501 aparser.add_argument("-b", "--intervals", dest="intervals")
502 aparser.add_argument("-t", "--targets", dest="targets")
503 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
504 args = aparser.parse_args()
505
506 main(
507 args.inputs,
508 args.infile_estimator,
509 args.infile1,
510 args.infile2,
511 args.outfile_result,
512 outfile_object=args.outfile_object,
513 outfile_weights=args.outfile_weights,
514 outfile_y_true=args.outfile_y_true,
515 outfile_y_preds=args.outfile_y_preds,
516 groups=args.groups,
517 ref_seq=args.ref_seq,
518 intervals=args.intervals,
519 targets=args.targets,
520 fasta_path=args.fasta_path,
521 )