Mercurial > repos > iuc > decontaminator
comparison predict.py @ 0:b856d3d95413 draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/decontaminator commit 3f8e87001f3dfe7d005d0765aeaa930225c93b72
| author | iuc |
|---|---|
| date | Mon, 09 Jan 2023 13:27:09 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:b856d3d95413 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 # -*- coding: utf-8 -*- | |
| 3 # Credits: Grigorii Sukhorukov, Macha Nikolski | |
| 4 import argparse | |
| 5 import os | |
| 6 from pathlib import Path | |
| 7 | |
| 8 import numpy as np | |
| 9 import pandas as pd | |
| 10 from Bio import SeqIO | |
| 11 from models import model_10 | |
| 12 from utils import preprocess as pp | |
| 13 | |
| 14 os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| 15 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" | |
| 16 # loglevel : | |
| 17 # 0 all printed | |
| 18 # 1 I not printed | |
| 19 # 2 I and W not printed | |
| 20 # 3 nothing printed | |
| 21 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' | |
| 22 | |
| 23 | |
| 24 def predict_nn(ds_path, nn_weights_path, length, batch_size=256): | |
| 25 """ | |
| 26 Breaks down contigs into fragments | |
| 27 and uses pretrained neural networks to give predictions for fragments | |
| 28 """ | |
| 29 try: | |
| 30 seqs_ = list(SeqIO.parse(ds_path, "fasta")) | |
| 31 except FileNotFoundError: | |
| 32 raise Exception("test dataset was not found. Change ds variable") | |
| 33 | |
| 34 out_table = { | |
| 35 "id": [], | |
| 36 "length": [], | |
| 37 "fragment": [], | |
| 38 "pred_vir": [], | |
| 39 "pred_other": [], | |
| 40 } | |
| 41 if not seqs_: | |
| 42 raise ValueError("All sequences were smaller than length of the model") | |
| 43 test_fragments = [] | |
| 44 test_fragments_rc = [] | |
| 45 for seq in seqs_: | |
| 46 fragments_, fragments_rc, _ = \ | |
| 47 pp.fragmenting( | |
| 48 [seq], | |
| 49 length, | |
| 50 max_gap=0.8, | |
| 51 sl_wind_step=int(length / 2) | |
| 52 ) | |
| 53 test_fragments.extend(fragments_) | |
| 54 test_fragments_rc.extend(fragments_rc) | |
| 55 for j in range(len(fragments_)): | |
| 56 out_table["id"].append(seq.id) | |
| 57 out_table["length"].append(len(seq.seq)) | |
| 58 out_table["fragment"].append(j) | |
| 59 test_encoded = pp.one_hot_encode(test_fragments) | |
| 60 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) | |
| 61 model = model_10.model(length) | |
| 62 model.load_weights(Path(nn_weights_path, f"model_{length}.h5")) | |
| 63 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) | |
| 64 out_table['pred_vir'].extend(list(prediction[..., 1])) | |
| 65 out_table['pred_other'].extend(list(prediction[..., 0])) | |
| 66 print('Exporting predictions to csv file') | |
| 67 df = pd.DataFrame(out_table) | |
| 68 df['NN_decision'] = np.where(df['pred_vir'] > df['pred_other'], 'virus', 'other') | |
| 69 return df | |
| 70 | |
| 71 | |
| 72 def predict_test(ds_path, length): | |
| 73 """ | |
| 74 Breaks down contigs into fragments | |
| 75 and gives 1 as prediction to all fragments | |
| 76 use only for testing! | |
| 77 """ | |
| 78 try: | |
| 79 seqs_ = list(SeqIO.parse(ds_path, "fasta")) | |
| 80 except FileNotFoundError: | |
| 81 raise Exception("test dataset was not found. Change ds variable") | |
| 82 | |
| 83 out_table = { | |
| 84 "id": [], | |
| 85 "length": [], | |
| 86 "fragment": [], | |
| 87 } | |
| 88 if not seqs_: | |
| 89 raise ValueError("All sequences were smaller than length of the model") | |
| 90 for seq in seqs_: | |
| 91 fragments_, fragments_rc, _ = \ | |
| 92 pp.fragmenting( | |
| 93 [seq], | |
| 94 length, | |
| 95 max_gap=0.8, | |
| 96 sl_wind_step=int(length / 2) | |
| 97 ) | |
| 98 for j in range(len(fragments_)): | |
| 99 out_table["id"].append(seq.id) | |
| 100 out_table["length"].append(len(seq.seq)) | |
| 101 out_table["fragment"].append(j) | |
| 102 print('Exporting predictions to tsv file') | |
| 103 df = pd.DataFrame(out_table) | |
| 104 df['pred_vir'] = 1 | |
| 105 df['pred_other'] = 0 | |
| 106 df['NN_decision'] = 'virus' | |
| 107 return df | |
| 108 | |
| 109 | |
| 110 def predict_contigs(df): | |
| 111 """ | |
| 112 Based on predictions of predict_rf for fragments | |
| 113 gives a final prediction for the whole contig | |
| 114 """ | |
| 115 df = ( | |
| 116 df.groupby( | |
| 117 ["id", | |
| 118 "length", | |
| 119 'NN_decision'], | |
| 120 sort=False | |
| 121 ).size().unstack(fill_value=0) | |
| 122 ) | |
| 123 df = df.reset_index() | |
| 124 df = df.reindex( | |
| 125 ['length', 'id', 'virus', 'other', ], | |
| 126 axis=1 | |
| 127 ).fillna(value=0) | |
| 128 df['decision'] = np.where(df['virus'] >= df['other'], 'virus', 'other') | |
| 129 df = df.sort_values(by='length', ascending=False) | |
| 130 df = df.loc[:, ['length', 'id', 'virus', 'other', 'decision']] | |
| 131 df = df.rename( | |
| 132 columns={ | |
| 133 'virus': '# viral fragments', | |
| 134 'other': '# other fragments', | |
| 135 } | |
| 136 ) | |
| 137 df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# other fragments'])).round(3) | |
| 138 df['# viral / # total * length'] = df['# viral / # total'] * df['length'] | |
| 139 df = df.sort_values(by='# viral / # total * length', ascending=False) | |
| 140 return df | |
| 141 | |
| 142 | |
| 143 def predict(test_ds, weights, out_path, return_viral=True): | |
| 144 """filters out contaminant contigs from the fasta file. | |
| 145 | |
| 146 test_ds: path to the input file with | |
| 147 contigs in fasta format (str or list of str) | |
| 148 weights: path to the folder containing weights | |
| 149 for NN and RF modules trained on 500 and 1000 fragment lengths (str) | |
| 150 out_path: path to the folder to store predictions (str) | |
| 151 return_viral: whether to return contigs annotated as | |
| 152 viral in separate fasta file (True/False) | |
| 153 """ | |
| 154 | |
| 155 test_ds = test_ds | |
| 156 if isinstance(test_ds, list): | |
| 157 pass | |
| 158 elif isinstance(test_ds, str): | |
| 159 test_ds = [test_ds] | |
| 160 else: | |
| 161 raise ValueError('test_ds was incorrectly assigned in the config file') | |
| 162 | |
| 163 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' | |
| 164 # assert Path(weights).exists(), f'{weights} does not exist' | |
| 165 limit = 0 | |
| 166 Path(out_path).mkdir(parents=True, exist_ok=True) | |
| 167 | |
| 168 # parameter to activate test function. Only for debugging on github | |
| 169 # test is launched when the weights directory is empty | |
| 170 use_test_f = not Path(weights, 'model_1000.h5').exists() | |
| 171 for ts in test_ds: | |
| 172 dfs_fr = [] | |
| 173 dfs_cont = [] | |
| 174 for l_ in 500, 1000: | |
| 175 print(f'starting prediction for {Path(ts).name} ' | |
| 176 f'for fragment length {l_}') | |
| 177 if use_test_f: | |
| 178 df = predict_test(ds_path=ts, length=l_, ) | |
| 179 else: | |
| 180 df = predict_nn( | |
| 181 ds_path=ts, | |
| 182 nn_weights_path=weights, | |
| 183 length=l_, | |
| 184 ) | |
| 185 df = df.round(3) | |
| 186 dfs_fr.append(df) | |
| 187 df = predict_contigs(df) | |
| 188 dfs_cont.append(df) | |
| 189 print('prediction finished') | |
| 190 df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)] | |
| 191 df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)] | |
| 192 df = pd.concat([df_1000, df_500], ignore_index=True) | |
| 193 pred_fr = Path(out_path, "predicted_fragments.tsv") | |
| 194 df.to_csv(pred_fr, sep='\t') | |
| 195 | |
| 196 df_500 = dfs_cont[0][(dfs_cont[0]['length'] | |
| 197 >= limit) & (dfs_cont[0]['length'] < 1500)] | |
| 198 df_1000 = dfs_cont[1][(dfs_cont[1]['length'] | |
| 199 >= 1500)] | |
| 200 df = pd.concat([df_1000, df_500], ignore_index=True) | |
| 201 pred_contigs = Path(out_path, "predicted.tsv") | |
| 202 df.to_csv(pred_contigs, sep='\t') | |
| 203 | |
| 204 if return_viral: | |
| 205 viral_ids = list(df[df["decision"] == "virus"]["id"]) | |
| 206 seqs_ = list(SeqIO.parse(ts, "fasta")) | |
| 207 viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids] | |
| 208 SeqIO.write( | |
| 209 viral_seqs, | |
| 210 Path( | |
| 211 out_path, | |
| 212 "viral.fasta"), 'fasta') | |
| 213 | |
| 214 | |
| 215 if __name__ == '__main__': | |
| 216 parser = argparse.ArgumentParser() | |
| 217 parser.add_argument("--test_ds", help="path to the input " | |
| 218 "file with contigs " | |
| 219 "in fasta format " | |
| 220 "(str or list of str)") | |
| 221 parser.add_argument("--weights", help="path to the folder containing " | |
| 222 "weights for NN and RF modules " | |
| 223 "trained on 500 and 1000 " | |
| 224 "fragment lengths (str)") | |
| 225 parser.add_argument("--out_path", help="path to the folder to store " | |
| 226 "predictions (str)") | |
| 227 parser.add_argument("--return_viral", help="whether to return " | |
| 228 "contigs annotated " | |
| 229 "as viral in separate " | |
| 230 "fasta file (True/False)") | |
| 231 | |
| 232 args = parser.parse_args() | |
| 233 if args.test_ds: | |
| 234 test_ds = args.test_ds | |
| 235 if args.weights: | |
| 236 weights = args.weights | |
| 237 if args.out_path: | |
| 238 out_path = args.out_path | |
| 239 if args.return_viral: | |
| 240 return_viral = args.return_viral | |
| 241 | |
| 242 predict(test_ds, weights, out_path, return_viral) |
