Mercurial > repos > bgruening > sklearn_model_fit
comparison simple_model_fit.py @ 1:49525d37843a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit d6333e7294e67be5968a41f404b66699cad4ae53"
author | bgruening |
---|---|
date | Thu, 07 Nov 2019 05:31:59 -0500 |
parents | 734c66aa945a |
children | 26decbf4bdb8 |
comparison
equal
deleted
inserted
replaced
0:734c66aa945a | 1:49525d37843a |
---|---|
3 import pandas as pd | 3 import pandas as pd |
4 import pickle | 4 import pickle |
5 | 5 |
6 from galaxy_ml.utils import load_model, read_columns | 6 from galaxy_ml.utils import load_model, read_columns |
7 from sklearn.pipeline import Pipeline | 7 from sklearn.pipeline import Pipeline |
8 | |
9 | |
10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
11 | |
12 | |
13 # TODO import from galaxy_ml.utils in future versions | |
14 def clean_params(estimator, n_jobs=None): | |
15 """clean unwanted hyperparameter settings | |
16 | |
17 If n_jobs is not None, set it into the estimator, if applicable | |
18 | |
19 Return | |
20 ------ | |
21 Cleaned estimator object | |
22 """ | |
23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN', | |
24 'ReduceLROnPlateau', 'CSVLogger', 'None') | |
25 | |
26 estimator_params = estimator.get_params() | |
27 | |
28 for name, p in estimator_params.items(): | |
29 # all potential unauthorized file write | |
30 if name == 'memory' or name.endswith('__memory') \ | |
31 or name.endswith('_path'): | |
32 new_p = {name: None} | |
33 estimator.set_params(**new_p) | |
34 elif n_jobs is not None and (name == 'n_jobs' or | |
35 name.endswith('__n_jobs')): | |
36 new_p = {name: n_jobs} | |
37 estimator.set_params(**new_p) | |
38 elif name.endswith('callbacks'): | |
39 for cb in p: | |
40 cb_type = cb['callback_selection']['callback_type'] | |
41 if cb_type not in ALLOWED_CALLBACKS: | |
42 raise ValueError( | |
43 "Prohibited callback type: %s!" % cb_type) | |
44 | |
45 return estimator | |
8 | 46 |
9 | 47 |
10 def _get_X_y(params, infile1, infile2): | 48 def _get_X_y(params, infile1, infile2): |
11 """ read from inputs and output X and y | 49 """ read from inputs and output X and y |
12 | 50 |
105 params = json.load(param_handler) | 143 params = json.load(param_handler) |
106 | 144 |
107 # load model | 145 # load model |
108 with open(infile_estimator, 'rb') as est_handler: | 146 with open(infile_estimator, 'rb') as est_handler: |
109 estimator = load_model(est_handler) | 147 estimator = load_model(est_handler) |
148 estimator = clean_params(estimator, n_jobs=N_JOBS) | |
110 | 149 |
111 X_train, y_train = _get_X_y(params, infile1, infile2) | 150 X_train, y_train = _get_X_y(params, infile1, infile2) |
112 | 151 |
113 estimator.fit(X_train, y_train) | 152 estimator.fit(X_train, y_train) |
114 | 153 |