comparison model_prediction.py @ 0:af2624d5ab32 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:24:32 +0000
parents
children 9349ed2749c6
comparison
equal deleted inserted replaced
-1:000000000000 0:af2624d5ab32
1 import argparse
2 import json
3 import warnings
4
5 import numpy as np
6 import pandas as pd
7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr
8 from scipy.io import mmread
9 from sklearn.pipeline import Pipeline
10
11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
12
13
14 def main(
15 inputs,
16 infile_estimator,
17 outfile_predict,
18 infile_weights=None,
19 infile1=None,
20 fasta_path=None,
21 ref_seq=None,
22 vcf_path=None,
23 ):
24 """
25 Parameter
26 ---------
27 inputs : str
28 File path to galaxy tool parameter
29
30 infile_estimator : strgit
31 File path to trained estimator input
32
33 outfile_predict : str
34 File path to save the prediction results, tabular
35
36 infile_weights : str
37 File path to weights input
38
39 infile1 : str
40 File path to dataset containing features
41
42 fasta_path : str
43 File path to dataset containing fasta file
44
45 ref_seq : str
46 File path to dataset containing the reference genome sequence.
47
48 vcf_path : str
49 File path to dataset containing variants info.
50 """
51 warnings.filterwarnings("ignore")
52
53 with open(inputs, "r") as param_handler:
54 params = json.load(param_handler)
55
56 # load model
57 with open(infile_estimator, "rb") as est_handler:
58 estimator = load_model(est_handler)
59
60 main_est = estimator
61 if isinstance(estimator, Pipeline):
62 main_est = estimator.steps[-1][-1]
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
64 if not infile_weights or infile_weights == "None":
65 raise ValueError(
66 "The selected model skeleton asks for weights, "
67 "but dataset for weights wan not selected!"
68 )
69 main_est.load_weights(infile_weights)
70
71 # handle data input
72 input_type = params["input_options"]["selected_input"]
73 # tabular input
74 if input_type == "tabular":
75 header = "infer" if params["input_options"]["header1"] else None
76 column_option = params["input_options"]["column_selector_options_1"][
77 "selected_column_selector_option"
78 ]
79 if column_option in [
80 "by_index_number",
81 "all_but_by_index_number",
82 "by_header_name",
83 "all_but_by_header_name",
84 ]:
85 c = params["input_options"]["column_selector_options_1"]["col1"]
86 else:
87 c = None
88
89 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
90
91 X = read_columns(df, c=c, c_option=column_option).astype(float)
92
93 if params["method"] == "predict":
94 preds = estimator.predict(X)
95 else:
96 preds = estimator.predict_proba(X)
97
98 # sparse input
99 elif input_type == "sparse":
100 X = mmread(open(infile1, "r"))
101 if params["method"] == "predict":
102 preds = estimator.predict(X)
103 else:
104 preds = estimator.predict_proba(X)
105
106 # fasta input
107 elif input_type == "seq_fasta":
108 if not hasattr(estimator, "data_batch_generator"):
109 raise ValueError(
110 "To do prediction on sequences in fasta input, "
111 "the estimator must be a `KerasGBatchClassifier`"
112 "equipped with data_batch_generator!"
113 )
114 pyfaidx = get_module("pyfaidx")
115 sequences = pyfaidx.Fasta(fasta_path)
116 n_seqs = len(sequences.keys())
117 X = np.arange(n_seqs)[:, np.newaxis]
118 seq_length = estimator.data_batch_generator.seq_length
119 batch_size = getattr(estimator, "batch_size", 32)
120 steps = (n_seqs + batch_size - 1) // batch_size
121
122 seq_type = params["input_options"]["seq_type"]
123 klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
124
125 pred_data_generator = klass(fasta_path, seq_length=seq_length)
126
127 if params["method"] == "predict":
128 preds = estimator.predict(
129 X, data_generator=pred_data_generator, steps=steps
130 )
131 else:
132 preds = estimator.predict_proba(
133 X, data_generator=pred_data_generator, steps=steps
134 )
135
136 # vcf input
137 elif input_type == "variant_effect":
138 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
139
140 options = params["input_options"]
141 options.pop("selected_input")
142 if options["blacklist_regions"] == "none":
143 options["blacklist_regions"] = None
144
145 pred_data_generator = klass(
146 ref_genome_path=ref_seq, vcf_path=vcf_path, **options
147 )
148
149 pred_data_generator.set_processing_attrs()
150
151 variants = pred_data_generator.variants
152
153 # predict 1600 sample at once then write to file
154 gen_flow = pred_data_generator.flow(batch_size=1600)
155
156 file_writer = open(outfile_predict, "w")
157 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"])
158 file_writer.write(header_row)
159 header_done = False
160
161 steps_done = 0
162
163 # TODO: multiple threading
164 try:
165 while steps_done < len(gen_flow):
166 index_array = next(gen_flow.index_generator)
167 batch_X = gen_flow._get_batches_of_transformed_samples(index_array)
168
169 if params["method"] == "predict":
170 batch_preds = estimator.predict(
171 batch_X,
172 # The presence of `pred_data_generator` below is to
173 # override model carrying data_generator if there
174 # is any.
175 data_generator=pred_data_generator,
176 )
177 else:
178 batch_preds = estimator.predict_proba(
179 batch_X,
180 # The presence of `pred_data_generator` below is to
181 # override model carrying data_generator if there
182 # is any.
183 data_generator=pred_data_generator,
184 )
185
186 if batch_preds.ndim == 1:
187 batch_preds = batch_preds[:, np.newaxis]
188
189 batch_meta = variants[index_array]
190 batch_out = np.column_stack([batch_meta, batch_preds])
191
192 if not header_done:
193 heads = np.arange(batch_preds.shape[-1]).astype(str)
194 heads_str = "\t".join(heads)
195 file_writer.write("\t%s\n" % heads_str)
196 header_done = True
197
198 for row in batch_out:
199 row_str = "\t".join(row)
200 file_writer.write("%s\n" % row_str)
201
202 steps_done += 1
203
204 finally:
205 file_writer.close()
206 # TODO: make api `pred_data_generator.close()`
207 pred_data_generator.close()
208 return 0
209 # end input
210
211 # output
212 if len(preds.shape) == 1:
213 rval = pd.DataFrame(preds, columns=["Predicted"])
214 else:
215 rval = pd.DataFrame(preds)
216
217 rval.to_csv(outfile_predict, sep="\t", header=True, index=False)
218
219
220 if __name__ == "__main__":
221 aparser = argparse.ArgumentParser()
222 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
223 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
224 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
225 aparser.add_argument("-X", "--infile1", dest="infile1")
226 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
227 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
228 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
229 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
230 args = aparser.parse_args()
231
232 main(
233 args.inputs,
234 args.infile_estimator,
235 args.outfile_predict,
236 infile_weights=args.infile_weights,
237 infile1=args.infile1,
238 fasta_path=args.fasta_path,
239 ref_seq=args.ref_seq,
240 vcf_path=args.vcf_path,
241 )