#!/usr/bin/env python


Author -- Gundula Povysil
Contact --

Takes a tabular file with mutations and a BAM file as input and prints
all tags of reads that carry the mutation to a user specified output file.
Creates fastq file of reads of tags with mutation.

=======  ==========  =================  ================================
Version  Date        Author             Description
0.2.1    2019-10-27  Gundula Povysil    -
=======  ==========  =================  ================================

USAGE: python DCS_Mutations.tabular DCS.bam Aligned_Families.tabular Interesting_Reads.fastq tag_count_dict.json

import argparse
import json
import os
import re
import sys

import numpy as np
import pysam
from cyvcf2 import VCF

def make_argparser():
    parser = argparse.ArgumentParser(description='Takes a vcf file with mutations and a BAM file as input and prints all tags of reads that carry the mutation to a user specified output file and creates a fastq file of reads of tags with mutation.')
                        help='VCF file with DCS mutations.')
                        help='BAM file with aligned DCS reads.')
                        help='TABULAR file with aligned families.')
                        help='Output FASTQ file of reads with mutations.')
                        help='Output JSON file to store collected data.')
    parser.add_argument('--refalttiers', action="store_true",
                        help='Store also information about the reference allele.')
    return parser

def mut2read(argv):
    parser = make_argparser()
    args = parser.parse_args(argv[1:])

    file1 = args.mutFile
    file2 = args.bamFile
    file3 = args.familiesFile
    outfile = args.outputFastq
    json_file = args.outputJson
    refalttiers = args.refalttiers

    if os.path.isfile(file1) is False:
        sys.exit("Error: Could not find '{}'".format(file1))

    if os.path.isfile(file2) is False:
        sys.exit("Error: Could not find '{}'".format(file2))

    if os.path.isfile(file3) is False:
        sys.exit("Error: Could not find '{}'".format(file3))

    # read dcs bam file
