comparison vcfs2fasta.py @ 14:f72039c5faa4 draft

Uploaded
author ulfschaefer
date Wed, 16 Dec 2015 07:29:05 -0500
parents
children
comparison
equal deleted inserted replaced
13:2e69ce9dca65 14:f72039c5faa4
1 #!/usr/bin/env python
2 '''
3 Merge SNP data from multiple VCF files into a single fasta file.
4
5 Created on 5 Oct 2015
6
7 @author: alex
8 '''
9 import argparse
10 from collections import OrderedDict
11 import glob
12 import itertools
13 import logging
14 import os
15
16 from Bio import SeqIO
17 from bintrees import FastRBTree
18
19 # Try importing the matplotlib and numpy for stats.
20 try:
21 from matplotlib import pyplot as plt
22 import numpy
23 can_stats = True
24 except ImportError:
25 can_stats = False
26
27 import vcf
28
29 from phe.variant_filters import IUPAC_CODES
30
31
32 def plot_stats(pos_stats, total_samples, plots_dir="plots", discarded={}):
33 if not os.path.exists(plots_dir):
34 os.makedirs(plots_dir)
35
36 for contig in pos_stats:
37
38 plt.style.use('ggplot')
39
40 x = numpy.array([pos for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
41 y = numpy.array([ float(pos_stats[contig][pos]["mut"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, []) ])
42
43 f, (ax1, ax2, ax3, ax4) = plt.subplots(4, sharex=True, sharey=True)
44 f.set_size_inches(12, 15)
45 ax1.plot(x, y, 'ro')
46 ax1.set_title("Fraction of samples with SNPs")
47 plt.ylim(0, 1.1)
48
49 y = numpy.array([ float(pos_stats[contig][pos]["N"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
50 ax2.plot(x, y, 'bo')
51 ax2.set_title("Fraction of samples with Ns")
52
53 y = numpy.array([ float(pos_stats[contig][pos]["mix"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
54 ax3.plot(x, y, 'go')
55 ax3.set_title("Fraction of samples with mixed bases")
56
57 y = numpy.array([ float(pos_stats[contig][pos]["gap"]) / total_samples for pos in pos_stats[contig] if pos not in discarded.get(contig, [])])
58 ax4.plot(x, y, 'yo')
59 ax4.set_title("Fraction of samples with uncallable genotype (gap)")
60
61 plt.savefig(os.path.join(plots_dir, "%s.png" % contig), dpi=100)
62
63 def get_mixture(record, threshold):
64 mixtures = {}
65 try:
66 if len(record.samples[0].data.AD) > 1:
67
68 total_depth = sum(record.samples[0].data.AD)
69 # Go over all combinations of touples.
70 for comb in itertools.combinations(range(0, len(record.samples[0].data.AD)), 2):
71 i = comb[0]
72 j = comb[1]
73
74 alleles = list()
75
76 if 0 in comb:
77 alleles.append(str(record.REF))
78
79 if i != 0:
80 alleles.append(str(record.ALT[i - 1]))
81 mixture = record.samples[0].data.AD[i]
82 if j != 0:
83 alleles.append(str(record.ALT[j - 1]))
84 mixture = record.samples[0].data.AD[j]
85
86 ratio = float(mixture) / total_depth
87 if ratio == 1.0:
88 logging.debug("This is only designed for mixtures! %s %s %s %s", record, ratio, record.samples[0].data.AD, record.FILTER)
89
90 if ratio not in mixtures:
91 mixtures[ratio] = []
92 mixtures[ratio].append(alleles.pop())
93
94 elif ratio >= threshold:
95 try:
96 code = IUPAC_CODES[frozenset(alleles)]
97 if ratio not in mixtures:
98 mixtures[ratio] = []
99 mixtures[ratio].append(code)
100 except KeyError:
101 logging.warn("Could not retrieve IUPAC code for %s from %s", alleles, record)
102 except AttributeError:
103 mixtures = {}
104
105 return mixtures
106
107 def print_stats(stats, pos_stats, total_vars):
108 for contig in stats:
109 for sample, info in stats[contig].items():
110 print "%s,%i,%i" % (sample, len(info.get("n_pos", [])), total_vars)
111
112 for contig in stats:
113 for pos, info in pos_stats[contig].iteritems():
114 print "%s,%i,%i,%i,%i" % (contig, pos, info.get("N", "NA"), info.get("-", "NA"), info.get("mut", "NA"))
115
116
117 def get_args():
118 args = argparse.ArgumentParser(description="Combine multiple VCFs into a single FASTA file.")
119
120 group = args.add_mutually_exclusive_group(required=True)
121 group.add_argument("--directory", "-d", help="Path to the directory with .vcf files.")
122 group.add_argument("--input", "-i", type=str, nargs='+', help="List of VCF files to process.")
123
124 args.add_argument("--out", "-o", required=True, help="Path to the output FASTA file.")
125
126 args.add_argument("--with-mixtures", type=float, help="Specify this option with a threshold to output mixtures above this threshold.")
127
128 args.add_argument("--column-Ns", type=float, help="Keeps columns with fraction of Ns above specified threshold.")
129
130 args.add_argument("--sample-Ns", type=float, help="Keeps samples with fraction of Ns above specified threshold.")
131
132 args.add_argument("--reference", type=str, help="If path to reference specified (FASTA), then whole genome will be written.")
133
134 group = args.add_mutually_exclusive_group()
135
136 group.add_argument("--include")
137 group.add_argument("--exclude")
138
139 args.add_argument("--with-stats", help="If a path is specified, then position of the outputed SNPs is stored in this file. Requires mumpy and matplotlib.")
140 args.add_argument("--plots-dir", default="plots", help="Where to write summary plots on SNPs extracted. Requires mumpy and matplotlib.")
141
142 return args.parse_args()
143
144 def main():
145 """
146 Process VCF files and merge them into a single fasta file.
147 """
148
149 logging.basicConfig(level=logging.INFO)
150
151 args = get_args()
152 contigs = list()
153
154 sample_stats = dict()
155
156 # All positions available for analysis.
157 avail_pos = dict()
158 # Stats about each position in each chromosome.
159 pos_stats = dict()
160 # Cached version of the data.
161 vcf_data = dict()
162 mixtures = dict()
163
164 empty_tree = FastRBTree()
165
166 exclude = False
167 include = False
168
169 if args.reference:
170 ref_seq = OrderedDict()
171 with open(args.reference) as fp:
172 for record in SeqIO.parse(fp, "fasta"):
173 ref_seq[record.id] = str(record.seq)
174
175 args.reference = ref_seq
176
177 if args.exclude or args.include:
178 pos = {}
179 chr_pos = []
180 bed_file = args.include if args.include is not None else args.exclude
181
182 with open(bed_file) as fp:
183 for line in fp:
184 data = line.strip().split("\t")
185
186 chr_pos += [ (i, False,) for i in xrange(int(data[1]), int(data[2]) + 1)]
187
188 if data[0] not in pos:
189 pos[data[0]] = []
190
191 pos[data[0]] += chr_pos
192
193
194 pos = {chrom: FastRBTree(l) for chrom, l in pos.items()}
195
196 if args.include:
197 include = pos
198 else:
199 exclude = pos
200
201
202 if args.directory is not None and args.input is None:
203 args.input = glob.glob(os.path.join(args.directory, "*.vcf"))
204
205 # First pass to get the references and the positions to be analysed.
206 for vcf_in in args.input:
207 sample_name, _ = os.path.splitext(os.path.basename(vcf_in))
208 vcf_data[vcf_in] = list()
209 reader = vcf.Reader(filename=vcf_in)
210
211 for record in reader:
212 if include and include.get(record.CHROM, empty_tree).get(record.POS, True) or exclude and not exclude.get(record.CHROM, empty_tree).get(record.POS, True):
213 continue
214
215 vcf_data[vcf_in].append(record)
216
217 if record.CHROM not in contigs:
218 contigs.append(record.CHROM)
219 avail_pos[record.CHROM] = FastRBTree()
220 mixtures[record.CHROM] = {}
221 sample_stats[record.CHROM] = {}
222
223 if sample_name not in mixtures[record.CHROM]:
224 mixtures[record.CHROM][sample_name] = FastRBTree()
225
226 if sample_name not in sample_stats[record.CHROM]:
227 sample_stats[record.CHROM][sample_name] = {}
228
229 if not record.FILTER:
230 if record.is_snp:
231 if record.POS in avail_pos[record.CHROM] and avail_pos[record.CHROM][record.POS] != record.REF:
232 logging.critical("SOMETHING IS REALLY WRONG because reference for the same position is DIFFERENT! %s", record.POS)
233 return 2
234
235 if record.CHROM not in pos_stats:
236 pos_stats[record.CHROM] = {}
237
238 avail_pos[record.CHROM].insert(record.POS, str(record.REF))
239 pos_stats[record.CHROM][record.POS] = {"N":0, "-": 0, "mut": 0, "mix": 0, "gap": 0}
240
241 elif args.with_mixtures and record.is_snp:
242 mix = get_mixture(record, args.with_mixtures)
243
244 for ratio, code in mix.items():
245 for c in code:
246 avail_pos[record.CHROM].insert(record.POS, str(record.REF))
247 if record.CHROM not in pos_stats:
248 pos_stats[record.CHROM] = {}
249 pos_stats[record.CHROM][record.POS] = {"N": 0, "-": 0, "mut": 0, "mix": 0, "gap": 0}
250
251 if sample_name not in mixtures[record.CHROM]:
252 mixtures[record.CHROM][sample_name] = FastRBTree()
253
254 mixtures[record.CHROM][sample_name].insert(record.POS, c)
255
256
257 all_data = { contig: {} for contig in contigs}
258 samples = []
259
260 for vcf_in in args.input:
261
262 sample_seq = ""
263 sample_name, _ = os.path.splitext(os.path.basename(vcf_in))
264 samples.append(sample_name)
265
266 # Initialise the data for this sample to be REF positions.
267 for contig in contigs:
268 all_data[contig][sample_name] = { pos: avail_pos[contig][pos] for pos in avail_pos[contig] }
269
270 # reader = vcf.Reader(filename=vcf_in)
271 for record in vcf_data[vcf_in]:
272 # Array of filters that have been applied.
273 filters = []
274
275 # If position is our available position.
276 if avail_pos.get(record.CHROM, empty_tree).get(record.POS, False):
277 if record.FILTER == "PASS" or not record.FILTER:
278 if record.is_snp:
279 if len(record.ALT) > 1:
280 logging.info("POS %s passed filters but has multiple alleles. Inserting N")
281 all_data[record.CHROM][sample_name][record.POS] = "N"
282 else:
283 all_data[record.CHROM][sample_name][record.POS] = record.ALT[0].sequence
284 pos_stats[record.CHROM][record.POS]["mut"] += 1
285 else:
286
287 # Currently we are only using first filter to call consensus.
288 extended_code = mixtures[record.CHROM][sample_name].get(record.POS, "N")
289
290 # extended_code = PHEFilterBase.call_concensus(record)
291
292 # Calculate the stats
293 if extended_code == "N":
294 pos_stats[record.CHROM][record.POS]["N"] += 1
295
296 if "n_pos" not in sample_stats[record.CHROM][sample_name]:
297 sample_stats[record.CHROM][sample_name]["n_pos"] = []
298 sample_stats[record.CHROM][sample_name]["n_pos"].append(record.POS)
299
300 elif extended_code == "-":
301 pos_stats[record.CHROM][record.POS]["-"] += 1
302 else:
303 pos_stats[record.CHROM][record.POS]["mix"] += 1
304 # print "Good mixture %s: %i (%s)" % (sample_name, record.POS, extended_code)
305 # Record if there was uncallable genoty/gap in the data.
306 if record.samples[0].data.GT == "./.":
307 pos_stats[record.CHROM][record.POS]["gap"] += 1
308
309 # Save the extended code of the SNP.
310 all_data[record.CHROM][sample_name][record.POS] = extended_code
311 del vcf_data[vcf_in]
312
313 # Output the data to the fasta file.
314 # The data is already aligned so simply output it.
315 discarded = {}
316
317 if args.reference:
318 # These should be in the same order as the order in reference.
319 contigs = args.reference.keys()
320
321 if args.sample_Ns:
322 delete_samples = []
323 for contig in contigs:
324 for sample in samples:
325
326 # Skip if the contig not in sample_stats
327 if contig not in sample_stats:
328 continue
329
330 sample_n_ratio = float(len(sample_stats[contig][sample]["n_pos"])) / len(avail_pos[contig])
331 if sample_n_ratio > args.sample_Ns:
332 for pos in sample_stats[contig][sample]["n_pos"]:
333 pos_stats[contig][pos]["N"] -= 1
334
335 logging.info("Removing %s due to high Ns in sample: %s", sample , sample_n_ratio)
336
337 delete_samples.append(sample)
338
339 samples = [sample for sample in samples if sample not in delete_samples]
340 snp_positions = []
341 with open(args.out, "w") as fp:
342
343 for sample in samples:
344 sample_seq = ""
345 for contig in contigs:
346 if contig in avail_pos:
347 if args.reference:
348 positions = xrange(1, len(args.reference[contig]) + 1)
349 else:
350 positions = avail_pos[contig].keys()
351 for pos in positions:
352 if pos in avail_pos[contig]:
353 if not args.column_Ns or float(pos_stats[contig][pos]["N"]) / len(samples) < args.column_Ns and \
354 float(pos_stats[contig][pos]["-"]) / len(samples) < args.column_Ns:
355 sample_seq += all_data[contig][sample][pos]
356 else:
357 if contig not in discarded:
358 discarded[contig] = []
359 discarded[contig].append(pos)
360 elif args.reference:
361 sample_seq += args.reference[contig][pos - 1]
362 elif args.reference:
363 sample_seq += args.reference[contig]
364
365 fp.write(">%s\n%s\n" % (sample, sample_seq))
366 # Do the same for reference data.
367 ref_snps = ""
368
369 for contig in contigs:
370 if contig in avail_pos:
371 if args.reference:
372 positions = xrange(1, len(args.reference[contig]) + 1)
373 else:
374 positions = avail_pos[contig].keys()
375 for pos in positions:
376 if pos in avail_pos[contig]:
377 if not args.column_Ns or float(pos_stats[contig][pos]["N"]) / len(samples) < args.column_Ns and \
378 float(pos_stats[contig][pos]["-"]) / len(samples) < args.column_Ns:
379
380 ref_snps += str(avail_pos[contig][pos])
381 snp_positions.append((contig, pos,))
382 elif args.reference:
383 ref_snps += args.reference[contig][pos - 1]
384 elif args.reference:
385 ref_snps += args.reference[contig]
386
387 fp.write(">reference\n%s\n" % ref_snps)
388
389 if can_stats and args.with_stats:
390 with open(args.with_stats, "wb") as fp:
391 fp.write("contig\tposition\tmutations\tn_frac\n")
392 for values in snp_positions:
393 fp.write("%s\t%s\t%s\t%s\n" % (values[0],
394 values[1],
395 float(pos_stats[values[0]][values[1]]["mut"]) / len(args.input),
396 float(pos_stats[values[0]][values[1]]["N"]) / len(args.input)))
397 plot_stats(pos_stats, len(samples), discarded=discarded, plots_dir=os.path.abspath(args.plots_dir))
398 # print_stats(sample_stats, pos_stats, total_vars=len(avail_pos[contig]))
399
400 total_discarded = 0
401 for _, i in discarded.items():
402 total_discarded += len(i)
403 logging.info("Discarded total of %i poor quality columns", float(total_discarded) / len(args.input))
404 return 0
405
406 if __name__ == '__main__':
407 import time
408
409 # with PyCallGraph(output=graphviz):
410 # T0 = time.time()
411 r = main()
412 # T1 = time.time()
413
414 # print "Time taken: %i" % (T1 - T0)
415 exit(r)