comparison model_prediction.py @ 0:59e8b4328c82 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author bgruening
date Tue, 13 Apr 2021 22:40:10 +0000
parents
children f93f0cdbaf18
comparison
equal deleted inserted replaced
-1:000000000000 0:59e8b4328c82
1 import argparse
2 import json
3 import warnings
4
5 import numpy as np
6 import pandas as pd
7 from galaxy_ml.utils import get_module, load_model, read_columns, try_get_attr
8 from scipy.io import mmread
9 from sklearn.pipeline import Pipeline
10
11 N_JOBS = int(__import__("os").environ.get("GALAXY_SLOTS", 1))
12
13
14 def main(
15 inputs,
16 infile_estimator,
17 outfile_predict,
18 infile_weights=None,
19 infile1=None,
20 fasta_path=None,
21 ref_seq=None,
22 vcf_path=None,
23 ):
24 """
25 Parameter
26 ---------
27 inputs : str
28 File path to galaxy tool parameter
29
30 infile_estimator : strgit
31 File path to trained estimator input
32
33 outfile_predict : str
34 File path to save the prediction results, tabular
35
36 infile_weights : str
37 File path to weights input
38
39 infile1 : str
40 File path to dataset containing features
41
42 fasta_path : str
43 File path to dataset containing fasta file
44
45 ref_seq : str
46 File path to dataset containing the reference genome sequence.
47
48 vcf_path : str
49 File path to dataset containing variants info.
50 """
51 warnings.filterwarnings("ignore")
52
53 with open(inputs, "r") as param_handler:
54 params = json.load(param_handler)
55
56 # load model
57 with open(infile_estimator, "rb") as est_handler:
58 estimator = load_model(est_handler)
59
60 main_est = estimator
61 if isinstance(estimator, Pipeline):
62 main_est = estimator.steps[-1][-1]
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
64 if not infile_weights or infile_weights == "None":
65 raise ValueError(
66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!"
67 )
68 main_est.load_weights(infile_weights)
69
70 # handle data input
71 input_type = params["input_options"]["selected_input"]
72 # tabular input
73 if input_type == "tabular":
74 header = "infer" if params["input_options"]["header1"] else None
75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"]
76 if column_option in [
77 "by_index_number",
78 "all_but_by_index_number",
79 "by_header_name",
80 "all_but_by_header_name",
81 ]:
82 c = params["input_options"]["column_selector_options_1"]["col1"]
83 else:
84 c = None
85
86 df = pd.read_csv(infile1, sep="\t", header=header, parse_dates=True)
87
88 X = read_columns(df, c=c, c_option=column_option).astype(float)
89
90 if params["method"] == "predict":
91 preds = estimator.predict(X)
92 else:
93 preds = estimator.predict_proba(X)
94
95 # sparse input
96 elif input_type == "sparse":
97 X = mmread(open(infile1, "r"))
98 if params["method"] == "predict":
99 preds = estimator.predict(X)
100 else:
101 preds = estimator.predict_proba(X)
102
103 # fasta input
104 elif input_type == "seq_fasta":
105 if not hasattr(estimator, "data_batch_generator"):
106 raise ValueError(
107 "To do prediction on sequences in fasta input, "
108 "the estimator must be a `KerasGBatchClassifier`"
109 "equipped with data_batch_generator!"
110 )
111 pyfaidx = get_module("pyfaidx")
112 sequences = pyfaidx.Fasta(fasta_path)
113 n_seqs = len(sequences.keys())
114 X = np.arange(n_seqs)[:, np.newaxis]
115 seq_length = estimator.data_batch_generator.seq_length
116 batch_size = getattr(estimator, "batch_size", 32)
117 steps = (n_seqs + batch_size - 1) // batch_size
118
119 seq_type = params["input_options"]["seq_type"]
120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
121
122 pred_data_generator = klass(fasta_path, seq_length=seq_length)
123
124 if params["method"] == "predict":
125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps)
126 else:
127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps)
128
129 # vcf input
130 elif input_type == "variant_effect":
131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
132
133 options = params["input_options"]
134 options.pop("selected_input")
135 if options["blacklist_regions"] == "none":
136 options["blacklist_regions"] = None
137
138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options)
139
140 pred_data_generator.set_processing_attrs()
141
142 variants = pred_data_generator.variants
143
144 # predict 1600 sample at once then write to file
145 gen_flow = pred_data_generator.flow(batch_size=1600)
146
147 file_writer = open(outfile_predict, "w")
148 header_row = "\t".join(["chrom", "pos", "name", "ref", "alt", "strand"])
149 file_writer.write(header_row)
150 header_done = False
151
152 steps_done = 0
153
154 # TODO: multiple threading
155 try:
156 while steps_done < len(gen_flow):
157 index_array = next(gen_flow.index_generator)
158 batch_X = gen_flow._get_batches_of_transformed_samples(index_array)
159
160 if params["method"] == "predict":
161 batch_preds = estimator.predict(
162 batch_X,
163 # The presence of `pred_data_generator` below is to
164 # override model carrying data_generator if there
165 # is any.
166 data_generator=pred_data_generator,
167 )
168 else:
169 batch_preds = estimator.predict_proba(
170 batch_X,
171 # The presence of `pred_data_generator` below is to
172 # override model carrying data_generator if there
173 # is any.
174 data_generator=pred_data_generator,
175 )
176
177 if batch_preds.ndim == 1:
178 batch_preds = batch_preds[:, np.newaxis]
179
180 batch_meta = variants[index_array]
181 batch_out = np.column_stack([batch_meta, batch_preds])
182
183 if not header_done:
184 heads = np.arange(batch_preds.shape[-1]).astype(str)
185 heads_str = "\t".join(heads)
186 file_writer.write("\t%s\n" % heads_str)
187 header_done = True
188
189 for row in batch_out:
190 row_str = "\t".join(row)
191 file_writer.write("%s\n" % row_str)
192
193 steps_done += 1
194
195 finally:
196 file_writer.close()
197 # TODO: make api `pred_data_generator.close()`
198 pred_data_generator.close()
199 return 0
200 # end input
201
202 # output
203 if len(preds.shape) == 1:
204 rval = pd.DataFrame(preds, columns=["Predicted"])
205 else:
206 rval = pd.DataFrame(preds)
207
208 rval.to_csv(outfile_predict, sep="\t", header=True, index=False)
209
210
211 if __name__ == "__main__":
212 aparser = argparse.ArgumentParser()
213 aparser.add_argument("-i", "--inputs", dest="inputs", required=True)
214 aparser.add_argument("-e", "--infile_estimator", dest="infile_estimator")
215 aparser.add_argument("-w", "--infile_weights", dest="infile_weights")
216 aparser.add_argument("-X", "--infile1", dest="infile1")
217 aparser.add_argument("-O", "--outfile_predict", dest="outfile_predict")
218 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
219 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
220 aparser.add_argument("-v", "--vcf_path", dest="vcf_path")
221 args = aparser.parse_args()
222
223 main(
224 args.inputs,
225 args.infile_estimator,
226 args.outfile_predict,
227 infile_weights=args.infile_weights,
228 infile1=args.infile1,
229 fasta_path=args.fasta_path,
230 ref_seq=args.ref_seq,
231 vcf_path=args.vcf_path,
232 )