#    pysam.index(file2)
    bam = pysam.AlignmentFile(file2, "rb")

    # get tags
    tag_dict = {}
    cvrg_dict = {}
    tag_dict_ref = {}

    for variant in VCF(file1):
        chrom = variant.CHROM
        stop_pos = variant.start
        ref = variant.REF
        if len(variant.ALT) == 0:
            alt = variant.ALT[0]
        alt = alt.upper()
        ref = ref.upper()
        if "N" in alt:  # skip indels with N in alt allele --> it is not an indel but just a mismatch at the position where the N is (checked this in IGV)
        chrom_stop_pos = str(chrom) + "#" + str(stop_pos) + "#" + ref + "#" + alt
        dcs_len = []
        for pileupcolumn in bam.pileup(chrom, stop_pos - 1, stop_pos + 1, max_depth=100000000):
            if pileupcolumn.reference_pos == stop_pos:
                count_alt = 0
                count_ref = 0
                count_indel = 0
                count_n = 0
                count_other = 0
                count_lowq = 0
                for pileupread in pileupcolumn.pileups:
                    if not pileupread.is_refskip:
                        if pileupread.is_del:
                            p = pileupread.query_position_or_next
                            e = p + len(alt) - 1
                            p = pileupread.query_position
                            e = p + len(alt)
                        s = p
                        split_cigar = re.split('(\d+)', pileupread.alignment.cigarstring)
                        if len(ref) < len(alt):
                            if "I" in split_cigar:
                                all_insertions = [inser_i for inser_i, ins in enumerate(split_cigar) if ins == "I"]
                                for ai in all_insertions:  # if multiple insertions in DCS
                                    ins_index = [int(ci) for ci in split_cigar[:ai - 1] if ci.isdigit()]
                                    ins_count = split_cigar[ai - 1]  # nr of insertions should match with alt allele
                                    if "I" in split_cigar and sum(ins_index) == p + 1 and len(alt) - 1 == int(ins_count):  # if position in read matches and length of insertion
                                        nuc = pileupread.alignment.query_sequence[s:e]
                                        nuc = pileupread.alignment.query_sequence[s]
                                nuc = pileupread.alignment.query_sequence[s]
                        elif len(ref) > len(alt):
                            ref_positions = pileupread.alignment.get_reference_positions(full_length=True)[s:p + len(ref)]
                            if "D" in split_cigar:
                                all_deletions = [del_i for del_i, dele in enumerate(split_cigar) if dele == "D"]
                                for di, ai in enumerate(all_deletions):  # if multiple insertions in DCS
                                    if di > 0:  # more than 1 deletion, don't count previous deletion to position
                                        all_deletions_mod = split_cigar[:ai - 1]
                                        prev_del_idx = [all_deletions_mod.index("D") - 1, all_deletions_mod.index("D")]
                                        split_cigar_no_prev = [ad for i, ad in enumerate(all_deletions_mod) if i not in prev_del_idx]
                                        del_index = [int(ci) for ci in split_cigar_no_prev[:ai - 1] if ci.isdigit()]
                                    else:  # first deletion in read, sum all previous (mis)matches and insertions to position
                                        del_index = [int(ci) for ci in split_cigar[:ai - 1] if ci.isdigit()]
                                    del_count = split_cigar[ai - 1]  # nr of deletions should match with ref allele
                                    if "D" in split_cigar and sum(del_index) == p + 1 and len(ref) - 1 == int(del_count):
                                        nuc = pileupread.alignment.query_sequence[s:e]
                                        if nuc == "":
                                            nuc = str(alt)
                                        nuc = pileupread.alignment.query_sequence[s:s + len(ref)]
                            elif len(ref_positions) < len(ref):  # DCS has reference but the position is at the very end of the DCS and therefore not the full reference positions are there
                                nuc = pileupread.alignment.get_reference_sequence()[s:s + len(ref)]
                                if nuc.upper() == ref[:len(nuc)]:
                                    nuc = str(ref)
                                nuc = pileupread.alignment.query_sequence[s:s + len(ref)]
                        else:  # SNV: query position is None if is_del or is_refskip is set.
                            nuc = pileupread.alignment.query_sequence[s]

                        nuc = nuc.upper()
                        tag = pileupread.alignment.query_name
                        if "_" in tag:
                            tag = re.split('_', tag)[0]

                        if nuc == alt:
                            count_alt += 1
                            if tag in tag_dict:
                                tag_dict[tag][chrom_stop_pos] = alt
                                tag_dict[tag] = {}
                                tag_dict[tag][chrom_stop_pos] = alt
                        elif nuc == ref:
                            count_ref += 1
                            if tag in tag_dict_ref:
                                tag_dict_ref[tag][chrom_stop_pos] = ref
                                tag_dict_ref[tag] = {}
                                tag_dict_ref[tag][chrom_stop_pos] = ref
                        elif nuc == "N":
                            count_n += 1
                        elif nuc == "lowQ":
                            count_lowq += 1
                            count_other += 1

                dcs_median = np.median(np.array(dcs_len))
                cvrg_dict[chrom_stop_pos] = (count_ref, count_alt, dcs_median)
                print("coverage at pos %s = %s, ref = %s, alt = %s, other bases = %s, N = %s, indel = %s, low quality = %s, median length of DCS = %s\n" %
                      (pileupcolumn.pos, count_ref + count_alt, count_ref, count_alt, count_other, count_n,
                       count_indel, count_lowq, dcs_median))

    with open(json_file, "w") as f:
        json.dump((tag_dict, cvrg_dict, tag_dict_ref), f)

    # create fastq from aligned reads
    with open(outfile, 'w') as out:
        with open(file3, 'r') as families:
            for line in families:
                line = line.rstrip('\n')
                splits = line.split('\t')
                tag = splits[0]
                if refalttiers is True:
                    if tag in tag_dict or tag in tag_dict_ref:
                        str1 = splits[4]
                        curr_seq = str1.replace("-", "")
                        str2 = splits[5]
                        curr_qual = str2.replace(" ", "")
                        out.write("@" + splits[0] + "." + splits[1] + "." + splits[2] + "\n")
                        out.write(curr_seq + "\n")
                        out.write("+" + "\n")
                        out.write(curr_qual + "\n")
                    if tag in tag_dict:
                        str1 = splits[4]
                        curr_seq = str1.replace("-", "")
                        str2 = splits[5]
                        curr_qual = str2.replace(" ", "")
                        out.write("@" + splits[0] + "." + splits[1] + "." + splits[2] + "\n")
                        out.write(curr_seq + "\n")
                        out.write("+" + "\n")
                        out.write(curr_qual + "\n")

if __name__ == '__main__':