comparison simple_model_fit.py @ 30:ab4249158912 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
author bgruening
date Thu, 07 Nov 2019 05:45:03 -0500
parents 172365bc2b5f
children 19d6c2745d34
comparison
equal deleted inserted replaced
29:172365bc2b5f 30:ab4249158912
3 import pandas as pd 3 import pandas as pd
4 import pickle 4 import pickle
5 5
6 from galaxy_ml.utils import load_model, read_columns 6 from galaxy_ml.utils import load_model, read_columns
7 from sklearn.pipeline import Pipeline 7 from sklearn.pipeline import Pipeline
8
9
10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
11
12
13 # TODO import from galaxy_ml.utils in future versions
14 def clean_params(estimator, n_jobs=None):
15 """clean unwanted hyperparameter settings
16
17 If n_jobs is not None, set it into the estimator, if applicable
18
19 Return
20 ------
21 Cleaned estimator object
22 """
23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN',
24 'ReduceLROnPlateau', 'CSVLogger', 'None')
25
26 estimator_params = estimator.get_params()
27
28 for name, p in estimator_params.items():
29 # all potential unauthorized file write
30 if name == 'memory' or name.endswith('__memory') \
31 or name.endswith('_path'):
32 new_p = {name: None}
33 estimator.set_params(**new_p)
34 elif n_jobs is not None and (name == 'n_jobs' or
35 name.endswith('__n_jobs')):
36 new_p = {name: n_jobs}
37 estimator.set_params(**new_p)
38 elif name.endswith('callbacks'):
39 for cb in p:
40 cb_type = cb['callback_selection']['callback_type']
41 if cb_type not in ALLOWED_CALLBACKS:
42 raise ValueError(
43 "Prohibited callback type: %s!" % cb_type)
44
45 return estimator
8 46
9 47
10 def _get_X_y(params, infile1, infile2): 48 def _get_X_y(params, infile1, infile2):
11 """ read from inputs and output X and y 49 """ read from inputs and output X and y
12 50
105 params = json.load(param_handler) 143 params = json.load(param_handler)
106 144
107 # load model 145 # load model
108 with open(infile_estimator, 'rb') as est_handler: 146 with open(infile_estimator, 'rb') as est_handler:
109 estimator = load_model(est_handler) 147 estimator = load_model(est_handler)
148 estimator = clean_params(estimator, n_jobs=N_JOBS)
110 149
111 X_train, y_train = _get_X_y(params, infile1, infile2) 150 X_train, y_train = _get_X_y(params, infile1, infile2)
112 151
113 estimator.fit(X_train, y_train) 152 estimator.fit(X_train, y_train)
114 153