view vcfs2fasta.py @ 22:96f393ad7fc6 draft default tip

Uploaded
author ulfschaefer
date Wed, 23 Dec 2015 04:50:58 -0500
parents
children
line wrap: on
line source

#!/usr/bin/env python
'''
Merge SNP data from multiple VCF files into a single fasta file.

Created on 5 Oct 2015

@author: alex
'''
import argparse
from collections import OrderedDict
from collections import defaultdict
import glob
import itertools
import logging
import os

from Bio import SeqIO
from bintrees import FastRBTree

# Try importing the matplotlib and numpy for stats.
try:
    from matplotlib import pyplot as plt
    import numpy
    can_stats = True
except ImportError:
    can_stats = False

import vcf

from phe.variant_filters import IUPAC_CODES


def plot_stats(pos_stats, total_samples, plots_dir="plots", discarded={}):
    if not os.path.exists(plots_dir):
        os.makedirs(plots_dir)

    for contig in pos_stats:
        plt.style.use('ggplot')

        x = numpy.array([pos for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
        y = numpy.array([ float(pos_stats[contig][pos]["mut"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, []) ])

        f, (ax1, ax2, ax3, ax4) = plt.subplots(4, sharex=True, sharey=True)
        f.set_size_inches(12, 15)
        ax1.plot(x, y, 'ro')
        ax1.set_title("Fraction of samples with SNPs")
        plt.ylim(0, 1.1)

        y = numpy.array([ float(pos_stats[contig][pos]["N"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
        ax2.plot(x, y, 'bo')
        ax2.set_title("Fraction of samples with Ns")

        y = numpy.array([ float(pos_stats[contig][pos]["mix"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
        ax3.plot(x, y, 'go')
        ax3.set_title("Fraction of samples with mixed bases")

        y = numpy.array([ float(pos_stats[contig][pos]["gap"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
        ax4.plot(x, y, 'yo')
        ax4.set_title("Fraction of samples with uncallable genotype (gap)")

        contig = contig.replace("/", "-")
        plt.savefig(os.path.join(plots_dir, "%s.png" % contig), dpi=100)

def get_mixture(record, threshold):
    mixtures = {}
    try:
        if len(record.samples[0].data.AD) > 1:

            total_depth = sum(record.samples[0].data.AD)
            # Go over all combinations of touples.
            for comb in itertools.combinations(range(0, len(record.samples[0].data.AD)), 2):
                i = comb[0]
                j = comb[1]

                alleles = list()

                if 0 in comb:
                    alleles.append(str(record.REF))

                if i != 0:
                    alleles.append(str(record.ALT[i - 1]))
                    mixture = record.samples[0].data.AD[i]
                if j != 0:
                    alleles.append(str(record.ALT[j - 1]))
                    mixture = record.samples[0].data.AD[j]

                ratio = float(mixture) / total_depth
                if ratio == 1.0:
                    logging.debug("This is only designed for mixtures! %s %s %s %s", record, ratio, record.samples[0].data.AD, record.FILTER)

                    if ratio not in mixtures:
                        mixtures[ratio] = []
                    mixtures[ratio].append(alleles.pop())

                elif ratio >= threshold:
                    try:
                        code = IUPAC_CODES[frozenset(alleles)]
                        if ratio not in mixtures:
                            mixtures[ratio] = []
                            mixtures[ratio].append(code)
                    except KeyError:
                        logging.warn("Could not retrieve IUPAC code for %s from %s", alleles, record)
    except AttributeError:
        mixtures = {}

    return mixtures

def print_stats(stats, pos_stats, total_vars):
    for contig in stats:
        for sample, info in stats[contig].items():
            print "%s,%i,%i" % (sample, len(info.get("n_pos", [])), total_vars)

    for contig in stats:
        for pos, info in pos_stats[contig].iteritems():
            print "%s,%i,%i,%i,%i" % (contig, pos, info.get("N", "NA"), info.get("-", "NA"), info.get("mut", "NA"))


def get_args():
    args = argparse.ArgumentParser(description="Combine multiple VCFs into a single FASTA file.")

    group = args.add_mutually_exclusive_group(required=True)
    group.add_argument("--directory", "-d", help="Path to the directory with .vcf files.")
    group.add_argument("--input", "-i", type=str, nargs='+', help="List of VCF files to process.")

    args.add_argument("--out", "-o", required=True, help="Path to the output FASTA file.")

    args.add_argument("--with-mixtures", type=float, help="Specify this option with a threshold to output mixtures above this threshold.")

    args.add_argument("--column-Ns", type=float, help="Keeps columns with fraction of Ns above specified threshold.")

    args.add_argument("--sample-Ns", type=float, help="Keeps samples with fraction of Ns above specified threshold.")

    args.add_argument("--reference", type=str, help="If path to reference specified (FASTA), then whole genome will be written.")

    group = args.add_mutually_exclusive_group()

    group.add_argument("--include")
    group.add_argument("--exclude")

    args.add_argument("--with-stats", help="If a path is specified, then position of the outputed SNPs is stored in this file. Requires mumpy and matplotlib.")
    args.add_argument("--plots-dir", default="plots", help="Where to write summary plots on SNPs extracted. Requires mumpy and matplotlib.")

    args.add_argument("--debug", action="store_true", help="More verbose logging (default: turned off).")
    args.add_argument("--local", action="store_true", help="Re-read the VCF instead of storing it in memory.")

    return args.parse_args()

def main():
    """
    Process VCF files and merge them into a single fasta file.
    """
    args = get_args()

    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)

    contigs = list()

    sample_stats = dict()

    # All positions available for analysis.
    avail_pos = dict()
    # Stats about each position in each chromosome.
    pos_stats = dict()
    indel_summary = defaultdict(int)
    # Cached version of the data.
    vcf_data = dict()
    mixtures = dict()

    empty_tree = FastRBTree()

    exclude = False
    include = False

    if args.reference:
        ref_seq = OrderedDict()
        with open(args.reference) as fp:
            for record in SeqIO.parse(fp, "fasta"):
                ref_seq[record.id] = str(record.seq)

        args.reference = ref_seq

    if args.exclude or args.include:
        pos = {}
        chr_pos = []
        bed_file = args.include if args.include is not None else args.exclude

        with open(bed_file) as fp:
            for line in fp:
                data = line.strip().split("\t")

                chr_pos += [ (i, False,) for i in xrange(int(data[1]), int(data[2]) + 1)]

                if data[0] not in pos:
                    pos[data[0]] = []

                pos[data[0]] += chr_pos


        pos = {chrom: FastRBTree(l) for chrom, l in pos.items()}

        if args.include:
            include = pos
        else:
            exclude = pos


    if args.directory is not None and args.input is None:
        args.input = glob.glob(os.path.join(args.directory, "*.filtered.vcf"))

    # First pass to get the references and the positions to be analysed.
    for vcf_in in args.input:
        sample_name, _ = os.path.splitext(os.path.basename(vcf_in))
        vcf_data[vcf_in] = list()
        reader = vcf.Reader(filename=vcf_in)

        for record in reader:
            if include and include.get(record.CHROM, empty_tree).get(record.POS, True) or exclude and not exclude.get(record.CHROM, empty_tree).get(record.POS, True):
                continue

            if not args.local:
                vcf_data[vcf_in].append(record)

            if record.CHROM not in contigs:
                contigs.append(record.CHROM)
                avail_pos[record.CHROM] = FastRBTree()
                mixtures[record.CHROM] = {}
                sample_stats[record.CHROM] = {}

            if sample_name not in mixtures[record.CHROM]:
                mixtures[record.CHROM][sample_name] = FastRBTree()

            if sample_name not in sample_stats[record.CHROM]:
                sample_stats[record.CHROM][sample_name] = {}

            if not record.FILTER:
                if record.is_snp:
                    if record.POS in avail_pos[record.CHROM] and avail_pos[record.CHROM][record.POS] != record.REF:
                        logging.critical("SOMETHING IS REALLY WRONG because reference for the same position is DIFFERENT! %s in %s", record.POS, vcf_in)
                        return 2

                    if record.CHROM not in pos_stats:
                        pos_stats[record.CHROM] = {}

                    avail_pos[record.CHROM].insert(record.POS, str(record.REF))
                    pos_stats[record.CHROM][record.POS] = {"N":0, "-": 0, "mut": 0, "mix": 0, "gap": 0}

            elif args.with_mixtures and record.is_snp:
                mix = get_mixture(record, args.with_mixtures)

                for ratio, code in mix.items():
                    for c in code:
                        avail_pos[record.CHROM].insert(record.POS, str(record.REF))
                        if record.CHROM not in pos_stats:
                            pos_stats[record.CHROM] = {}
                        pos_stats[record.CHROM][record.POS] = {"N": 0, "-": 0, "mut": 0, "mix": 0, "gap": 0}

                        if sample_name not in mixtures[record.CHROM]:
                            mixtures[record.CHROM][sample_name] = FastRBTree()

                        mixtures[record.CHROM][sample_name].insert(record.POS, c)
            elif not record.is_deletion and not record.is_indel:
                if record.CHROM not in pos_stats:
                    pos_stats[record.CHROM] = {}
                pos_stats[record.CHROM][record.POS] = {"N": 0, "-": 0, "mut": 0, "mix": 0, "gap": 0}
                avail_pos[record.CHROM].insert(record.POS, str(record.REF))
            else:
                logging.debug("Discarding %s from %s as DEL and/or INDEL", record.POS, vcf_in)
                indel_summary[vcf_in] += 1
                try:
                    vcf_data[vcf_in].remove(record)
                except ValueError:
                    pass


    all_data = { contig: {} for contig in contigs}
    samples = []

    for vcf_in in args.input:

        sample_seq = ""
        sample_name, _ = os.path.splitext(os.path.basename(vcf_in))
        samples.append(sample_name)

        # Initialise the data for this sample to be REF positions.
        for contig in contigs:
            all_data[contig][sample_name] = { pos: avail_pos[contig][pos] for pos in avail_pos[contig] }

        # Re-read data from VCF if local is specified, otherwise get it from memory.
        iterator = vcf.Reader(filename=vcf_in) if args.local else vcf_data[vcf_in]
        for record in iterator:
            # Array of filters that have been applied.
            filters = []

            # If position is our available position.
            if avail_pos.get(record.CHROM, empty_tree).get(record.POS, False):
                if not record.FILTER:
                    if record.is_snp:
                        if len(record.ALT) > 1:
                            logging.info("POS %s passed filters but has multiple alleles. Inserting N")
                            all_data[record.CHROM][sample_name][record.POS] = "N"
                        else:
                            all_data[record.CHROM][sample_name][record.POS] = record.ALT[0].sequence
                            pos_stats[record.CHROM][record.POS]["mut"] += 1
                else:

                    # Currently we are only using first filter to call consensus.
                    extended_code = mixtures[record.CHROM][sample_name].get(record.POS, "N")

#                     extended_code = PHEFilterBase.call_concensus(record)

                    # Calculate the stats
                    if extended_code == "N":
                        pos_stats[record.CHROM][record.POS]["N"] += 1

                        if "n_pos" not in sample_stats[record.CHROM][sample_name]:
                            sample_stats[record.CHROM][sample_name]["n_pos"] = []
                        sample_stats[record.CHROM][sample_name]["n_pos"].append(record.POS)

                    elif extended_code == "-":
                        pos_stats[record.CHROM][record.POS]["-"] += 1
                    else:
                        pos_stats[record.CHROM][record.POS]["mix"] += 1
#                         print "Good mixture %s: %i (%s)" % (sample_name, record.POS, extended_code)
                    # Record if there was uncallable genoty/gap in the data.
                    if record.samples[0].data.GT == "./.":
                        pos_stats[record.CHROM][record.POS]["gap"] += 1

                    # Save the extended code of the SNP.
                    all_data[record.CHROM][sample_name][record.POS] = extended_code
        del vcf_data[vcf_in]

    # Output the data to the fasta file.
    # The data is already aligned so simply output it.
    discarded = {}

    if args.reference:
        # These should be in the same order as the order in reference.
        contigs = args.reference.keys()

    if args.sample_Ns:
        delete_samples = []
        for contig in contigs:
            for sample in samples:

                # Skip if the contig not in sample_stats
                if contig not in sample_stats:
                    continue

                sample_n_ratio = float(len(sample_stats[contig][sample]["n_pos"])) / len(avail_pos[contig])
                if sample_n_ratio > args.sample_Ns:
                    for pos in sample_stats[contig][sample]["n_pos"]:
                        pos_stats[contig][pos]["N"] -= 1

                    logging.info("Removing %s due to high Ns in sample: %s", sample , sample_n_ratio)

                    delete_samples.append(sample)

        samples = [sample for sample in samples if sample not in delete_samples]
    snp_positions = []
    with open(args.out, "w") as fp:

        for sample in samples:
            sample_seq = ""
            for contig in contigs:
                if contig in avail_pos:
                    if args.reference:
                        positions = xrange(1, len(args.reference[contig]) + 1)
                    else:
                        positions = avail_pos[contig].keys()
                    for pos in positions:
                        if pos in avail_pos[contig]:
                            if not args.column_Ns or float(pos_stats[contig][pos]["N"]) / len(samples) < args.column_Ns and \
                                float(pos_stats[contig][pos]["-"]) / len(samples) < args.column_Ns:
                                sample_seq += all_data[contig][sample][pos]
                            else:
                                if contig not in discarded:
                                    discarded[contig] = []
                                discarded[contig].append(pos)
                        elif args.reference:
                            sample_seq += args.reference[contig][pos - 1]
                elif args.reference:
                    sample_seq += args.reference[contig]

            fp.write(">%s\n%s\n" % (sample, sample_seq))
        # Do the same for reference data.
        ref_snps = ""

        for contig in contigs:
            if contig in avail_pos:
                if args.reference:
                    positions = xrange(1, len(args.reference[contig]) + 1)
                else:
                    positions = avail_pos[contig].keys()
                for pos in positions:
                    if pos in avail_pos[contig]:
                        if not args.column_Ns or float(pos_stats[contig][pos]["N"]) / len(samples) < args.column_Ns and \
                                float(pos_stats[contig][pos]["-"]) / len(samples) < args.column_Ns:

                            ref_snps += str(avail_pos[contig][pos])
                            snp_positions.append((contig, pos,))
                    elif args.reference:
                        ref_snps += args.reference[contig][pos - 1]
            elif args.reference:
                    ref_snps += args.reference[contig]

        fp.write(">reference\n%s\n" % ref_snps)

    if can_stats and args.with_stats:
        with open(args.with_stats, "wb") as fp:
            fp.write("contig\tposition\tmutations\tn_frac\n")
            for values in snp_positions:
                fp.write("%s\t%s\t%s\t%s\n" % (values[0],
                                             values[1],
                                             float(pos_stats[values[0]][values[1]]["mut"]) / len(args.input),
                                             float(pos_stats[values[0]][values[1]]["N"]) / len(args.input)))
        plot_stats(pos_stats, len(samples), discarded=discarded, plots_dir=os.path.abspath(args.plots_dir))
    # print_stats(sample_stats, pos_stats, total_vars=len(avail_pos[contig]))

    total_discarded = 0
    for _, i in discarded.items():
        total_discarded += len(i)
    logging.info("Discarded total of %i poor quality columns", float(total_discarded) / len(args.input))
    logging.info("Samples with indels:")
    for sample, count in indel_summary.iteritems():
        logging.info("%s\t%s", sample, count)
    return 0

if __name__ == '__main__':
    import time

#     with PyCallGraph(output=graphviz):
#     T0 = time.time()
    r = main()
#     T1 = time.time()

#     print "Time taken: %i" % (T1 - T0)
    exit(r)