comparison simple_model_fit.py @ 0:0985b0dd6f1a draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
author bgruening
date Fri, 01 Nov 2019 17:26:59 -0400
parents
children 910ebff96ddc
comparison
equal deleted inserted replaced
-1:000000000000 0:0985b0dd6f1a
1 import argparse
2 import json
3 import pandas as pd
4 import pickle
5
6 from galaxy_ml.utils import load_model, read_columns
7 from sklearn.pipeline import Pipeline
8
9
10 def _get_X_y(params, infile1, infile2):
11 """ read from inputs and output X and y
12
13 Parameters
14 ----------
15 params : dict
16 Tool inputs parameter
17 infile1 : str
18 File path to dataset containing features
19 infile2 : str
20 File path to dataset containing target values
21
22 """
23 # store read dataframe object
24 loaded_df = {}
25
26 input_type = params['input_options']['selected_input']
27 # tabular input
28 if input_type == 'tabular':
29 header = 'infer' if params['input_options']['header1'] else None
30 column_option = (params['input_options']['column_selector_options_1']
31 ['selected_column_selector_option'])
32 if column_option in ['by_index_number', 'all_but_by_index_number',
33 'by_header_name', 'all_but_by_header_name']:
34 c = params['input_options']['column_selector_options_1']['col1']
35 else:
36 c = None
37
38 df_key = infile1 + repr(header)
39 df = pd.read_csv(infile1, sep='\t', header=header,
40 parse_dates=True)
41 loaded_df[df_key] = df
42
43 X = read_columns(df, c=c, c_option=column_option).astype(float)
44 # sparse input
45 elif input_type == 'sparse':
46 X = mmread(open(infile1, 'r'))
47
48 # Get target y
49 header = 'infer' if params['input_options']['header2'] else None
50 column_option = (params['input_options']['column_selector_options_2']
51 ['selected_column_selector_option2'])
52 if column_option in ['by_index_number', 'all_but_by_index_number',
53 'by_header_name', 'all_but_by_header_name']:
54 c = params['input_options']['column_selector_options_2']['col2']
55 else:
56 c = None
57
58 df_key = infile2 + repr(header)
59 if df_key in loaded_df:
60 infile2 = loaded_df[df_key]
61 else:
62 infile2 = pd.read_csv(infile2, sep='\t',
63 header=header, parse_dates=True)
64 loaded_df[df_key] = infile2
65
66 y = read_columns(
67 infile2,
68 c=c,
69 c_option=column_option,
70 sep='\t',
71 header=header,
72 parse_dates=True)
73 if len(y.shape) == 2 and y.shape[1] == 1:
74 y = y.ravel()
75
76 return X, y
77
78
79 def main(inputs, infile_estimator, infile1, infile2, out_object,
80 out_weights=None):
81 """ main
82
83 Parameters
84 ----------
85 inputs : str
86 File path to galaxy tool parameter
87
88 infile_estimator : str
89 File paths of input estimator
90
91 infile1 : str
92 File path to dataset containing features
93
94 infile2 : str
95 File path to dataset containing target labels
96
97 out_object : str
98 File path for output of fitted model or skeleton
99
100 out_weights : str
101 File path for output of weights
102
103 """
104 with open(inputs, 'r') as param_handler:
105 params = json.load(param_handler)
106
107 # load model
108 with open(infile_estimator, 'rb') as est_handler:
109 estimator = load_model(est_handler)
110
111 X_train, y_train = _get_X_y(params, infile1, infile2)
112
113 estimator.fit(X_train, y_train)
114
115 main_est = estimator
116 if isinstance(main_est, Pipeline):
117 main_est = main_est.steps[-1][-1]
118 if hasattr(main_est, 'model_') \
119 and hasattr(main_est, 'save_weights'):
120 if out_weights:
121 main_est.save_weights(out_weights)
122 del main_est.model_
123 del main_est.fit_params
124 del main_est.model_class_
125 del main_est.validation_data
126 if getattr(main_est, 'data_generator_', None):
127 del main_est.data_generator_
128
129 with open(out_object, 'wb') as output_handler:
130 pickle.dump(estimator, output_handler,
131 pickle.HIGHEST_PROTOCOL)
132
133
134 if __name__ == '__main__':
135 aparser = argparse.ArgumentParser()
136 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
137 aparser.add_argument("-X", "--infile_estimator", dest="infile_estimator")
138 aparser.add_argument("-y", "--infile1", dest="infile1")
139 aparser.add_argument("-g", "--infile2", dest="infile2")
140 aparser.add_argument("-o", "--out_object", dest="out_object")
141 aparser.add_argument("-t", "--out_weights", dest="out_weights")
142 args = aparser.parse_args()
143
144 main(args.inputs, args.infile_estimator, args.infile1,
145 args.infile2, args.out_object, args.out_weights)