diff vcfs2fasta.py @ 22:96f393ad7fc6 draft default tip

Uploaded
author ulfschaefer
date Wed, 23 Dec 2015 04:50:58 -0500
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/vcfs2fasta.py	Wed Dec 23 04:50:58 2015 -0500
@@ -0,0 +1,437 @@
+#!/usr/bin/env python
+'''
+Merge SNP data from multiple VCF files into a single fasta file.
+
+Created on 5 Oct 2015
+
+@author: alex
+'''
+import argparse
+from collections import OrderedDict
+from collections import defaultdict
+import glob
+import itertools
+import logging
+import os
+
+from Bio import SeqIO
+from bintrees import FastRBTree
+
+# Try importing the matplotlib and numpy for stats.
+try:
+    from matplotlib import pyplot as plt
+    import numpy
+    can_stats = True
+except ImportError:
+    can_stats = False
+
+import vcf
+
+from phe.variant_filters import IUPAC_CODES
+
+
+def plot_stats(pos_stats, total_samples, plots_dir="plots", discarded={}):
+    if not os.path.exists(plots_dir):
+        os.makedirs(plots_dir)
+
+    for contig in pos_stats:
+        plt.style.use('ggplot')
+
+        x = numpy.array([pos for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
+        y = numpy.array([ float(pos_stats[contig][pos]["mut"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, []) ])
+
+        f, (ax1, ax2, ax3, ax4) = plt.subplots(4, sharex=True, sharey=True)
+        f.set_size_inches(12, 15)
+        ax1.plot(x, y, 'ro')
+        ax1.set_title("Fraction of samples with SNPs")
+        plt.ylim(0, 1.1)
+
+        y = numpy.array([ float(pos_stats[contig][pos]["N"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
+        ax2.plot(x, y, 'bo')
+        ax2.set_title("Fraction of samples with Ns")
+
+        y = numpy.array([ float(pos_stats[contig][pos]["mix"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
+        ax3.plot(x, y, 'go')
+        ax3.set_title("Fraction of samples with mixed bases")
+
+        y = numpy.array([ float(pos_stats[contig][pos]["gap"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
+        ax4.plot(x, y, 'yo')
+        ax4.set_title("Fraction of samples with uncallable genotype (gap)")
+
+        contig = contig.replace("/", "-")
+        plt.savefig(os.path.join(plots_dir, "%s.png" % contig), dpi=100)
+
+def get_mixture(record, threshold):
+    mixtures = {}
+    try:
+        if len(record.samples[0].data.AD) > 1:
+
+            total_depth = sum(record.samples[0].data.AD)
+            # Go over all combinations of touples.
+            for comb in itertools.combinations(range(0, len(record.samples[0].data.AD)), 2):
+                i = comb[0]
+                j = comb[1]
+
+                alleles = list()
+
+                if 0 in comb:
+                    alleles.append(str(record.REF))
+
+                if i != 0:
+                    alleles.append(str(record.ALT[i - 1]))
+                    mixture = record.samples[0].data.AD[i]
+                if j != 0:
+                    alleles.append(str(record.ALT[j - 1]))
+                    mixture = record.samples[0].data.AD[j]
+
+                ratio = float(mixture) / total_depth
+                if ratio == 1.0:
+                    logging.debug("This is only designed for mixtures! %s %s %s %s", record, ratio, record.samples[0].data.AD, record.FILTER)
+
+                    if ratio not in mixtures:
+                        mixtures[ratio] = []
+                    mixtures[ratio].append(alleles.pop())
+
+                elif ratio >= threshold:
+                    try:
+                        code = IUPAC_CODES[frozenset(alleles)]
+                        if ratio not in mixtures:
+                            mixtures[ratio] = []
+                            mixtures[ratio].append(code)
+                    except KeyError:
+                        logging.warn("Could not retrieve IUPAC code for %s from %s", alleles, record)
+    except AttributeError:
+        mixtures = {}
+
+    return mixtures
+
+def print_stats(stats, pos_stats, total_vars):
+    for contig in stats:
+        for sample, info in stats[contig].items():
+            print "%s,%i,%i" % (sample, len(info.get("n_pos", [])), total_vars)
+
+    for contig in stats:
+        for pos, info in pos_stats[contig].iteritems():
+            print "%s,%i,%i,%i,%i" % (contig, pos, info.get("N", "NA"), info.get("-", "NA"), info.get("mut", "NA"))
+
+
+def get_args():
+    args = argparse.ArgumentParser(description="Combine multiple VCFs into a single FASTA file.")
+
+    group = args.add_mutually_exclusive_group(required=True)
+    group.add_argument("--directory", "-d", help="Path to the directory with .vcf files.")
+    group.add_argument("--input", "-i", type=str, nargs='+', help="List of VCF files to process.")
+
+    args.add_argument("--out", "-o", required=True, help="Path to the output FASTA file.")
+
+    args.add_argument("--with-mixtures", type=float, help="Specify this option with a threshold to output mixtures above this threshold.")
+
+    args.add_argument("--column-Ns", type=float, help="Keeps columns with fraction of Ns above specified threshold.")
+
+    args.add_argument("--sample-Ns", type=float, help="Keeps samples with fraction of Ns above specified threshold.")
+
+    args.add_argument("--reference", type=str, help="If path to reference specified (FASTA), then whole genome will be written.")
+
+    group = args.add_mutually_exclusive_group()
+
+    group.add_argument("--include")
+    group.add_argument("--exclude")
+
+    args.add_argument("--with-stats", help="If a path is specified, then position of the outputed SNPs is stored in this file. Requires mumpy and matplotlib.")
+    args.add_argument("--plots-dir", default="plots", help="Where to write summary plots on SNPs extracted. Requires mumpy and matplotlib.")
+
+    args.add_argument("--debug", action="store_true", help="More verbose logging (default: turned off).")
+    args.add_argument("--local", action="store_true", help="Re-read the VCF instead of storing it in memory.")
+
+    return args.parse_args()
+
+def main():
+    """
+    Process VCF files and merge them into a single fasta file.
+    """
+    args = get_args()
+
+    logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO)
+
+    contigs = list()
+
+    sample_stats = dict()
+
+    # All positions available for analysis.
+    avail_pos = dict()
+    # Stats about each position in each chromosome.
+    pos_stats = dict()
+    indel_summary = defaultdict(int)
+    # Cached version of the data.
+    vcf_data = dict()
+    mixtures = dict()
+
+    empty_tree = FastRBTree()
+
+    exclude = False
+    include = False
+
+    if args.reference:
+        ref_seq = OrderedDict()
+        with open(args.reference) as fp:
+            for record in SeqIO.parse(fp, "fasta"):
+                ref_seq[record.id] = str(record.seq)
+
+        args.reference = ref_seq
+
+    if args.exclude or args.include:
+        pos = {}
+        chr_pos = []
+        bed_file = args.include if args.include is not None else args.exclude
+
+        with open(bed_file) as fp:
+            for line in fp:
+                data = line.strip().split("\t")
+
+                chr_pos += [ (i, False,) for i in xrange(int(data[1]), int(data[2]) + 1)]
+
+                if data[0] not in pos:
+                    pos[data[0]] = []
+
+                pos[data[0]] += chr_pos
+
+
+        pos = {chrom: FastRBTree(l) for chrom, l in pos.items()}
+
+        if args.include:
+            include = pos
+        else:
+            exclude = pos
+
+
+    if args.directory is not None and args.input is None:
+        args.input = glob.glob(os.path.join(args.directory, "*.filtered.vcf"))
+
+    # First pass to get the references and the positions to be analysed.
+    for vcf_in in args.input:
+        sample_name, _ = os.path.splitext(os.path.basename(vcf_in))
+        vcf_data[vcf_in] = list()
+        reader = vcf.Reader(filename=vcf_in)
+
+        for record in reader:
+            if include and include.get(record.CHROM, empty_tree).get(record.POS, True) or exclude and not exclude.get(record.CHROM, empty_tree).get(record.POS, True):
+                continue
+
+            if not args.local:
+                vcf_data[vcf_in].append(record)
+
+            if record.CHROM not in contigs:
+                contigs.append(record.CHROM)
+                avail_pos[record.CHROM] = FastRBTree()
+                mixtures[record.CHROM] = {}
+                sample_stats[record.CHROM] = {}
+
+            if sample_name not in mixtures[record.CHROM]:
+                mixtures[record.CHROM][sample_name] = FastRBTree()
+
+            if sample_name not in sample_stats[record.CHROM]:
+                sample_stats[record.CHROM][sample_name] = {}
+
+            if not record.FILTER:
+                if record.is_snp:
+                    if record.POS in avail_pos[record.CHROM] and avail_pos[record.CHROM][record.POS] != record.REF:
+                        logging.critical("SOMETHING IS REALLY WRONG because reference for the same position is DIFFERENT! %s in %s", record.POS, vcf_in)
+                        return 2
+
+                    if record.CHROM not in pos_stats:
+                        pos_stats[record.CHROM] = {}
+
+                    avail_pos[record.CHROM].insert(record.POS, str(record.REF))
+                    pos_stats[record.CHROM][record.POS] = {"N":0, "-": 0, "mut": 0, "mix": 0, "gap": 0}
+
+            elif args.with_mixtures and record.is_snp:
+                mix = get_mixture(record, args.with_mixtures)
+
+                for ratio, code in mix.items():
+                    for c in code:
+                        avail_pos[record.CHROM].insert(record.POS, str(record.REF))
+                        if record.CHROM not in pos_stats:
+                            pos_stats[record.CHROM] = {}
+                        pos_stats[record.CHROM][record.POS] = {"N": 0, "-": 0, "mut": 0, "mix": 0, "gap": 0}
+
+                        if sample_name not in mixtures[record.CHROM]:
+                            mixtures[record.CHROM][sample_name] = FastRBTree()
+
+                        mixtures[record.CHROM][sample_name].insert(record.POS, c)
+            elif not record.is_deletion and not record.is_indel:
+                if record.CHROM not in pos_stats:
+                    pos_stats[record.CHROM] = {}
+                pos_stats[record.CHROM][record.POS] = {"N": 0, "-": 0, "mut": 0, "mix": 0, "gap": 0}
+                avail_pos[record.CHROM].insert(record.POS, str(record.REF))
+            else:
+                logging.debug("Discarding %s from %s as DEL and/or INDEL", record.POS, vcf_in)
+                indel_summary[vcf_in] += 1
+                try:
+                    vcf_data[vcf_in].remove(record)
+                except ValueError:
+                    pass
+
+
+    all_data = { contig: {} for contig in contigs}
+    samples = []
+
+    for vcf_in in args.input:
+
+        sample_seq = ""
+        sample_name, _ = os.path.splitext(os.path.basename(vcf_in))
+        samples.append(sample_name)
+
+        # Initialise the data for this sample to be REF positions.
+        for contig in contigs:
+            all_data[contig][sample_name] = { pos: avail_pos[contig][pos] for pos in avail_pos[contig] }
+
+        # Re-read data from VCF if local is specified, otherwise get it from memory.
+        iterator = vcf.Reader(filename=vcf_in) if args.local else vcf_data[vcf_in]
+        for record in iterator:
+            # Array of filters that have been applied.
+            filters = []
+
+            # If position is our available position.
+            if avail_pos.get(record.CHROM, empty_tree).get(record.POS, False):
+                if not record.FILTER:
+                    if record.is_snp:
+                        if len(record.ALT) > 1:
+                            logging.info("POS %s passed filters but has multiple alleles. Inserting N")
+                            all_data[record.CHROM][sample_name][record.POS] = "N"
+                        else:
+                            all_data[record.CHROM][sample_name][record.POS] = record.ALT[0].sequence
+                            pos_stats[record.CHROM][record.POS]["mut"] += 1
+                else:
+
+                    # Currently we are only using first filter to call consensus.
+                    extended_code = mixtures[record.CHROM][sample_name].get(record.POS, "N")
+
+#                     extended_code = PHEFilterBase.call_concensus(record)
+
+                    # Calculate the stats
+                    if extended_code == "N":
+                        pos_stats[record.CHROM][record.POS]["N"] += 1
+
+                        if "n_pos" not in sample_stats[record.CHROM][sample_name]:
+                            sample_stats[record.CHROM][sample_name]["n_pos"] = []
+                        sample_stats[record.CHROM][sample_name]["n_pos"].append(record.POS)
+
+                    elif extended_code == "-":
+                        pos_stats[record.CHROM][record.POS]["-"] += 1
+                    else:
+                        pos_stats[record.CHROM][record.POS]["mix"] += 1
+#                         print "Good mixture %s: %i (%s)" % (sample_name, record.POS, extended_code)
+                    # Record if there was uncallable genoty/gap in the data.
+                    if record.samples[0].data.GT == "./.":
+                        pos_stats[record.CHROM][record.POS]["gap"] += 1
+
+                    # Save the extended code of the SNP.
+                    all_data[record.CHROM][sample_name][record.POS] = extended_code
+        del vcf_data[vcf_in]
+
+    # Output the data to the fasta file.
+    # The data is already aligned so simply output it.
+    discarded = {}
+
+    if args.reference:
+        # These should be in the same order as the order in reference.
+        contigs = args.reference.keys()
+
+    if args.sample_Ns:
+        delete_samples = []
+        for contig in contigs:
+            for sample in samples:
+
+                # Skip if the contig not in sample_stats
+                if contig not in sample_stats:
+                    continue
+
+                sample_n_ratio = float(len(sample_stats[contig][sample]["n_pos"])) / len(avail_pos[contig])
+                if sample_n_ratio > args.sample_Ns:
+                    for pos in sample_stats[contig][sample]["n_pos"]:
+                        pos_stats[contig][pos]["N"] -= 1
+
+                    logging.info("Removing %s due to high Ns in sample: %s", sample , sample_n_ratio)
+
+                    delete_samples.append(sample)
+
+        samples = [sample for sample in samples if sample not in delete_samples]
+    snp_positions = []
+    with open(args.out, "w") as fp:
+
+        for sample in samples:
+            sample_seq = ""
+            for contig in contigs:
+                if contig in avail_pos:
+                    if args.reference:
+                        positions = xrange(1, len(args.reference[contig]) + 1)
+                    else:
+                        positions = avail_pos[contig].keys()
+                    for pos in positions:
+                        if pos in avail_pos[contig]:
+                            if not args.column_Ns or float(pos_stats[contig][pos]["N"]) / len(samples) < args.column_Ns and \
+                                float(pos_stats[contig][pos]["-"]) / len(samples) < args.column_Ns:
+                                sample_seq += all_data[contig][sample][pos]
+                            else:
+                                if contig not in discarded:
+                                    discarded[contig] = []
+                                discarded[contig].append(pos)
+                        elif args.reference:
+                            sample_seq += args.reference[contig][pos - 1]
+                elif args.reference:
+                    sample_seq += args.reference[contig]
+
+            fp.write(">%s\n%s\n" % (sample, sample_seq))
+        # Do the same for reference data.
+        ref_snps = ""
+
+        for contig in contigs:
+            if contig in avail_pos:
+                if args.reference:
+                    positions = xrange(1, len(args.reference[contig]) + 1)
+                else:
+                    positions = avail_pos[contig].keys()
+                for pos in positions:
+                    if pos in avail_pos[contig]:
+                        if not args.column_Ns or float(pos_stats[contig][pos]["N"]) / len(samples) < args.column_Ns and \
+                                float(pos_stats[contig][pos]["-"]) / len(samples) < args.column_Ns:
+
+                            ref_snps += str(avail_pos[contig][pos])
+                            snp_positions.append((contig, pos,))
+                    elif args.reference:
+                        ref_snps += args.reference[contig][pos - 1]
+            elif args.reference:
+                    ref_snps += args.reference[contig]
+
+        fp.write(">reference\n%s\n" % ref_snps)
+
+    if can_stats and args.with_stats:
+        with open(args.with_stats, "wb") as fp:
+            fp.write("contig\tposition\tmutations\tn_frac\n")
+            for values in snp_positions:
+                fp.write("%s\t%s\t%s\t%s\n" % (values[0],
+                                             values[1],
+                                             float(pos_stats[values[0]][values[1]]["mut"]) / len(args.input),
+                                             float(pos_stats[values[0]][values[1]]["N"]) / len(args.input)))
+        plot_stats(pos_stats, len(samples), discarded=discarded, plots_dir=os.path.abspath(args.plots_dir))
+    # print_stats(sample_stats, pos_stats, total_vars=len(avail_pos[contig]))
+
+    total_discarded = 0
+    for _, i in discarded.items():
+        total_discarded += len(i)
+    logging.info("Discarded total of %i poor quality columns", float(total_discarded) / len(args.input))
+    logging.info("Samples with indels:")
+    for sample, count in indel_summary.iteritems():
+        logging.info("%s\t%s", sample, count)
+    return 0
+
+if __name__ == '__main__':
+    import time
+
+#     with PyCallGraph(output=graphviz):
+#     T0 = time.time()
+    r = main()
+#     T1 = time.time()
+
+#     print "Time taken: %i" % (T1 - T0)
+    exit(r)