comparison simple_model_fit.py @ 0:af2624d5ab32 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:24:32 +0000
parents
children 9349ed2749c6
comparison
equal deleted inserted replaced
-1:000000000000 0:af2624d5ab32
1 import argparse
2 import json
3 import pickle
4
5 import pandas as pd
6 from galaxy_ml.utils import load_model, read_columns
7 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline
9
10 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
11
12
13 # TODO import from galaxy_ml.utils in future versions
14 def clean_params(estimator, n_jobs=None):
15 """clean unwanted hyperparameter settings
16
17 If n_jobs is not None, set it into the estimator, if applicable
18
19 Return
20 ------
21 Cleaned estimator object
22 """
23 ALLOWED_CALLBACKS = (
24 "EarlyStopping",
25 "TerminateOnNaN",
26 "ReduceLROnPlateau",
27 "CSVLogger",
28 "None",
29 )
30
31 estimator_params = estimator.get_params()
32
33 for name, p in estimator_params.items():
34 # all potential unauthorized file write
35 if name == "memory" or name.endswith("__memory") or name.endswith("_path"):
36 new_p = {name: None}
37 estimator.set_params(**new_p)
38 elif n_jobs is not None and (name == "n_jobs" or name.endswith("__n_jobs")):
39 new_p = {name: n_jobs}
40 estimator.set_params(**new_p)
41 elif name.endswith("callbacks"):
42 for cb in p:
43 cb_type = cb["callback_selection"]["callback_type"]
44 if cb_type not in ALLOWED_CALLBACKS:
45 raise ValueError("Prohibited callback type: %s!" % cb_type)
46
47 return estimator
48
49
50 def _get_X_y(params, infile1, infile2):
51 """read from inputs and output X and y
52
53 Parameters
54 ----------
55 params : dict
56 Tool inputs parameter
57 infile1 : str
58 File path to dataset containing features
59 infile2 : str
60 File path to dataset containing target values
61
62 """
63 # store read dataframe object
64 loaded_df = {}
65
66 input_type = params["input_options"]["selected_input"]
67 # tabular input
68 if input_type == "tabular":
69 header = "infer" if params["input_options"]["header1"] else None
70 column_option = params["input_options"]["column_selector_options_1"][
71 "selected_column_selector_option"
72 ]
73 if column_option in [
74 "by_index_number",
75 "all_but_by_index_number",
76 "by_header_name",
77 "all_but_by_header_name",
78 ]:
79 c = params["input_options"]["column_selector_options_1"]["col1"]
80 else:
81 c = None
82
83 df_key = infile1 + repr(header)
84 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
85 loaded_df[df_key] = df
86
87 X = read_columns(df, c=c, c_option=column_option).astype(float)
88 # sparse input
89 elif input_type == "sparse":
90 X = mmread(open(infile1, "r"))
91
92 # Get target y
93 header = "infer" if params["input_options"]["header2"] else None
94 column_option = params["input_options"]["column_selector_options_2"][
95 "selected_column_selector_option2"
96 ]
97 if column_option in [
98 "by_index_number",
99 "all_but_by_index_number",
100 "by_header_name",
101 "all_but_by_header_name",
102 ]:
103 c = params["input_options"]["column_selector_options_2"]["col2"]
104 else:
105 c = None
106
107 df_key = infile2 + repr(header)
108 if df_key in loaded_df:
109 infile2 = loaded_df[df_key]
110 else:
111 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True)
112 loaded_df[df_key] = infile2
113
114 y = read_columns(
115 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True
116 )
117 if len(y.shape) == 2 and y.shape[1] == 1:
118 y = y.ravel()
119
120 return X, y
121
122
123 def main(inputs, infile_estimator, infile1, infile2, out_object, out_weights=None):
124 """main
125
126 Parameters
127 ----------
128 inputs : str
129 File path to galaxy tool parameter
130
131 infile_estimator : str
132 File paths of input estimator
133
134 infile1 : str
135 File path to dataset containing features
136
137 infile2 : str
138 File path to dataset containing target labels
139
140 out_object : str
141 File path for output of fitted model or skeleton
142
143 out_weights : str
144 File path for output of weights
145
146 """
147 with open(inputs, "r") as param_handler:
148 params = json.load(param_handler)
149
150 # load model
151 with open(infile_estimator, "rb") as est_handler:
152 estimator = load_model(est_handler)
153 estimator = clean_params(estimator, n_jobs=N_JOBS)
154
155 X_train, y_train = _get_X_y(params, infile1, infile2)
156
157 estimator.fit(X_train, y_train)
158
159 main_est = estimator
160 if isinstance(main_est, Pipeline):
161 main_est = main_est.steps[-1][-1]
162 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"):
163 if out_weights:
164 main_est.save_weights(out_weights)
165 del main_est.model_
166 del main_est.fit_params
167 del main_est.model_class_
168 if getattr(main_est, "validation_data", None):
169 del main_est.validation_data
170 if getattr(main_est, "data_generator_", None):
171 del main_est.data_generator_
172
173 with open(out_object, "wb") as output_handler:
174 pickle.dump(estimator, output_handler, pickle.HIGHEST_PROTOCOL)
175
176
177 if __name__ == "__main__":
178 aparser = argparse.ArgumentParser()
179 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
180 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator")
181 aparser.add_argument("-y", "--infile1", dest="infile1")
182 aparser.add_argument("-g", "--infile2", dest="infile2")
183 aparser.add_argument("-o", "--out_object", dest="out_object")
184 aparser.add_argument("-t", "--out_weights", dest="out_weights")
185 args = aparser.parse_args()
186
187 main(
188 args.inputs,
189 args.infile_estimator,
190 args.infile1,
191 args.infile2,
192 args.out_object,
193 args.out_weights,
194 )