comparison predict.py @ 0:b856d3d95413 draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/decontaminator commit 3f8e87001f3dfe7d005d0765aeaa930225c93b72
author iuc
date Mon, 09 Jan 2023 13:27:09 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:b856d3d95413
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Credits: Grigorii Sukhorukov, Macha Nikolski
4 import argparse
5 import os
6 from pathlib import Path
7
8 import numpy as np
9 import pandas as pd
10 from Bio import SeqIO
11 from models import model_10
12 from utils import preprocess as pp
13
14 os.environ["CUDA_VISIBLE_DEVICES"] = ""
15 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
16 # loglevel :
17 # 0 all printed
18 # 1 I not printed
19 # 2 I and W not printed
20 # 3 nothing printed
21 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
22
23
24 def predict_nn(ds_path, nn_weights_path, length, batch_size=256):
25 """
26 Breaks down contigs into fragments
27 and uses pretrained neural networks to give predictions for fragments
28 """
29 try:
30 seqs_ = list(SeqIO.parse(ds_path, "fasta"))
31 except FileNotFoundError:
32 raise Exception("test dataset was not found. Change ds variable")
33
34 out_table = {
35 "id": [],
36 "length": [],
37 "fragment": [],
38 "pred_vir": [],
39 "pred_other": [],
40 }
41 if not seqs_:
42 raise ValueError("All sequences were smaller than length of the model")
43 test_fragments = []
44 test_fragments_rc = []
45 for seq in seqs_:
46 fragments_, fragments_rc, _ = \
47 pp.fragmenting(
48 [seq],
49 length,
50 max_gap=0.8,
51 sl_wind_step=int(length / 2)
52 )
53 test_fragments.extend(fragments_)
54 test_fragments_rc.extend(fragments_rc)
55 for j in range(len(fragments_)):
56 out_table["id"].append(seq.id)
57 out_table["length"].append(len(seq.seq))
58 out_table["fragment"].append(j)
59 test_encoded = pp.one_hot_encode(test_fragments)
60 test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
61 model = model_10.model(length)
62 model.load_weights(Path(nn_weights_path, f"model_{length}.h5"))
63 prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
64 out_table['pred_vir'].extend(list(prediction[..., 1]))
65 out_table['pred_other'].extend(list(prediction[..., 0]))
66 print('Exporting predictions to csv file')
67 df = pd.DataFrame(out_table)
68 df['NN_decision'] = np.where(df['pred_vir'] > df['pred_other'], 'virus', 'other')
69 return df
70
71
72 def predict_test(ds_path, length):
73 """
74 Breaks down contigs into fragments
75 and gives 1 as prediction to all fragments
76 use only for testing!
77 """
78 try:
79 seqs_ = list(SeqIO.parse(ds_path, "fasta"))
80 except FileNotFoundError:
81 raise Exception("test dataset was not found. Change ds variable")
82
83 out_table = {
84 "id": [],
85 "length": [],
86 "fragment": [],
87 }
88 if not seqs_:
89 raise ValueError("All sequences were smaller than length of the model")
90 for seq in seqs_:
91 fragments_, fragments_rc, _ = \
92 pp.fragmenting(
93 [seq],
94 length,
95 max_gap=0.8,
96 sl_wind_step=int(length / 2)
97 )
98 for j in range(len(fragments_)):
99 out_table["id"].append(seq.id)
100 out_table["length"].append(len(seq.seq))
101 out_table["fragment"].append(j)
102 print('Exporting predictions to tsv file')
103 df = pd.DataFrame(out_table)
104 df['pred_vir'] = 1
105 df['pred_other'] = 0
106 df['NN_decision'] = 'virus'
107 return df
108
109
110 def predict_contigs(df):
111 """
112 Based on predictions of predict_rf for fragments
113 gives a final prediction for the whole contig
114 """
115 df = (
116 df.groupby(
117 ["id",
118 "length",
119 'NN_decision'],
120 sort=False
121 ).size().unstack(fill_value=0)
122 )
123 df = df.reset_index()
124 df = df.reindex(
125 ['length', 'id', 'virus', 'other', ],
126 axis=1
127 ).fillna(value=0)
128 df['decision'] = np.where(df['virus'] >= df['other'], 'virus', 'other')
129 df = df.sort_values(by='length', ascending=False)
130 df = df.loc[:, ['length', 'id', 'virus', 'other', 'decision']]
131 df = df.rename(
132 columns={
133 'virus': '# viral fragments',
134 'other': '# other fragments',
135 }
136 )
137 df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# other fragments'])).round(3)
138 df['# viral / # total * length'] = df['# viral / # total'] * df['length']
139 df = df.sort_values(by='# viral / # total * length', ascending=False)
140 return df
141
142
143 def predict(test_ds, weights, out_path, return_viral=True):
144 """filters out contaminant contigs from the fasta file.
145
146 test_ds: path to the input file with
147 contigs in fasta format (str or list of str)
148 weights: path to the folder containing weights
149 for NN and RF modules trained on 500 and 1000 fragment lengths (str)
150 out_path: path to the folder to store predictions (str)
151 return_viral: whether to return contigs annotated as
152 viral in separate fasta file (True/False)
153 """
154
155 test_ds = test_ds
156 if isinstance(test_ds, list):
157 pass
158 elif isinstance(test_ds, str):
159 test_ds = [test_ds]
160 else:
161 raise ValueError('test_ds was incorrectly assigned in the config file')
162
163 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist'
164 # assert Path(weights).exists(), f'{weights} does not exist'
165 limit = 0
166 Path(out_path).mkdir(parents=True, exist_ok=True)
167
168 # parameter to activate test function. Only for debugging on github
169 # test is launched when the weights directory is empty
170 use_test_f = not Path(weights, 'model_1000.h5').exists()
171 for ts in test_ds:
172 dfs_fr = []
173 dfs_cont = []
174 for l_ in 500, 1000:
175 print(f'starting prediction for {Path(ts).name} '
176 f'for fragment length {l_}')
177 if use_test_f:
178 df = predict_test(ds_path=ts, length=l_, )
179 else:
180 df = predict_nn(
181 ds_path=ts,
182 nn_weights_path=weights,
183 length=l_,
184 )
185 df = df.round(3)
186 dfs_fr.append(df)
187 df = predict_contigs(df)
188 dfs_cont.append(df)
189 print('prediction finished')
190 df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)]
191 df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)]
192 df = pd.concat([df_1000, df_500], ignore_index=True)
193 pred_fr = Path(out_path, "predicted_fragments.tsv")
194 df.to_csv(pred_fr, sep='\t')
195
196 df_500 = dfs_cont[0][(dfs_cont[0]['length']
197 >= limit) & (dfs_cont[0]['length'] < 1500)]
198 df_1000 = dfs_cont[1][(dfs_cont[1]['length']
199 >= 1500)]
200 df = pd.concat([df_1000, df_500], ignore_index=True)
201 pred_contigs = Path(out_path, "predicted.tsv")
202 df.to_csv(pred_contigs, sep='\t')
203
204 if return_viral:
205 viral_ids = list(df[df["decision"] == "virus"]["id"])
206 seqs_ = list(SeqIO.parse(ts, "fasta"))
207 viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids]
208 SeqIO.write(
209 viral_seqs,
210 Path(
211 out_path,
212 "viral.fasta"), 'fasta')
213
214
215 if __name__ == '__main__':
216 parser = argparse.ArgumentParser()
217 parser.add_argument("--test_ds", help="path to the input "
218 "file with contigs "
219 "in fasta format "
220 "(str or list of str)")
221 parser.add_argument("--weights", help="path to the folder containing "
222 "weights for NN and RF modules "
223 "trained on 500 and 1000 "
224 "fragment lengths (str)")
225 parser.add_argument("--out_path", help="path to the folder to store "
226 "predictions (str)")
227 parser.add_argument("--return_viral", help="whether to return "
228 "contigs annotated "
229 "as viral in separate "
230 "fasta file (True/False)")
231
232 args = parser.parse_args()
233 if args.test_ds:
234 test_ds = args.test_ds
235 if args.weights:
236 weights = args.weights
237 if args.out_path:
238 out_path = args.out_path
239 if args.return_viral:
240 return_viral = args.return_viral
241
242 predict(test_ds, weights, out_path, return_viral)