view predict.py @ 0:b856d3d95413 draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/decontaminator commit 3f8e87001f3dfe7d005d0765aeaa930225c93b72
author iuc
date Mon, 09 Jan 2023 13:27:09 +0000
parents
children
line wrap: on
line source

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Credits: Grigorii Sukhorukov, Macha Nikolski
import argparse
import os
from pathlib import Path

import numpy as np
import pandas as pd
from Bio import SeqIO
from models import model_10
from utils import preprocess as pp

os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_XLA_FLAGS"] = "--tf_xla_cpu_global_jit"
# loglevel :
# 0 all printed
# 1 I not printed
# 2 I and W not printed
# 3 nothing printed
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'


def predict_nn(ds_path, nn_weights_path, length, batch_size=256):
    """
    Breaks down contigs into fragments
    and uses pretrained neural networks to give predictions for fragments
    """
    try:
        seqs_ = list(SeqIO.parse(ds_path, "fasta"))
    except FileNotFoundError:
        raise Exception("test dataset was not found. Change ds variable")

    out_table = {
        "id": [],
        "length": [],
        "fragment": [],
        "pred_vir": [],
        "pred_other": [],
    }
    if not seqs_:
        raise ValueError("All sequences were smaller than length of the model")
    test_fragments = []
    test_fragments_rc = []
    for seq in seqs_:
        fragments_, fragments_rc, _ = \
            pp.fragmenting(
                [seq],
                length,
                max_gap=0.8,
                sl_wind_step=int(length / 2)
            )
        test_fragments.extend(fragments_)
        test_fragments_rc.extend(fragments_rc)
        for j in range(len(fragments_)):
            out_table["id"].append(seq.id)
            out_table["length"].append(len(seq.seq))
            out_table["fragment"].append(j)
    test_encoded = pp.one_hot_encode(test_fragments)
    test_encoded_rc = pp.one_hot_encode(test_fragments_rc)
    model = model_10.model(length)
    model.load_weights(Path(nn_weights_path, f"model_{length}.h5"))
    prediction = model.predict([test_encoded, test_encoded_rc], batch_size)
    out_table['pred_vir'].extend(list(prediction[..., 1]))
    out_table['pred_other'].extend(list(prediction[..., 0]))
    print('Exporting predictions to csv file')
    df = pd.DataFrame(out_table)
    df['NN_decision'] = np.where(df['pred_vir'] > df['pred_other'], 'virus', 'other')
    return df


def predict_test(ds_path, length):
    """
    Breaks down contigs into fragments
    and gives 1 as prediction to all fragments
    use only for testing!
    """
    try:
        seqs_ = list(SeqIO.parse(ds_path, "fasta"))
    except FileNotFoundError:
        raise Exception("test dataset was not found. Change ds variable")

    out_table = {
        "id": [],
        "length": [],
        "fragment": [],
    }
    if not seqs_:
        raise ValueError("All sequences were smaller than length of the model")
    for seq in seqs_:
        fragments_, fragments_rc, _ = \
            pp.fragmenting(
                [seq],
                length,
                max_gap=0.8,
                sl_wind_step=int(length / 2)
            )
        for j in range(len(fragments_)):
            out_table["id"].append(seq.id)
            out_table["length"].append(len(seq.seq))
            out_table["fragment"].append(j)
    print('Exporting predictions to tsv file')
    df = pd.DataFrame(out_table)
    df['pred_vir'] = 1
    df['pred_other'] = 0
    df['NN_decision'] = 'virus'
    return df


def predict_contigs(df):
    """
    Based on predictions of predict_rf for fragments
    gives a final prediction for the whole contig
    """
    df = (
        df.groupby(
            ["id",
             "length",
             'NN_decision'],
            sort=False
        ).size().unstack(fill_value=0)
    )
    df = df.reset_index()
    df = df.reindex(
        ['length', 'id', 'virus', 'other', ],
        axis=1
    ).fillna(value=0)
    df['decision'] = np.where(df['virus'] >= df['other'], 'virus', 'other')
    df = df.sort_values(by='length', ascending=False)
    df = df.loc[:, ['length', 'id', 'virus', 'other', 'decision']]
    df = df.rename(
        columns={
            'virus': '# viral fragments',
            'other': '# other fragments',
        }
    )
    df['# viral / # total'] = (df['# viral fragments'] / (df['# viral fragments'] + df['# other fragments'])).round(3)
    df['# viral / # total * length'] = df['# viral / # total'] * df['length']
    df = df.sort_values(by='# viral / # total * length', ascending=False)
    return df


