Mercurial > repos > bgruening > sklearn_train_test_split
comparison model_prediction.py @ 0:0985b0dd6f1a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit eb703290e2589561ea215c84aa9f71bcfe1712c6"
author | bgruening |
---|---|
date | Fri, 01 Nov 2019 17:26:59 -0400 |
parents | |
children | 5a092779412e |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:0985b0dd6f1a |
---|---|
1 import argparse | |
2 import json | |
3 import numpy as np | |
4 import pandas as pd | |
5 import tabix | |
6 import warnings | |
7 | |
8 from scipy.io import mmread | |
9 from sklearn.pipeline import Pipeline | |
10 | |
11 from galaxy_ml.externals.selene_sdk.sequences import Genome | |
12 from galaxy_ml.utils import (load_model, read_columns, | |
13 get_module, try_get_attr) | |
14 | |
15 | |
16 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | |
17 | |
18 | |
19 def main(inputs, infile_estimator, outfile_predict, | |
20 infile_weights=None, infile1=None, | |
21 fasta_path=None, ref_seq=None, | |
22 vcf_path=None): | |
23 """ | |
24 Parameter | |
25 --------- | |
26 inputs : str | |
27 File path to galaxy tool parameter | |
28 | |
29 infile_estimator : strgit | |
30 File path to trained estimator input | |
31 | |
32 outfile_predict : str | |
33 File path to save the prediction results, tabular | |
34 | |
35 infile_weights : str | |
36 File path to weights input | |
37 | |
38 infile1 : str | |
39 File path to dataset containing features | |
40 | |
41 fasta_path : str | |
42 File path to dataset containing fasta file | |
43 | |
44 ref_seq : str | |
45 File path to dataset containing the reference genome sequence. | |
46 | |
47 vcf_path : str | |
48 File path to dataset containing variants info. | |
49 """ | |
50 warnings.filterwarnings('ignore') | |
51 | |
52 with open(inputs, 'r') as param_handler: | |
53 params = json.load(param_handler) | |
54 | |
55 # load model | |
56 with open(infile_estimator, 'rb') as est_handler: | |
57 estimator = load_model(est_handler) | |
58 | |
59 main_est = estimator | |
60 if isinstance(estimator, Pipeline): | |
61 main_est = estimator.steps[-1][-1] | |
62 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | |
63 if not infile_weights or infile_weights == 'None': | |
64 raise ValueError("The selected model skeleton asks for weights, " | |
65 "but dataset for weights wan not selected!") | |
66 main_est.load_weights(infile_weights) | |
67 | |
68 # handle data input | |
69 input_type = params['input_options']['selected_input'] | |
70 # tabular input | |
71 if input_type == 'tabular': | |
72 header = 'infer' if params['input_options']['header1'] else None | |
73 column_option = (params['input_options'] | |
74 ['column_selector_options_1'] | |
75 ['selected_column_selector_option']) | |
76 if column_option in ['by_index_number', 'all_but_by_index_number', | |
77 'by_header_name', 'all_but_by_header_name']: | |
78 c = params['input_options']['column_selector_options_1']['col1'] | |
79 else: | |
80 c = None | |
81 | |
82 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | |
83 | |
84 X = read_columns(df, c=c, c_option=column_option).astype(float) | |
85 | |
86 if params['method'] == 'predict': | |
87 preds = estimator.predict(X) | |
88 else: | |
89 preds = estimator.predict_proba(X) | |
90 | |
91 # sparse input | |
92 elif input_type == 'sparse': | |
93 X = mmread(open(infile1, 'r')) | |
94 if params['method'] == 'predict': | |
95 preds = estimator.predict(X) | |
96 else: | |
97 preds = estimator.predict_proba(X) | |
98 | |
99 # fasta input | |
100 elif input_type == 'seq_fasta': | |
101 if not hasattr(estimator, 'data_batch_generator'): | |
102 raise ValueError( | |
103 "To do prediction on sequences in fasta input, " | |
104 "the estimator must be a `KerasGBatchClassifier`" | |
105 "equipped with data_batch_generator!") | |
106 pyfaidx = get_module('pyfaidx') | |
107 sequences = pyfaidx.Fasta(fasta_path) | |
108 n_seqs = len(sequences.keys()) | |
109 X = np.arange(n_seqs)[:, np.newaxis] | |
110 seq_length = estimator.data_batch_generator.seq_length | |
111 batch_size = getattr(estimator, 'batch_size', 32) | |
112 steps = (n_seqs + batch_size - 1) // batch_size | |
113 | |
114 seq_type = params['input_options']['seq_type'] | |
115 klass = try_get_attr( | |
116 'galaxy_ml.preprocessors', seq_type) | |
117 | |
118 pred_data_generator = klass( | |
119 fasta_path, seq_length=seq_length) | |
120 | |
121 if params['method'] == 'predict': | |
122 preds = estimator.predict( | |
123 X, data_generator=pred_data_generator, steps=steps) | |
124 else: | |
125 preds = estimator.predict_proba( | |
126 X, data_generator=pred_data_generator, steps=steps) | |
127 | |
128 # vcf input | |
129 elif input_type == 'variant_effect': | |
130 klass = try_get_attr('galaxy_ml.preprocessors', | |
131 'GenomicVariantBatchGenerator') | |
132 | |
133 options = params['input_options'] | |
134 options.pop('selected_input') | |
135 if options['blacklist_regions'] == 'none': | |
136 options['blacklist_regions'] = None | |
137 | |
138 pred_data_generator = klass( | |
139 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
140 | |
141 pred_data_generator.fit() | |
142 | |
143 variants = pred_data_generator.variants | |
144 # TODO : remove the following block after galaxy-ml v0.7.13 | |
145 blacklist_tabix = getattr(pred_data_generator.reference_genome_, | |
146 '_blacklist_tabix', None) | |
147 clean_variants = [] | |
148 if blacklist_tabix: | |
149 start_radius = pred_data_generator.start_radius_ | |
150 end_radius = pred_data_generator.end_radius_ | |
151 | |
152 for chrom, pos, name, ref, alt, strand in variants: | |
153 center = pos + len(ref) // 2 | |
154 start = center - start_radius | |
155 end = center + end_radius | |
156 | |
157 if isinstance(pred_data_generator.reference_genome_, Genome): | |
158 if "chr" not in chrom: | |
159 chrom = "chr" + chrom | |
160 if "MT" in chrom: | |
161 chrom = chrom[:-1] | |
162 try: | |
163 rows = blacklist_tabix.query(chrom, start, end) | |
164 found = 0 | |
165 for row in rows: | |
166 found = 1 | |
167 break | |
168 if found: | |
169 continue | |
170 except tabix.TabixError: | |
171 pass | |
172 | |
173 clean_variants.append((chrom, pos, name, ref, alt, strand)) | |
174 else: | |
175 clean_variants = variants | |
176 | |
177 setattr(pred_data_generator, 'variants', clean_variants) | |
178 | |
179 variants = np.array(clean_variants) | |
180 # predict 1600 sample at once then write to file | |
181 gen_flow = pred_data_generator.flow(batch_size=1600) | |
182 | |
183 file_writer = open(outfile_predict, 'w') | |
184 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', | |
185 'alt', 'strand']) | |
186 file_writer.write(header_row) | |
187 header_done = False | |
188 | |
189 steps_done = 0 | |
190 | |
191 # TODO: multiple threading | |
192 try: | |
193 while steps_done < len(gen_flow): | |
194 index_array = next(gen_flow.index_generator) | |
195 batch_X = gen_flow._get_batches_of_transformed_samples( | |
196 index_array) | |
197 | |
198 if params['method'] == 'predict': | |
199 batch_preds = estimator.predict( | |
200 batch_X, | |
201 # The presence of `pred_data_generator` below is to | |
202 # override model carrying data_generator if there | |
203 # is any. | |
204 data_generator=pred_data_generator) | |
205 else: | |
206 batch_preds = estimator.predict_proba( | |
207 batch_X, | |
208 # The presence of `pred_data_generator` below is to | |
209 # override model carrying data_generator if there | |
210 # is any. | |
211 data_generator=pred_data_generator) | |
212 | |
213 if batch_preds.ndim == 1: | |
214 batch_preds = batch_preds[:, np.newaxis] | |
215 | |
216 batch_meta = variants[index_array] | |
217 batch_out = np.column_stack([batch_meta, batch_preds]) | |
218 | |
219 if not header_done: | |
220 heads = np.arange(batch_preds.shape[-1]).astype(str) | |
221 heads_str = '\t'.join(heads) | |
222 file_writer.write("\t%s\n" % heads_str) | |
223 header_done = True | |
224 | |
225 for row in batch_out: | |
226 row_str = '\t'.join(row) | |
227 file_writer.write("%s\n" % row_str) | |
228 | |
229 steps_done += 1 | |
230 | |
231 finally: | |
232 file_writer.close() | |
233 # TODO: make api `pred_data_generator.close()` | |
234 pred_data_generator.close() | |
235 return 0 | |
236 # end input | |
237 | |
238 # output | |
239 if len(preds.shape) == 1: | |
240 rval = pd.DataFrame(preds, columns=['Predicted']) | |
241 else: | |
242 rval = pd.DataFrame(preds) | |
243 | |
244 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) | |
245 | |
246 | |
247 if __name__ == '__main__': | |
248 aparser = argparse.ArgumentParser() | |
249 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
250 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | |
251 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
252 aparser.add_argument("-X", "--infile1", dest="infile1") | |
253 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | |
254 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
255 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
256 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | |
257 args = aparser.parse_args() | |
258 | |
259 main(args.inputs, args.infile_estimator, args.outfile_predict, | |
260 infile_weights=args.infile_weights, infile1=args.infile1, | |
261 fasta_path=args.fasta_path, ref_seq=args.ref_seq, | |
262 vcf_path=args.vcf_path) |