comparison simple_model_fit.py @ 0:2d7016b3ae92 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2afb24f3c81d625312186750a714d702363012b5"
author bgruening
date Fri, 02 Oct 2020 08:45:21 +0000
parents
children 132805688fa3
comparison
equal deleted inserted replaced
-1:000000000000 0:2d7016b3ae92
1 import argparse
2 import json
3 import pandas as pd
4 import pickle
5
6 from galaxy_ml.utils import load_model, read_columns
7 from sklearn.pipeline import Pipeline
8
9
10 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
11
12
13 # TODO import from galaxy_ml.utils in future versions
14 def clean_params(estimator, n_jobs=None):
15 """clean unwanted hyperparameter settings
16
17 If n_jobs is not None, set it into the estimator, if applicable
18
19 Return
20 ------
21 Cleaned estimator object
22 """
23 ALLOWED_CALLBACKS = ('EarlyStopping', 'TerminateOnNaN',
24 'ReduceLROnPlateau', 'CSVLogger', 'None')
25
26 estimator_params = estimator.get_params()
27
28 for name, p in estimator_params.items():
29 # all potential unauthorized file write
30 if name == 'memory' or name.endswith('__memory') \
31 or name.endswith('_path'):
32 new_p = {name: None}
33 estimator.set_params(**new_p)
34 elif n_jobs is not None and (name == 'n_jobs' or
35 name.endswith('__n_jobs')):
36 new_p = {name: n_jobs}
37 estimator.set_params(**new_p)
38 elif name.endswith('callbacks'):
39 for cb in p:
40 cb_type = cb['callback_selection']['callback_type']
41 if cb_type not in ALLOWED_CALLBACKS:
42 raise ValueError(
43 "Prohibited callback type: %s!" % cb_type)
44
45 return estimator
46
47
48 def _get_X_y(params, infile1, infile2):
49 """ read from inputs and output X and y
50
51 Parameters
52 ----------
53 params : dict
54 Tool inputs parameter
55 infile1 : str
56 File path to dataset containing features
57 infile2 : str
58 File path to dataset containing target values
59
60 """
61 # store read dataframe object
62 loaded_df = {}
63
64 input_type = params['input_options']['selected_input']
65 # tabular input
66 if input_type == 'tabular':
67 header = 'infer' if params['input_options']['header1'] else None
68 column_option = (params['input_options']['column_selector_options_1']
69 ['selected_column_selector_option'])
70 if column_option in ['by_index_number', 'all_but_by_index_number',
71 'by_header_name', 'all_but_by_header_name']:
72 c = params['input_options']['column_selector_options_1']['col1']
73 else:
74 c = None
75
76 df_key = infile1 + repr(header)
77 df = pd.read_csv(infile1, sep='\t', header=header,
78 parse_dates=True)
79 loaded_df[df_key] = df
80
81 X = read_columns(df, c=c, c_option=column_option).astype(float)
82 # sparse input
83 elif input_type == 'sparse':
84 X = mmread(open(infile1, 'r'))
85
86 # Get target y
87 header = 'infer' if params['input_options']['header2'] else None
88 column_option = (params['input_options']['column_selector_options_2']
89 ['selected_column_selector_option2'])
90 if column_option in ['by_index_number', 'all_but_by_index_number',
91 'by_header_name', 'all_but_by_header_name']:
92 c = params['input_options']['column_selector_options_2']['col2']
93 else:
94 c = None
95
96 df_key = infile2 + repr(header)
97 if df_key in loaded_df:
98 infile2 = loaded_df[df_key]
99 else:
100 infile2 = pd.read_csv(infile2, sep='\t',
101 header=header, parse_dates=True)
102 loaded_df[df_key] = infile2
103
104 y = read_columns(
105 infile2,
106 c=c,
107 c_option=column_option,
108 sep='\t',
109 header=header,
110 parse_dates=True)
111 if len(y.shape) == 2 and y.shape[1] == 1:
112 y = y.ravel()
113
114 return X, y
115
116
117 def main(inputs, infile_estimator, infile1, infile2, out_object,
118 out_weights=None):
119 """ main
120
121 Parameters
122 ----------
123 inputs : str
124 File path to galaxy tool parameter
125
126 infile_estimator : str
127 File paths of input estimator
128
129 infile1 : str
130 File path to dataset containing features
131
132 infile2 : str
133 File path to dataset containing target labels
134
135 out_object : str
136 File path for output of fitted model or skeleton
137
138 out_weights : str
139 File path for output of weights
140
141 """
142 with open(inputs, 'r') as param_handler:
143 params = json.load(param_handler)
144
145 # load model
146 with open(infile_estimator, 'rb') as est_handler:
147 estimator = load_model(est_handler)
148 estimator = clean_params(estimator, n_jobs=N_JOBS)
149
150 X_train, y_train = _get_X_y(params, infile1, infile2)
151
152 estimator.fit(X_train, y_train)
153
154 main_est = estimator
155 if isinstance(main_est, Pipeline):
156 main_est = main_est.steps[-1][-1]
157 if hasattr(main_est, 'model_') \
158 and hasattr(main_est, 'save_weights'):
159 if out_weights:
160 main_est.save_weights(out_weights)
161 del main_est.model_
162 del main_est.fit_params
163 del main_est.model_class_
164 del main_est.validation_data
165 if getattr(main_est, 'data_generator_', None):
166 del main_est.data_generator_
167
168 with open(out_object, 'wb') as output_handler:
169 pickle.dump(estimator, output_handler,
170 pickle.HIGHEST_PROTOCOL)
171
172
173 if __name__ == '__main__':
174 aparser = argparse.ArgumentParser()
175 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
176 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator")
177 aparser.add_argument("-y", "--infile1", dest="infile1")
178 aparser.add_argument("-g", "--infile2", dest="infile2")
179 aparser.add_argument("-o", "--out_object", dest="out_object")
180 aparser.add_argument("-t", "--out_weights", dest="out_weights")
181 args = aparser.parse_args()
182
183 main(args.inputs, args.infile_estimator, args.infile1,
184 args.infile2, args.out_object, args.out_weights)