Mercurial > repos > iuc > virhunter
annotate predict.py @ 2:ea2cccb9f73e draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit c3685ed6a70b47012b62b95a2a3db062bd3b7475
author | iuc |
---|---|
date | Thu, 05 Jan 2023 14:27:54 +0000 |
parents | 9b12bc1b1e2c |
children | 302332b914ef |
rev | line source |
---|---|
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
1 #!/usr/bin/env python |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
2 # -*- coding: utf-8 -*- |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
3 # Credits: Grigorii Sukhorukov, Macha Nikolski |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
4 import argparse |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
5 import os |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
6 from pathlib import Path |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
7 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
8 import numpy as np |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
9 import pandas as pd |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
10 from Bio import SeqIO |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
11 from joblib import load |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
12 from models import model_10, model_5, model_7 |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
13 from utils import preprocess as pp |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
14 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
15 os.environ["CUDA_VISIBLE_DEVICES"] = "" |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit" |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
19 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
20 |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
21 def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256): |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
22 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
23 Breaks down contigs into fragments |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
24 and uses pretrained neural networks to give predictions for fragments |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
25 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
26 try: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
27 seqs_ = list(SeqIO.parse(ds_path, "fasta")) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
28 except FileNotFoundError: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
29 raise Exception("test dataset was not found. Change ds variable") |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
30 out_table = { |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
31 "id": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
32 "length": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
33 "fragment": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
34 "pred_plant_5": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
35 "pred_vir_5": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
36 "pred_bact_5": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
37 "pred_plant_7": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
38 "pred_vir_7": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
39 "pred_bact_7": [], |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
40 } |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
41 if use_10: |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
42 out_table_ = { |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
43 "pred_plant_10": [], |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
44 "pred_vir_10": [], |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
45 "pred_bact_10": [], |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
46 } |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
47 out_table.update(out_table_) |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
48 if not seqs_: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
49 raise ValueError("All sequences were smaller than length of the model") |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
50 test_fragments = [] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
51 test_fragments_rc = [] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
52 for seq in seqs_: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
53 fragments_, fragments_rc, _ = pp.fragmenting([seq], length, max_gap=0.8, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
54 sl_wind_step=int(length / 2)) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
55 test_fragments.extend(fragments_) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
56 test_fragments_rc.extend(fragments_rc) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
57 for j in range(len(fragments_)): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
58 out_table["id"].append(seq.id) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
59 out_table["length"].append(len(seq.seq)) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
60 out_table["fragment"].append(j) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
61 test_encoded = pp.one_hot_encode(test_fragments) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
62 test_encoded_rc = pp.one_hot_encode(test_fragments_rc) |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
63 if use_10: |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
64 zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]) |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
65 else: |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
66 zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7]) |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
67 for model, s in zipped_models: |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
68 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5")) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
69 prediction = model.predict([test_encoded, test_encoded_rc], batch_size) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
70 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0])) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
71 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1])) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
72 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2])) |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
73 |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
74 return pd.DataFrame(out_table) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
75 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
76 |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
77 def predict_rf(df, rf_weights_path, length, use_10): |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
78 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
79 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
80 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
81 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
82 clf = load(Path(rf_weights_path, f"RF_{length}.joblib")) |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
83 if use_10: |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
84 X = df[ |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
85 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]] |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
86 else: |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
87 X = df[ |
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
88 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]] |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
89 y_pred = clf.predict(X) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
90 mapping = {0: "plant", 1: "virus", 2: "bacteria"} |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
91 df["RF_decision"] = np.vectorize(mapping.get)(y_pred) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
92 prob_classes = clf.predict_proba(X) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
93 df["RF_pred_plant"] = prob_classes[..., 0] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
94 df["RF_pred_vir"] = prob_classes[..., 1] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
95 df["RF_pred_bact"] = prob_classes[..., 2] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
96 return df |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
97 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
98 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
99 def predict_contigs(df): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
100 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
101 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
102 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
103 df = ( |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
104 df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0) |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
105 ) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
106 df = df.reset_index() |
2
ea2cccb9f73e
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit c3685ed6a70b47012b62b95a2a3db062bd3b7475
iuc
parents:
1
diff
changeset
|
107 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1).fillna(value=0) |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
108 conditions = [ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
109 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
110 (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
111 (df['bacteria'] >= df['plant']) & (df['bacteria'] >= df['virus']), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
112 ] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
113 choices = ['virus', 'plant', 'bacteria'] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
114 df['decision'] = np.select(conditions, choices, default='bacteria') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
115 df = df.loc[:, ['length', 'id', 'virus', 'plant', 'bacteria', 'decision']] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
116 df = df.rename(columns={'virus': '# viral fragments', 'bacteria': '# bacterial fragments', 'plant': '# plant fragments'}) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
117 df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# bacterial fragments'] + df['# plant fragments'])).round(3) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
118 df['# viral / # total * length'] = df['# viral / # total'] * df['length'] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
119 df = df.sort_values(by='# viral / # total * length', ascending=False) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
120 return df |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
121 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
122 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
123 def predict(test_ds, weights, out_path, return_viral, limit): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
124 """Predicts viral contigs from the fasta file |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
125 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
126 test_ds: path to the input file with contigs in fasta format (str or list of str) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
127 weights: path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
128 out_path: path to the folder to store predictions (str) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
129 return_viral: whether to return contigs annotated as viral in separate fasta file (True/False) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
130 limit: Do predictions only for contigs > l. We suggest l=750. (int) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
131 """ |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
132 test_ds = test_ds |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
133 if isinstance(test_ds, list): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
134 pass |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
135 elif isinstance(test_ds, str): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
136 test_ds = [test_ds] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
137 else: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
138 raise ValueError('test_ds was incorrectly assigned in the config file') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
139 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
140 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist' |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
141 assert Path(weights).exists(), f'{weights} does not exist' |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
142 assert isinstance(limit, int), 'limit should be an integer' |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
143 Path(out_path).mkdir(parents=True, exist_ok=True) |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
144 use_10 = Path(weights, 'model_10_500.h5').exists() |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
145 for ts in test_ds: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
146 dfs_fr = [] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
147 dfs_cont = [] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
148 for l_ in 500, 1000: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
149 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
150 df = predict_nn( |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
151 ds_path=ts, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
152 nn_weights_path=weights, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
153 length=l_, |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
154 use_10=use_10 |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
155 ) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
156 print(df) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
157 df = predict_rf( |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
158 df=df, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
159 rf_weights_path=weights, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
160 length=l_, |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
161 use_10=use_10 |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
162 ) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
163 df = df.round(3) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
164 dfs_fr.append(df) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
165 df = predict_contigs(df) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
166 dfs_cont.append(df) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
167 # print('prediction finished') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
168 df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
169 df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
170 df = pd.concat([df_1000, df_500], ignore_index=True) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
171 pred_fr = Path(out_path, 'predicted_fragments.csv') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
172 df.to_csv(pred_fr) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
173 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
174 df_500 = dfs_cont[0][(dfs_cont[0]['length'] >= limit) & (dfs_cont[0]['length'] < 1500)] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
175 df_1000 = dfs_cont[1][(dfs_cont[1]['length'] >= 1500)] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
176 df = pd.concat([df_1000, df_500], ignore_index=True) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
177 pred_contigs = Path(out_path, 'predicted.csv') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
178 df.to_csv(pred_contigs) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
179 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
180 if return_viral: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
181 viral_ids = list(df[df["decision"] == "virus"]["id"]) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
182 seqs_ = list(SeqIO.parse(ts, "fasta")) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
183 viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids] |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
184 SeqIO.write(viral_seqs, Path(out_path, 'viral.fasta'), 'fasta') |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
185 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
186 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
187 if __name__ == '__main__': |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
188 parser = argparse.ArgumentParser() |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
189 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)") |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
190 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)") |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
191 parser.add_argument("--out_path", help="path to the folder to store predictions (str)") |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
192 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)") |
1
9b12bc1b1e2c
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents:
0
diff
changeset
|
193 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750) |
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
194 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
195 args = parser.parse_args() |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
196 if args.test_ds: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
197 test_ds = args.test_ds |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
198 if args.weights: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
199 weights = args.weights |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
200 if args.out_path: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
201 out_path = args.out_path |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
202 if args.return_viral: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
203 return_viral = args.return_viral |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
204 if args.limit: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
205 limit = args.limit |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
206 predict(test_ds, weights, out_path, return_viral, limit) |