def predict(test_ds, weights, out_path, return_viral=True):
    """filters out contaminant contigs from the fasta file.

    test_ds: path to the input file with
        contigs in fasta format (str or list of str)
    weights: path to the folder containing weights
        for NN and RF modules trained on 500 and 1000 fragment lengths (str)
    out_path: path to the folder to store predictions (str)
    return_viral: whether to return contigs annotated as
        viral in separate fasta file (True/False)
    """

    test_ds = test_ds
    if isinstance(test_ds, list):
        pass
    elif isinstance(test_ds, str):
        test_ds = [test_ds]
    else:
        raise ValueError('test_ds was incorrectly assigned in the config file')

    assert Path(test_ds[0]).exists(), f'{test_ds[0]} does not exist'
    # assert Path(weights).exists(), f'{weights} does not exist'
    limit = 0
    Path(out_path).mkdir(parents=True, exist_ok=True)

    # parameter to activate test function. Only for debugging on github
    # test is launched when the weights directory is empty
    use_test_f = not Path(weights, 'model_1000.h5').exists()
    for ts in test_ds:
        dfs_fr = []
        dfs_cont = []
        for l_ in 500, 1000:
            print(f'starting prediction for {Path(ts).name} '
                  f'for fragment length {l_}')
            if use_test_f:
                df = predict_test(ds_path=ts, length=l_, )
            else:
                df = predict_nn(
                    ds_path=ts,
                    nn_weights_path=weights,
                    length=l_,
                )
            df = df.round(3)
            dfs_fr.append(df)
            df = predict_contigs(df)
            dfs_cont.append(df)
            print('prediction finished')
        df_500 = dfs_fr[0][(dfs_fr[0]['length'] >= limit) & (dfs_fr[0]['length'] < 1500)]
        df_1000 = dfs_fr[1][(dfs_fr[1]['length'] >= 1500)]
        df = pd.concat([df_1000, df_500], ignore_index=True)
        pred_fr = Path(out_path, "predicted_fragments.tsv")
        df.to_csv(pred_fr, sep='\t')

        df_500 = dfs_cont[0][(dfs_cont[0]['length']
                              >= limit) & (dfs_cont[0]['length'] < 1500)]
        df_1000 = dfs_cont[1][(dfs_cont[1]['length']
                               >= 1500)]
        df = pd.concat([df_1000, df_500], ignore_index=True)
        pred_contigs = Path(out_path, "predicted.tsv")
        df.to_csv(pred_contigs, sep='\t')

        if return_viral:
            viral_ids = list(df[df["decision"] == "virus"]["id"])
            seqs_ = list(SeqIO.parse(ts, "fasta"))
            viral_seqs = [s_ for s_ in seqs_ if s_.id in viral_ids]
            SeqIO.write(
                viral_seqs,
                Path(
                    out_path,
                    "viral.fasta"), 'fasta')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--test_ds", help="path to the input "
                                          "file with contigs "
                                          "in fasta format "
                                          "(str or list of str)")
    parser.add_argument("--weights", help="path to the folder containing "
                                          "weights for NN and RF modules "
                                          "trained on 500 and 1000 "
                                          "fragment lengths (str)")
    parser.add_argument("--out_path", help="path to the folder to store "
                                           "predictions (str)")
    parser.add_argument("--return_viral", help="whether to return "
                                               "contigs annotated "
                                               "as viral in separate "
                                               "fasta file (True/False)")

    args = parser.parse_args()
    if args.test_ds:
        test_ds = args.test_ds
    if args.weights:
        weights = args.weights
    if args.out_path:
        out_path = args.out_path
    if args.return_viral:
        return_viral = args.return_viral

    predict(test_ds, weights, out_path, return_viral)