#!/usr/bin/env python

import sys
import argparse as ap
from ete3 import NCBITaxa

def parse_kraken_results(db, report, taxid):
    kraken = {}  # Store classification for each read
    ncbi = NCBITaxa(db)
    descendants = set(ncbi.get_descendant_taxa(taxid))

    with open(report, 'r') as classification:
        for line in classification:
            classified, read_id, tax_id, length, details = line.strip().split("\t")
            kraken[read_id] = tax_id

    # Classify each read
    kraken_class = {}

    for read_id, tax_id in kraken.items():
        if tax_id == 0:
            kraken_class[read_id] = "unclassified"
        elif int(tax_id) in descendants or int(tax_id) == int(taxid):
            kraken_class[read_id] = "target"
        else:
            kraken_class[read_id] = "other"

    return kraken_class


def kraken_trim(db, report, taxid, paired, fastq, fastq2):
    kraken = parse_kraken_results(db, report, taxid)

    # Write new fastq file
    if paired:
        files = [fastq, fastq2]
    else:
        files = [fastq]
    for index,fastq_in in enumerate(files):
        with open(fastq_in, 'r') as f_in:
            with open('input_%d.fastq' % (index+1), 'w') as f_out:
                 for line in f_in:
                    # Split ID with space, then remove "/1" or "/2" if it exists and ignore initial @
                    read_id = line.split(" ")[0].split("/")[0][1:]
                    if read_id in kraken and kraken[read_id] != "other":
                        f_out.write(line)
                        for i in range(3):
                            f_out.write(f_in.readline())
                    else:
                        for i in range(3):
                            f_in.readline()

parser = ap.ArgumentParser(prog='kraken_trim', conflict_handler='resolve',
                           description="Trims contaminated reads using Kraken reports")

input = parser.add_argument_group('Input', '')
input.add_argument('db', help="sqlite formatted ETE3 taxa database")
input.add_argument('report', help="Kraken report")
input.add_argument('taxid', type=int, help="Target taxonomic ID")
input.add_argument('fastq', help="FASTQ file")
input.add_argument('fastq2', nargs='?', help="Reverse FASTQ mate ")
input.add_argument('--p', '--paired', action='store_true', help="Paired FASTQ files")

if len(sys.argv) == 1:
    parser.print_usage()
    sys.exit(1)

args = parser.parse_args()

kraken_trim(args.db, args.report, args.taxid, args.p, args.fastq, args.fastq2)