Repository 'virhunter'
hg clone https://toolshed.g2.bx.psu.edu/repos/iuc/virhunter

Changeset 1:9b12bc1b1e2c (2022-11-30)
Previous changeset 0:457fd8fd681a (2022-11-09) Next changeset 2:ea2cccb9f73e (2023-01-05)
Commit message:
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
modified:
macros.xml
predict.py
b
diff -r 457fd8fd681a -r 9b12bc1b1e2c macros.xml
--- a/macros.xml Wed Nov 09 12:19:26 2022 +0000
+++ b/macros.xml Wed Nov 30 17:31:52 2022 +0000
b
@@ -1,6 +1,6 @@
 <macros>
     <token name="@TOOL_VERSION@">1.0.0</token>
-    <token name="@VERSION_SUFFIX@">0</token>
+    <token name="@VERSION_SUFFIX@">1</token>
     <xml name="requirements">
         <requirements>
             <requirement type="package" version="1.23.3">numpy</requirement>
@@ -14,7 +14,7 @@
     </xml>
     <xml name="citations">
         <citations>
-            <citation type="doi">10.1038/s41467-019-12528-4</citation>
+            <citation type="doi">10.3389/fbinf.2022.867111</citation>
         </citations>
     </xml>
-</macros>
\ No newline at end of file
+</macros>
b
diff -r 457fd8fd681a -r 9b12bc1b1e2c predict.py
--- a/predict.py Wed Nov 09 12:19:26 2022 +0000
+++ b/predict.py Wed Nov 30 17:31:52 2022 +0000
[
@@ -9,7 +9,7 @@
 import pandas as pd
 from Bio import SeqIO
 from joblib import load
-from models import model_5, model_7
+from models import model_10, model_5, model_7
 from utils import preprocess as pp
 
 os.environ["CUDA_VISIBLE_DEVICES"] = ""
@@ -18,7 +18,7 @@
 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
 
 
-def predict_nn(ds_path, nn_weights_path, length, batch_size=256):
+def predict_nn(ds_path, nn_weights_path, length, use_10, batch_size=256):
     """
     Breaks down contigs into fragments
     and uses pretrained neural networks to give predictions for fragments
@@ -37,10 +37,14 @@
         "pred_plant_7": [],
         "pred_vir_7": [],
         "pred_bact_7": [],
-        # "pred_plant_10": [],
-        # "pred_vir_10": [],
-        # "pred_bact_10": [],
     }
+    if use_10:
+        out_table_ = {
+            "pred_plant_10": [],
+            "pred_vir_10": [],
+            "pred_bact_10": [],
+        }
+        out_table.update(out_table_)
     if not seqs_:
         raise ValueError("All sequences were smaller than length of the model")
     test_fragments = []
@@ -56,24 +60,32 @@
             out_table["fragment"].append(j)
     test_encoded = pp.one_hot_encode(test_fragments)
     test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
-    # for model, s in zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10]):
-    for model, s in zip([model_5.model(length), model_7.model(length)], [5, 7]):
+    if use_10:
+        zipped_models = zip([model_5.model(length), model_7.model(length), model_10.model(length)], [5, 7, 10])
+    else:
+        zipped_models = zip([model_5.model(length), model_7.model(length)], [5, 7])
+    for model, s in zipped_models:
         model.load_weights(Path(nn_weights_path, f"model_{s}_{length}.h5"))
         prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
         out_table[f"pred_plant_{s}"].extend(list(prediction[..., 0]))
         out_table[f"pred_vir_{s}"].extend(list(prediction[..., 1]))
         out_table[f"pred_bact_{s}"].extend(list(prediction[..., 2]))
+
     return pd.DataFrame(out_table)
 
 
-def predict_rf(df, rf_weights_path, length):
+def predict_rf(df, rf_weights_path, length, use_10):
     """
     Using predictions by predict_nn and weights of a trained RF classifier gives a single prediction for a fragment
     """
 
     clf = load(Path(rf_weights_path, f"RF_{length}.joblib"))
-    X = df[["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7"]]
-    # X = ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]]
+    if use_10:
+        X = df[
+            ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", "pred_plant_10", "pred_vir_10", ]]
+    else:
+        X = df[
+            ["pred_plant_5", "pred_vir_5", "pred_plant_7", "pred_vir_7", ]]
     y_pred = clf.predict(X)
     mapping = {0: "plant", 1: "virus", 2: "bacteria"}
     df["RF_decision"] = np.vectorize(mapping.get)(y_pred)
@@ -89,9 +101,7 @@
     Based on predictions of predict_rf for fragments gives a final prediction for the whole contig
     """
     df = (
-        df.groupby(["id", "length", 'RF_decision'], sort=False)
-        .size()
-        .unstack(fill_value=0)
+        df.groupby(["id", "length", 'RF_decision'], sort=False).size().unstack(fill_value=0)
     )
     df = df.reset_index()
     df = df.reindex(['length', 'id', 'virus', 'plant', 'bacteria'], axis=1)
@@ -131,7 +141,7 @@
     assert Path(weights).exists(), f'{weights} does not exist'
     assert isinstance(limit, int), 'limit should be an integer'
     Path(out_path).mkdir(parents=True, exist_ok=True)
-
+    use_10 = Path(weights, 'model_10_500.h5').exists()
     for ts in test_ds:
         dfs_fr = []
         dfs_cont = []
@@ -141,12 +151,14 @@
                 ds_path=ts,
                 nn_weights_path=weights,
                 length=l_,
+                use_10=use_10
             )
             print(df)
             df = predict_rf(
                 df=df,
                 rf_weights_path=weights,
                 length=l_,
+                use_10=use_10
             )
             df = df.round(3)
             dfs_fr.append(df)
@@ -178,7 +190,7 @@
     parser.add_argument("--weights", help="path to the folder containing weights for NN and RF modules trained on 500 and 1000 fragment lengths (str)")
     parser.add_argument("--out_path", help="path to the folder to store predictions (str)")
     parser.add_argument("--return_viral", help="whether to return contigs annotated as viral in separate fasta file (True/False)")
-    parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int)
+    parser.add_argument("--limit", help="Do predictions only for contigs > l. We suggest l=750. (int)", type=int, default=750)
 
     args = parser.parse_args()
     if args.test_ds: