Mercurial > repos > bgruening > keras_train_and_eval
comparison model_prediction.py @ 4:3866911c93ae draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 18:37:35 +0000 |
parents | 03f61bb3ca43 |
children | 785dd890e27d |
comparison
equal
deleted
inserted
replaced
3:f1b9a42d6809 | 4:3866911c93ae |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import warnings | |
4 | |
3 import numpy as np | 5 import numpy as np |
4 import pandas as pd | 6 import pandas as pd |
5 import warnings | |
6 | |
7 from scipy.io import mmread | 7 from scipy.io import mmread |
8 from sklearn.pipeline import Pipeline | 8 from sklearn.pipeline import Pipeline |
9 | 9 |
10 from galaxy_ml.utils import (load_model, read_columns, | 10 from galaxy_ml.utils import (get_module, load_model, |
11 get_module, try_get_attr) | 11 read_columns, try_get_attr) |
12 | 12 |
13 | 13 |
14 N_JOBS = int(__import__('os').environ.get('GALAXY_SLOTS', 1)) | 14 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) |
15 | 15 |
16 | 16 |
17 def main(inputs, infile_estimator, outfile_predict, | 17 def main( |
18 infile_weights=None, infile1=None, | 18 inputs, |
19 fasta_path=None, ref_seq=None, | 19 infile_estimator, |
20 vcf_path=None): | 20 outfile_predict, |
21 infile_weights=None, | |
22 infile1=None, | |
23 fasta_path=None, | |
24 ref_seq=None, | |
25 vcf_path=None, | |
26 ): | |
21 """ | 27 """ |
22 Parameter | 28 Parameter |
23 --------- | 29 --------- |
24 inputs : str | 30 inputs : str |
25 File path to galaxy tool parameter | 31 File path to galaxy tool parameter |
43 File path to dataset containing the reference genome sequence. | 49 File path to dataset containing the reference genome sequence. |
44 | 50 |
45 vcf_path : str | 51 vcf_path : str |
46 File path to dataset containing variants info. | 52 File path to dataset containing variants info. |
47 """ | 53 """ |
48 warnings.filterwarnings('ignore') | 54 warnings.filterwarnings("ignore") |
49 | 55 |
50 with open(inputs, 'r') as param_handler: | 56 with open(inputs, "r") as param_handler: |
51 params = json.load(param_handler) | 57 params = json.load(param_handler) |
52 | 58 |
53 # load model | 59 # load model |
54 with open(infile_estimator, 'rb') as est_handler: | 60 with open(infile_estimator, "rb") as est_handler: |
55 estimator = load_model(est_handler) | 61 estimator = load_model(est_handler) |
56 | 62 |
57 main_est = estimator | 63 main_est = estimator |
58 if isinstance(estimator, Pipeline): | 64 if isinstance(estimator, Pipeline): |
59 main_est = estimator.steps[-1][-1] | 65 main_est = estimator.steps[-1][-1] |
60 if hasattr(main_est, 'config') and hasattr(main_est, 'load_weights'): | 66 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): |
61 if not infile_weights or infile_weights == 'None': | 67 if not infile_weights or infile_weights == "None": |
62 raise ValueError("The selected model skeleton asks for weights, " | 68 raise ValueError( |
63 "but dataset for weights wan not selected!") | 69 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" |
70 ) | |
64 main_est.load_weights(infile_weights) | 71 main_est.load_weights(infile_weights) |
65 | 72 |
66 # handle data input | 73 # handle data input |
67 input_type = params['input_options']['selected_input'] | 74 input_type = params["input_options"]["selected_input"] |
68 # tabular input | 75 # tabular input |
69 if input_type == 'tabular': | 76 if input_type == "tabular": |
70 header = 'infer' if params['input_options']['header1'] else None | 77 header = "infer" if params["input_options"]["header1"] else None |
71 column_option = (params['input_options'] | 78 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] |
72 ['column_selector_options_1'] | 79 if column_option in [ |
73 ['selected_column_selector_option']) | 80 "by_index_number", |
74 if column_option in ['by_index_number', 'all_but_by_index_number', | 81 "all_but_by_index_number", |
75 'by_header_name', 'all_but_by_header_name']: | 82 "by_header_name", |
76 c = params['input_options']['column_selector_options_1']['col1'] | 83 "all_but_by_header_name", |
84 ]: | |
85 c = params["input_options"]["column_selector_options_1"]["col1"] | |
77 else: | 86 else: |
78 c = None | 87 c = None |
79 | 88 |
80 df = pd.read_csv(infile1, sep='\t', header=header, parse_dates=True) | 89 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) |
81 | 90 |
82 X = read_columns(df, c=c, c_option=column_option).astype(float) | 91 X = read_columns(df, c=c, c_option=column_option).astype(float) |
83 | 92 |
84 if params['method'] == 'predict': | 93 if params["method"] == "predict": |
85 preds = estimator.predict(X) | 94 preds = estimator.predict(X) |
86 else: | 95 else: |
87 preds = estimator.predict_proba(X) | 96 preds = estimator.predict_proba(X) |
88 | 97 |
89 # sparse input | 98 # sparse input |
90 elif input_type == 'sparse': | 99 elif input_type == "sparse": |
91 X = mmread(open(infile1, 'r')) | 100 X = mmread(open(infile1, "r")) |
92 if params['method'] == 'predict': | 101 if params["method"] == "predict": |
93 preds = estimator.predict(X) | 102 preds = estimator.predict(X) |
94 else: | 103 else: |
95 preds = estimator.predict_proba(X) | 104 preds = estimator.predict_proba(X) |
96 | 105 |
97 # fasta input | 106 # fasta input |
98 elif input_type == 'seq_fasta': | 107 elif input_type == "seq_fasta": |
99 if not hasattr(estimator, 'data_batch_generator'): | 108 if not hasattr(estimator, "data_batch_generator"): |
100 raise ValueError( | 109 raise ValueError( |
101 "To do prediction on sequences in fasta input, " | 110 "To do prediction on sequences in fasta input, " |
102 "the estimator must be a `KerasGBatchClassifier`" | 111 "the estimator must be a `KerasGBatchClassifier`" |
103 "equipped with data_batch_generator!") | 112 "equipped with data_batch_generator!" |
104 pyfaidx = get_module('pyfaidx') | 113 ) |
114 pyfaidx = get_module("pyfaidx") | |
105 sequences = pyfaidx.Fasta(fasta_path) | 115 sequences = pyfaidx.Fasta(fasta_path) |
106 n_seqs = len(sequences.keys()) | 116 n_seqs = len(sequences.keys()) |
107 X = np.arange(n_seqs)[:, np.newaxis] | 117 X = np.arange(n_seqs)[:, np.newaxis] |
108 seq_length = estimator.data_batch_generator.seq_length | 118 seq_length = estimator.data_batch_generator.seq_length |
109 batch_size = getattr(estimator, 'batch_size', 32) | 119 batch_size = getattr(estimator, "batch_size", 32) |
110 steps = (n_seqs + batch_size - 1) // batch_size | 120 steps = (n_seqs + batch_size - 1) // batch_size |
111 | 121 |
112 seq_type = params['input_options']['seq_type'] | 122 seq_type = params["input_options"]["seq_type"] |
113 klass = try_get_attr( | 123 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) |
114 'galaxy_ml.preprocessors', seq_type) | 124 |
115 | 125 pred_data_generator = klass(fasta_path, seq_length=seq_length) |
116 pred_data_generator = klass( | 126 |
117 fasta_path, seq_length=seq_length) | 127 if params["method"] == "predict": |
118 | 128 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) |
119 if params['method'] == 'predict': | 129 else: |
120 preds = estimator.predict( | 130 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) |
121 X, data_generator=pred_data_generator, steps=steps) | |
122 else: | |
123 preds = estimator.predict_proba( | |
124 X, data_generator=pred_data_generator, steps=steps) | |
125 | 131 |
126 # vcf input | 132 # vcf input |
127 elif input_type == 'variant_effect': | 133 elif input_type == "variant_effect": |
128 klass = try_get_attr('galaxy_ml.preprocessors', | 134 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") |
129 'GenomicVariantBatchGenerator') | 135 |
130 | 136 options = params["input_options"] |
131 options = params['input_options'] | 137 options.pop("selected_input") |
132 options.pop('selected_input') | 138 if options["blacklist_regions"] == "none": |
133 if options['blacklist_regions'] == 'none': | 139 options["blacklist_regions"] = None |
134 options['blacklist_regions'] = None | 140 |
135 | 141 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) |
136 pred_data_generator = klass( | |
137 ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
138 | 142 |
139 pred_data_generator.set_processing_attrs() | 143 pred_data_generator.set_processing_attrs() |
140 | 144 |
141 variants = pred_data_generator.variants | 145 variants = pred_data_generator.variants |
142 | 146 |
143 # predict 1600 sample at once then write to file | 147 # predict 1600 sample at once then write to file |
144 gen_flow = pred_data_generator.flow(batch_size=1600) | 148 gen_flow = pred_data_generator.flow(batch_size=1600) |
145 | 149 |
146 file_writer = open(outfile_predict, 'w') | 150 file_writer = open(outfile_predict, "w") |
147 header_row = '\t'.join(['chrom', 'pos', 'name', 'ref', | 151 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) |
148 'alt', 'strand']) | |
149 file_writer.write(header_row) | 152 file_writer.write(header_row) |
150 header_done = False | 153 header_done = False |
151 | 154 |
152 steps_done = 0 | 155 steps_done = 0 |
153 | 156 |
154 # TODO: multiple threading | 157 # TODO: multiple threading |
155 try: | 158 try: |
156 while steps_done < len(gen_flow): | 159 while steps_done < len(gen_flow): |
157 index_array = next(gen_flow.index_generator) | 160 index_array = next(gen_flow.index_generator) |
158 batch_X = gen_flow._get_batches_of_transformed_samples( | 161 batch_X = gen_flow._get_batches_of_transformed_samples(index_array) |
159 index_array) | 162 |
160 | 163 if params["method"] == "predict": |
161 if params['method'] == 'predict': | |
162 batch_preds = estimator.predict( | 164 batch_preds = estimator.predict( |
163 batch_X, | 165 batch_X, |
164 # The presence of `pred_data_generator` below is to | 166 # The presence of `pred_data_generator` below is to |
165 # override model carrying data_generator if there | 167 # override model carrying data_generator if there |
166 # is any. | 168 # is any. |
167 data_generator=pred_data_generator) | 169 data_generator=pred_data_generator, |
170 ) | |
168 else: | 171 else: |
169 batch_preds = estimator.predict_proba( | 172 batch_preds = estimator.predict_proba( |
170 batch_X, | 173 batch_X, |
171 # The presence of `pred_data_generator` below is to | 174 # The presence of `pred_data_generator` below is to |
172 # override model carrying data_generator if there | 175 # override model carrying data_generator if there |
173 # is any. | 176 # is any. |
174 data_generator=pred_data_generator) | 177 data_generator=pred_data_generator, |
178 ) | |
175 | 179 |
176 if batch_preds.ndim == 1: | 180 if batch_preds.ndim == 1: |
177 batch_preds = batch_preds[:, np.newaxis] | 181 batch_preds = batch_preds[:, np.newaxis] |
178 | 182 |
179 batch_meta = variants[index_array] | 183 batch_meta = variants[index_array] |
180 batch_out = np.column_stack([batch_meta, batch_preds]) | 184 batch_out = np.column_stack([batch_meta, batch_preds]) |
181 | 185 |
182 if not header_done: | 186 if not header_done: |
183 heads = np.arange(batch_preds.shape[-1]).astype(str) | 187 heads = np.arange(batch_preds.shape[-1]).astype(str) |
184 heads_str = '\t'.join(heads) | 188 heads_str = "\t".join(heads) |
185 file_writer.write("\t%s\n" % heads_str) | 189 file_writer.write("\t%s\n" % heads_str) |
186 header_done = True | 190 header_done = True |
187 | 191 |
188 for row in batch_out: | 192 for row in batch_out: |
189 row_str = '\t'.join(row) | 193 row_str = "\t".join(row) |
190 file_writer.write("%s\n" % row_str) | 194 file_writer.write("%s\n" % row_str) |
191 | 195 |
192 steps_done += 1 | 196 steps_done += 1 |
193 | 197 |
194 finally: | 198 finally: |
198 return 0 | 202 return 0 |
199 # end input | 203 # end input |
200 | 204 |
201 # output | 205 # output |
202 if len(preds.shape) == 1: | 206 if len(preds.shape) == 1: |
203 rval = pd.DataFrame(preds, columns=['Predicted']) | 207 rval = pd.DataFrame(preds, columns=["Predicted"]) |
204 else: | 208 else: |
205 rval = pd.DataFrame(preds) | 209 rval = pd.DataFrame(preds) |
206 | 210 |
207 rval.to_csv(outfile_predict, sep='\t', header=True, index=False) | 211 rval.to_csv(outfile_predict, sep="\t", header=True, index=False) |
208 | 212 |
209 | 213 |
210 if __name__ == '__main__': | 214 if __name__ == "__main__": |
211 aparser = argparse.ArgumentParser() | 215 aparser = argparse.ArgumentParser() |
212 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 216 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
213 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | 217 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") |
214 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | 218 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") |
215 aparser.add_argument("-X", "--infile1", dest="infile1") | 219 aparser.add_argument("-X", "--infile1", dest="infile1") |
217 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 221 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
218 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 222 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
219 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | 223 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") |
220 args = aparser.parse_args() | 224 args = aparser.parse_args() |
221 | 225 |
222 main(args.inputs, args.infile_estimator, args.outfile_predict, | 226 main( |
223 infile_weights=args.infile_weights, infile1=args.infile1, | 227 args.inputs, |
224 fasta_path=args.fasta_path, ref_seq=args.ref_seq, | 228 args.infile_estimator, |
225 vcf_path=args.vcf_path) | 229 args.outfile_predict, |
230 infile_weights=args.infile_weights, | |
231 infile1=args.infile1, | |
232 fasta_path=args.fasta_path, | |
233 ref_seq=args.ref_seq, | |
234 vcf_path=args.vcf_path, | |
235 ) |