annotate predict.py @ 2:ea2cccb9f73e draft

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit c3685ed6a70b47012b62b95a2a3db062bd3b7475
author iuc
date Thu, 05 Jan 2023 14:27:54 +0000
parents 9b12bc1b1e2c
children 302332b914ef
Ignore whitespace changes - Everywhere: Within whitespace: At end of lines:
rev   line source
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
1 #!/usr/bin/env python
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
2 # -*- coding: utf-8 -*-
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
3 # Credits: Grigorii Sukhorukov, Macha Nikolski
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
4 import argparse
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
5 import os
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
6 from pathlib import Path
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
7
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
8 import numpy as np
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
9 import pandas as pd
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
10 from Bio import SeqIO
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
11 from joblib import load
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
12 from models import model_10, model_5, model_7
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
13 from utils import preprocess as pp
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
14
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
15 os.environ["CUDA_VISIBLE_DEVICES"] = ""
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
16 os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
17 # loglevel : 0 all printed, 1 I not printed, 2 I and W not printed, 3 nothing printed
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
18 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
19
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
20
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
21 def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256):
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
22 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
23 Breaks down contigs into fragments
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
24 and uses pretrained neural networks to give predictions for fragments
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
25 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
26 try:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
27 seqs_ = list(SeqIO.parse(ds_path, "fasta"))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
28 except FileNotFoundError:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
29 raise Exception("test dataset was not found. Change ds variable")
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
30 out_table = {
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
31 "id": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
32 "length": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
33 "fragment": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
34 "pred_plant_5": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
35 "pred_vir_5": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
36 "pred_bact_5": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
37 "pred_plant_7": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
38 "pred_vir_7": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
39 "pred_bact_7": [],
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
40 }
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
41 if use_10:
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
42 out_table_ = {
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
43 "pred_plant_10": [],
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
44 "pred_vir_10": [],
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
45 "pred_bact_10": [],
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
46 }
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
47 out_table.update(out_table_)
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
48 if not seqs_:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
49 raise ValueError("All sequences were smaller than length of the model")
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
50 test_fragments = []
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
51 test_fragments_rc = []
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
52 for seq in seqs_:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
53 fragments_, fragments_rc, _ = pp.fragmenting([seq], length, max_gap=0.8,
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
54 sl_wind_step=int(length / 2))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
55 test_fragments.extend(fragments_)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
56 test_fragments_rc.extend(fragments_rc)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
57 for j in range(len(fragments_)):
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
58 out_table["id"].append(seq.id)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
59 out_table["length"].append(len(seq.seq))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
60 out_table["fragment"].append(j)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
61 test_encoded = pp.one_hot_encode(test_fragments)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
62 test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
63 if use_10:
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
64 zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10])
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
65 else:
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
66 zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7])
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
67 for model, s in zipped_models:
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
68 model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5"))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
69 prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
70 out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0]))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
71 out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1]))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
72 out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2]))
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
73
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
74 return pd.DataFrame(out_table)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
75
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
76
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
77 def predict_rf(df, rf_weights_path, length, use_10):
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
78 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
79 Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
80 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
81
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
82 clf = load(Path(rf_weights_path, f"RF_{length}.joblib"))
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
83 if use_10:
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
84 X = df[
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
85 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]]
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
86 else:
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
87 X = df[
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
88 ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]]
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
89 y_pred = clf.predict(X)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
90 mapping = {0: "plant", 1: "virus", 2: "bacteria"}
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
91 df["RF_decision"] = np.vectorize(mapping.get)(y_pred)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
92 prob_classes = clf.predict_proba(X)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
93 df["RF_pred_plant"] = prob_classes[..., 0]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
94 df["RF_pred_vir"] = prob_classes[..., 1]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
95 df["RF_pred_bact"] = prob_classes[..., 2]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
96 return df
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
97
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
98
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
99 def predict_contigs(df):
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
100 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
101 Based on predictions of predict_rf for fragments gives a final prediction for the whole contig
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
102 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
103 df = (
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
104 df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0)
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
105 )
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
106 df = df.reset_index()
2
ea2cccb9f73e planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit c3685ed6a70b47012b62b95a2a3db062bd3b7475
iuc
parents: 1
diff changeset
107 df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1).fillna(value=0)
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
108 conditions = [
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
109 (df['virus'] > df['plant']) & (df['virus'] > df['bacteria']),
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
110 (df['plant'] > df['virus']) & (df['plant'] > df['bacteria']),
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
111 (df['bacteria'] >= df['plant']) & (df['bacteria'] >= df['virus']),
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
112 ]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
113 choices = ['virus', 'plant', 'bacteria']
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
114 df['decision'] = np.select(conditions, choices, default='bacteria')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
115 df = df.loc[:, ['length', 'id', 'virus', 'plant', 'bacteria', 'decision']]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
116 df = df.rename(columns={'virus': '# viral fragments', 'bacteria': '# bacterial fragments', 'plant': '# plant fragments'})
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
117 df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# bacterial fragments'] + df['# plant fragments'])).round(3)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
118 df['# viral / # total * length'] = df['# viral / # total'] * df['length']
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
119 df = df.sort_values(by='# viral / # total * length', ascending=False)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
120 return df
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
121
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
122
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
123 def predict(test_ds, weights, out_path, return_viral, limit):
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
124 """Predicts viral contigs from the fasta file
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
125
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
126 test_ds: path to the input file with contigs in fasta format (str or list of str)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
127 weights: path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
128 out_path: path to the folder to store predictions (str)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
129 return_viral: whether to return contigs annotated as viral in separate fasta file (True/False)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
130 limit: Do predictions only for contigs > l. We suggest l=750. (int)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
131 """
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
132 test_ds = test_ds
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
133 if isinstance(test_ds, list):
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
134 pass
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
135 elif isinstance(test_ds, str):
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
136 test_ds = [test_ds]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
137 else:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
138 raise ValueError('test_ds was incorrectly assigned in the config file')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
139
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
140 assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist'
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
141 assert Path(weights).exists(), f'{weights} does not exist'
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
142 assert isinstance(limit, int), 'limit should be an integer'
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
143 Path(out_path).mkdir(parents=True, exist_ok=True)
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
144 use_10 = Path(weights, 'model_10_500.h5').exists()
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
145 for ts in test_ds:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
146 dfs_fr = []
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
147 dfs_cont = []
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
148 for l_ in 500, 1000:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
149 # print(f'starting prediction for {Path(ts).name} for fragment length {l_}')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
150 df = predict_nn(
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
151 ds_path=ts,
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
152 nn_weights_path=weights,
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
153 length=l_,
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
154 use_10=use_10
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
155 )
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
156 print(df)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
157 df = predict_rf(
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
158 df=df,
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
159 rf_weights_path=weights,
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
160 length=l_,
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
161 use_10=use_10
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
162 )
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
163 df = df.round(3)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
164 dfs_fr.append(df)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
165 df = predict_contigs(df)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
166 dfs_cont.append(df)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
167 # print('prediction finished')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
168 df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
169 df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
170 df = pd.concat([df_1000, df_500], ignore_index=True)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
171 pred_fr = Path(out_path, 'predicted_fragments.csv')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
172 df.to_csv(pred_fr)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
173
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
174 df_500 = dfs_cont[0][(dfs_cont[0]['length'] >= limit) & (dfs_cont[0]['length'] < 1500)]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
175 df_1000 = dfs_cont[1][(dfs_cont[1]['length'] >= 1500)]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
176 df = pd.concat([df_1000, df_500], ignore_index=True)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
177 pred_contigs = Path(out_path, 'predicted.csv')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
178 df.to_csv(pred_contigs)
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
179
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
180 if return_viral:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
181 viral_ids = list(df[df["decision"] == "virus"]["id"])
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
182 seqs_ = list(SeqIO.parse(ts, "fasta"))
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
183 viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids]
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
184 SeqIO.write(viral_seqs, Path(out_path, 'viral.fasta'), 'fasta')
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
185
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
186
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
187 if __name__ == '__main__':
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
188 parser = argparse.ArgumentParser()
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
189 parser.add_argument("--test_ds", help="path to the input file with contigs in fasta format (str or list of str)")
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
190 parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)")
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
191 parser.add_argument("--out_path", help="path to the folder to store predictions (str)")
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
192 parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)")
1
9b12bc1b1e2c planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
iuc
parents: 0
diff changeset
193 parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750)
0
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
194
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
195 args = parser.parse_args()
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
196 if args.test_ds:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
197 test_ds = args.test_ds
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
198 if args.weights:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
199 weights = args.weights
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
200 if args.out_path:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
201 out_path = args.out_path
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
202 if args.return_viral:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
203 return_viral = args.return_viral
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
204 if args.limit:
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
205 limit = args.limit
457fd8fd681a planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff changeset
206 predict(test_ds, weights, out_path, return_viral, limit)