Mercurial > repos > bgruening > sklearn_to_categorical
comparison model_prediction.py @ 0:59e8b4328c82 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 22:40:10 +0000 |
parents | |
children | f93f0cdbaf18 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:59e8b4328c82 |
---|---|
1 import argparse | |
2 import json | |
3 import warnings | |
4 | |
5 import numpy as np | |
6 import pandas as pd | |
7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr | |
8 from scipy.io import mmread | |
9 from sklearn.pipeline import Pipeline | |
10 | |
11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1)) | |
12 | |
13 | |
14 def main( | |
15 inputs, | |
16 infile_estimator, | |
17 outfile_predict, | |
18 infile_weights=None, | |
19 infile1=None, | |
20 fasta_path=None, | |
21 ref_seq=None, | |
22 vcf_path=None, | |
23 ): | |
24 """ | |
25 Parameter | |
26 --------- | |
27 inputs : str | |
28 File path to galaxy tool parameter | |
29 | |
30 infile_estimator : strgit | |
31 File path to trained estimator input | |
32 | |
33 outfile_predict : str | |
34 File path to save the prediction results, tabular | |
35 | |
36 infile_weights : str | |
37 File path to weights input | |
38 | |
39 infile1 : str | |
40 File path to dataset containing features | |
41 | |
42 fasta_path : str | |
43 File path to dataset containing fasta file | |
44 | |
45 ref_seq : str | |
46 File path to dataset containing the reference genome sequence. | |
47 | |
48 vcf_path : str | |
49 File path to dataset containing variants info. | |
50 """ | |
51 warnings.filterwarnings("ignore") | |
52 | |
53 with open(inputs, "r") as param_handler: | |
54 params = json.load(param_handler) | |
55 | |
56 # load model | |
57 with open(infile_estimator, "rb") as est_handler: | |
58 estimator = load_model(est_handler) | |
59 | |
60 main_est = estimator | |
61 if isinstance(estimator, Pipeline): | |
62 main_est = estimator.steps[-1][-1] | |
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): | |
64 if not infile_weights or infile_weights == "None": | |
65 raise ValueError( | |
66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" | |
67 ) | |
68 main_est.load_weights(infile_weights) | |
69 | |
70 # handle data input | |
71 input_type = params["input_options"]["selected_input"] | |
72 # tabular input | |
73 if input_type == "tabular": | |
74 header = "infer" if params["input_options"]["header1"] else None | |
75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] | |
76 if column_option in [ | |
77 "by_index_number", | |
78 "all_but_by_index_number", | |
79 "by_header_name", | |
80 "all_but_by_header_name", | |
81 ]: | |
82 c = params["input_options"]["column_selector_options_1"]["col1"] | |
83 else: | |
84 c = None | |
85 | |
86 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True) | |
87 | |
88 X = read_columns(df, c=c, c_option=column_option).astype(float) | |
89 | |
90 if params["method"] == "predict": | |
91 preds = estimator.predict(X) | |
92 else: | |
93 preds = estimator.predict_proba(X) | |
94 | |
95 # sparse input | |
96 elif input_type == "sparse": | |
97 X = mmread(open(infile1, "r")) | |
98 if params["method"] == "predict": | |
99 preds = estimator.predict(X) | |
100 else: | |
101 preds = estimator.predict_proba(X) | |
102 | |
103 # fasta input | |
104 elif input_type == "seq_fasta": | |
105 if not hasattr(estimator, "data_batch_generator"): | |
106 raise ValueError( | |
107 "To do prediction on sequences in fasta input, " | |
108 "the estimator must be a `KerasGBatchClassifier`" | |
109 "equipped with data_batch_generator!" | |
110 ) | |
111 pyfaidx = get_module("pyfaidx") | |
112 sequences = pyfaidx.Fasta(fasta_path) | |
113 n_seqs = len(sequences.keys()) | |
114 X = np.arange(n_seqs)[:, np.newaxis] | |
115 seq_length = estimator.data_batch_generator.seq_length | |
116 batch_size = getattr(estimator, "batch_size", 32) | |
117 steps = (n_seqs + batch_size - 1) // batch_size | |
118 | |
119 seq_type = params["input_options"]["seq_type"] | |
120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) | |
121 | |
122 pred_data_generator = klass(fasta_path, seq_length=seq_length) | |
123 | |
124 if params["method"] == "predict": | |
125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) | |
126 else: | |
127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) | |
128 | |
129 # vcf input | |
130 elif input_type == "variant_effect": | |
131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") | |
132 | |
133 options = params["input_options"] | |
134 options.pop("selected_input") | |
135 if options["blacklist_regions"] == "none": | |
136 options["blacklist_regions"] = None | |
137 | |
138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | |
139 | |
140 pred_data_generator.set_processing_attrs() | |
141 | |
142 variants = pred_data_generator.variants | |
143 | |
144 # predict 1600 sample at once then write to file | |
145 gen_flow = pred_data_generator.flow(batch_size=1600) | |
146 | |
147 file_writer = open(outfile_predict, "w") | |
148 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"]) | |
149 file_writer.write(header_row) | |
150 header_done = False | |
151 | |
152 steps_done = 0 | |
153 | |
154 # TODO: multiple threading | |
155 try: | |
156 while steps_done < len(gen_flow): | |
157 index_array = next(gen_flow.index_generator) | |
158 batch_X = gen_flow._get_batches_of_transformed_samples(index_array) | |
159 | |
160 if params["method"] == "predict": | |
161 batch_preds = estimator.predict( | |
162 batch_X, | |
163 # The presence of `pred_data_generator` below is to | |
164 # override model carrying data_generator if there | |
165 # is any. | |
166 data_generator=pred_data_generator, | |
167 ) | |
168 else: | |
169 batch_preds = estimator.predict_proba( | |
170 batch_X, | |
171 # The presence of `pred_data_generator` below is to | |
172 # override model carrying data_generator if there | |
173 # is any. | |
174 data_generator=pred_data_generator, | |
175 ) | |
176 | |
177 if batch_preds.ndim == 1: | |
178 batch_preds = batch_preds[:, np.newaxis] | |
179 | |
180 batch_meta = variants[index_array] | |
181 batch_out = np.column_stack([batch_meta, batch_preds]) | |
182 | |
183 if not header_done: | |
184 heads = np.arange(batch_preds.shape[-1]).astype(str) | |
185 heads_str = "\t".join(heads) | |
186 file_writer.write("\t%s\n" % heads_str) | |
187 header_done = True | |
188 | |
189 for row in batch_out: | |
190 row_str = "\t".join(row) | |
191 file_writer.write("%s\n" % row_str) | |
192 | |
193 steps_done += 1 | |
194 | |
195 finally: | |
196 file_writer.close() | |
197 # TODO: make api `pred_data_generator.close()` | |
198 pred_data_generator.close() | |
199 return 0 | |
200 # end input | |
201 | |
202 # output | |
203 if len(preds.shape) == 1: | |
204 rval = pd.DataFrame(preds, columns=["Predicted"]) | |
205 else: | |
206 rval = pd.DataFrame(preds) | |
207 | |
208 rval.to_csv(outfile_predict, sep="\t", header=True, index=False) | |
209 | |
210 | |
211 if __name__ == "__main__": | |
212 aparser = argparse.ArgumentParser() | |
213 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
214 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator") | |
215 aparser.add_argument("-w", "--infile_weights", dest="infile_weights") | |
216 aparser.add_argument("-X", "--infile1", dest="infile1") | |
217 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict") | |
218 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
219 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
220 aparser.add_argument("-v", "--vcf_path", dest="vcf_path") | |
221 args = aparser.parse_args() | |
222 | |
223 main( | |
224 args.inputs, | |
225 args.infile_estimator, | |
226 args.outfile_predict, | |
227 infile_weights=args.infile_weights, | |
228 infile1=args.infile1, | |
229 fasta_path=args.fasta_path, | |
230 ref_seq=args.ref_seq, | |
231 vcf_path=args.vcf_path, | |
232 ) |