comparison model_prediction.py @ 0:2d7016b3ae92 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2afb24f3c81d625312186750a714d702363012b5"
author bgruening
date Fri, 02 Oct 2020 08:45:21 +0000
parents
children 132805688fa3
comparison
equal deleted inserted replaced
-1:000000000000 0:2d7016b3ae92
1 import argparse
2 import json
3 import numpy as np
4 import pandas as pd
5 import warnings
6
7 from scipy.io import mmread
8 from sklearn.pipeline import Pipeline
9
10 from galaxy_ml.utils import (load_model, read_columns,
11 get_module, try_get_attr)
12
13
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1))
15
16
17 def main(inputs, infile_estimator, outfile_predict,
18 infile_weights=None, infile1=None,
19 fasta_path=None, ref_seq=None,
20 vcf_path=None):
21 """
22 Parameter
23 ---------
24 inputs : str
25 File path to galaxy tool parameter
26
27 infile_estimator : strgit
28 File path to trained estimator input
29
30 outfile_predict : str
31 File path to save the prediction results, tabular
32
33 infile_weights : str
34 File path to weights input
35
36 infile1 : str
37 File path to dataset containing features
38
39 fasta_path : str
40 File path to dataset containing fasta file
41
42 ref_seq : str
43 File path to dataset containing the reference genome sequence.
44
45 vcf_path : str
46 File path to dataset containing variants info.
47 """
48 warnings.filterwarnings('ignore')
49
50 with open(inputs, 'r') as param_handler:
51 params = json.load(param_handler)
52
53 # load model
54 with open(infile_estimator, 'rb') as est_handler:
55 estimator = load_model(est_handler)
56
57 main_est = estimator
58 if isinstance(estimator, Pipeline):
59 main_est = estimator.steps[-1][-1]
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'):
61 if not infile_weights or infile_weights == 'None':
62 raise ValueError("The selected model skeleton asks for weights, "
63 "but dataset for weights wan not selected!")
64 main_est.load_weights(infile_weights)
65
66 # handle data input
67 input_type = params['input_options']['selected_input']
68 # tabular input
69 if input_type == 'tabular':
70 header = 'infer' if params['input_options']['header1'] else None
71 column_option = (params['input_options']
72 ['column_selector_options_1']
73 ['selected_column_selector_option'])
74 if column_option in ['by_index_number', 'all_but_by_index_number',
75 'by_header_name', 'all_but_by_header_name']:
76 c = params['input_options']['column_selector_options_1']['col1']
77 else:
78 c = None
79
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True)
81
82 X = read_columns(df, c=c, c_option=column_option).astype(float)
83
84 if params['method'] == 'predict':
85 preds = estimator.predict(X)
86 else:
87 preds = estimator.predict_proba(X)
88
89 # sparse input
90 elif input_type == 'sparse':
91 X = mmread(open(infile1, 'r'))
92 if params['method'] == 'predict':
93 preds = estimator.predict(X)
94 else:
95 preds = estimator.predict_proba(X)
96
97 # fasta input
98 elif input_type == 'seq_fasta':
99 if not hasattr(estimator, 'data_batch_generator'):
100 raise ValueError(
101 "To do prediction on sequences in fasta input, "
102 "the estimator must be a `KerasGBatchClassifier`"
103 "equipped with data_batch_generator!")
104 pyfaidx = get_module('pyfaidx')
105 sequences = pyfaidx.Fasta(fasta_path)
106 n_seqs = len(sequences.keys())
107 X = np.arange(n_seqs)[:, np.newaxis]
108 seq_length = estimator.data_batch_generator.seq_length
109 batch_size = getattr(estimator, 'batch_size', 32)
110 steps = (n_seqs + batch_size - 1) // batch_size
111
112 seq_type = params['input_options']['seq_type']
113 klass = try_get_attr(
114 'galaxy_ml.preprocessors', seq_type)
115
116 pred_data_generator = klass(
117 fasta_path, seq_length=seq_length)
118
119 if params['method'] == 'predict':
120 preds = estimator.predict(
121 X, data_generator=pred_data_generator, steps=steps)
122 else:
123 preds = estimator.predict_proba(
124 X, data_generator=pred_data_generator, steps=steps)
125
126 # vcf input
127 elif input_type == 'variant_effect':
128 klass = try_get_attr('galaxy_ml.preprocessors',
129 'GenomicVariantBatchGenerator')
130
131 options = params['input_options']
132 options.pop('selected_input')
133 if options['blacklist_regions'] == 'none':
134 options['blacklist_regions'] = None
135
136 pred_data_generator = klass(
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
138
139 pred_data_generator.set_processing_attrs()
140
141 variants = pred_data_generator.variants
142
143 # predict 1600 sample at once then write to file
144 gen_flow = pred_data_generator.flow(batch_size=1600)
145
146 file_writer = open(outfile_predict, 'w')
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref',
148 'alt', 'strand'])
149 file_writer.write(header_row)
150 header_done = False
151
152 steps_done = 0
153
154 # TODO: multiple threading
155 try:
156 while steps_done < len(gen_flow):
157 index_array = next(gen_flow.index_generator)
158 batch_X = gen_flow._get_batches_of_transformed_samples(
159 index_array)
160
161 if params['method'] == 'predict':
162 batch_preds = estimator.predict(
163 batch_X,
164 # The presence of `pred_data_generator` below is to
165 # override model carrying data_generator if there
166 # is any.
167 data_generator=pred_data_generator)
168 else:
169 batch_preds = estimator.predict_proba(
170 batch_X,
171 # The presence of `pred_data_generator` below is to
172 # override model carrying data_generator if there
173 # is any.
174 data_generator=pred_data_generator)
175
176 if batch_preds.ndim == 1:
177 batch_preds = batch_preds[:, np.newaxis]
178
179 batch_meta = variants[index_array]
180 batch_out = np.column_stack([batch_meta, batch_preds])
181
182 if not header_done:
183 heads = np.arange(batch_preds.shape[-1]).astype(str)
184 heads_str = '\t'.join(heads)
185 file_writer.write("\t%s\n" % heads_str)
186 header_done = True
187
188 for row in batch_out:
189 row_str = '\t'.join(row)
190 file_writer.write("%s\n" % row_str)
191
192 steps_done += 1
193
194 finally:
195 file_writer.close()
196 # TODO: make api `pred_data_generator.close()`
197 pred_data_generator.close()
198 return 0
199 # end input
200
201 # output
202 if len(preds.shape) == 1:
203 rval = pd.DataFrame(preds, columns=['Predicted'])
204 else:
205 rval = pd.DataFrame(preds)
206
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False)
208
209
210 if __name__ == '__main__':
211 aparser = argparse.ArgumentParser()
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
215 aparser.add_argument("-X", "--infile1", dest="infile1")
216 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
220 args = aparser.parse_args()
221
222 main(args.inputs, args.infile_estimator, args.outfile_predict,
223 infile_weights=args.infile_weights, infile1=args.infile1,
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq,
225 vcf_path=args.vcf_path)