comparison vsnp_add_zero_coverage.py @ 1:b03e88e7bb1d draft

"planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/vsnp commit 2e312886647244b416c64eca91e1a61dd1be939b"
author iuc
date Thu, 10 Dec 2020 15:25:22 +0000
parents 12f2b14549f6
children 57bd5b859e86
comparison
equal deleted inserted replaced
0:12f2b14549f6 1:b03e88e7bb1d
1 #!/usr/bin/env python 1 #!/usr/bin/env python
2 2
3 import argparse 3 import argparse
4 import multiprocessing
5 import os 4 import os
6 import queue
7 import re 5 import re
8 import shutil 6 import shutil
9 7
10 import pandas 8 import pandas
11 import pysam 9 import pysam
12 from Bio import SeqIO 10 from Bio import SeqIO
13 11
14 INPUT_BAM_DIR = 'input_bam_dir'
15 INPUT_VCF_DIR = 'input_vcf_dir'
16 OUTPUT_VCF_DIR = 'output_vcf_dir'
17 OUTPUT_METRICS_DIR = 'output_metrics_dir'
18 12
19 13 def get_sample_name(file_path):
20 def get_base_file_name(file_path):
21 base_file_name = os.path.basename(file_path) 14 base_file_name = os.path.basename(file_path)
22 if base_file_name.find(".") > 0: 15 if base_file_name.find(".") > 0:
23 # Eliminate the extension. 16 # Eliminate the extension.
24 return os.path.splitext(base_file_name)[0] 17 return os.path.splitext(base_file_name)[0]
25 elif base_file_name.endswith("_vcf"):
26 # The "." character has likely
27 # changed to an "_" character.
28 return base_file_name.rstrip("_vcf")
29 return base_file_name 18 return base_file_name
30 19
31 20
32 def get_coverage_and_snp_count(task_queue, reference, output_metrics, output_vcf, timeout): 21 def get_coverage_df(bam_file):
33 while True: 22 # Create a coverage dictionary.
34 try: 23 coverage_dict = {}
35 tup = task_queue.get(block=True, timeout=timeout) 24 coverage_list = pysam.depth(bam_file, split_lines=True)
36 except queue.Empty: 25 for line in coverage_list:
37 break 26 chrom, position, depth = line.split('\t')
38 bam_file, vcf_file = tup 27 coverage_dict["%s-%s" % (chrom, position)] = depth
39 # Create a coverage dictionary. 28 # Convert it to a data frame.
40 coverage_dict = {} 29 coverage_df = pandas.DataFrame.from_dict(coverage_dict, orient='index', columns=["depth"])
41 coverage_list = pysam.depth(bam_file, split_lines=True) 30 return coverage_df
42 for line in coverage_list:
43 chrom, position, depth = line.split('\t')
44 coverage_dict["%s-%s" % (chrom, position)] = depth
45 # Convert it to a data frame.
46 coverage_df = pandas.DataFrame.from_dict(coverage_dict, orient='index', columns=["depth"])
47 # Create a zero coverage dictionary.
48 zero_dict = {}
49 for record in SeqIO.parse(reference, "fasta"):
50 chrom = record.id
51 total_len = len(record.seq)
52 for pos in list(range(1, total_len + 1)):
53 zero_dict["%s-%s" % (str(chrom), str(pos))] = 0
54 # Convert it to a data frame with depth_x
55 # and depth_y columns - index is NaN.
56 zero_df = pandas.DataFrame.from_dict(zero_dict, orient='index', columns=["depth"])
57 coverage_df = zero_df.merge(coverage_df, left_index=True, right_index=True, how='outer')
58 # depth_x "0" column no longer needed.
59 coverage_df = coverage_df.drop(columns=['depth_x'])
60 coverage_df = coverage_df.rename(columns={'depth_y': 'depth'})
61 # Covert the NaN to 0 coverage and get some metrics.
62 coverage_df = coverage_df.fillna(0)
63 coverage_df['depth'] = coverage_df['depth'].apply(int)
64 total_length = len(coverage_df)
65 average_coverage = coverage_df['depth'].mean()
66 zero_df = coverage_df[coverage_df['depth'] == 0]
67 total_zero_coverage = len(zero_df)
68 total_coverage = total_length - total_zero_coverage
69 genome_coverage = "{:.2%}".format(total_coverage / total_length)
70 # Process the associated VCF input.
71 column_names = ["CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", "FORMAT", "Sample"]
72 vcf_df = pandas.read_csv(vcf_file, sep='\t', header=None, names=column_names, comment='#')
73 good_snp_count = len(vcf_df[(vcf_df['ALT'].str.len() == 1) & (vcf_df['REF'].str.len() == 1) & (vcf_df['QUAL'] > 150)])
74 base_file_name = get_base_file_name(vcf_file)
75 if total_zero_coverage > 0:
76 header_file = "%s_header.csv" % base_file_name
77 with open(header_file, 'w') as outfile:
78 with open(vcf_file) as infile:
79 for line in infile:
80 if re.search('^#', line):
81 outfile.write("%s" % line)
82 vcf_df_snp = vcf_df[vcf_df['REF'].str.len() == 1]
83 vcf_df_snp = vcf_df_snp[vcf_df_snp['ALT'].str.len() == 1]
84 vcf_df_snp['ABS_VALUE'] = vcf_df_snp['CHROM'].map(str) + "-" + vcf_df_snp['POS'].map(str)
85 vcf_df_snp = vcf_df_snp.set_index('ABS_VALUE')
86 cat_df = pandas.concat([vcf_df_snp, zero_df], axis=1, sort=False)
87 cat_df = cat_df.drop(columns=['CHROM', 'POS', 'depth'])
88 cat_df[['ID', 'ALT', 'QUAL', 'FILTER', 'INFO']] = cat_df[['ID', 'ALT', 'QUAL', 'FILTER', 'INFO']].fillna('.')
89 cat_df['REF'] = cat_df['REF'].fillna('N')
90 cat_df['FORMAT'] = cat_df['FORMAT'].fillna('GT')
91 cat_df['Sample'] = cat_df['Sample'].fillna('./.')
92 cat_df['temp'] = cat_df.index.str.rsplit('-', n=1)
93 cat_df[['CHROM', 'POS']] = pandas.DataFrame(cat_df.temp.values.tolist(), index=cat_df.index)
94 cat_df = cat_df[['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'Sample']]
95 cat_df['POS'] = cat_df['POS'].astype(int)
96 cat_df = cat_df.sort_values(['CHROM', 'POS'])
97 body_file = "%s_body.csv" % base_file_name
98 cat_df.to_csv(body_file, sep='\t', header=False, index=False)
99 if output_vcf is None:
100 output_vcf_file = os.path.join(OUTPUT_VCF_DIR, "%s.vcf" % base_file_name)
101 else:
102 output_vcf_file = output_vcf
103 with open(output_vcf_file, "w") as outfile:
104 for cf in [header_file, body_file]:
105 with open(cf, "r") as infile:
106 for line in infile:
107 outfile.write("%s" % line)
108 else:
109 if output_vcf is None:
110 output_vcf_file = os.path.join(OUTPUT_VCF_DIR, "%s.vcf" % base_file_name)
111 else:
112 output_vcf_file = output_vcf
113 shutil.copyfile(vcf_file, output_vcf_file)
114 bam_metrics = [base_file_name, "", "%4f" % average_coverage, genome_coverage]
115 vcf_metrics = [base_file_name, str(good_snp_count), "", ""]
116 if output_metrics is None:
117 output_metrics_file = os.path.join(OUTPUT_METRICS_DIR, "%s.tabular" % base_file_name)
118 else:
119 output_metrics_file = output_metrics
120 metrics_columns = ["File", "Number of Good SNPs", "Average Coverage", "Genome Coverage"]
121 with open(output_metrics_file, "w") as fh:
122 fh.write("# %s\n" % "\t".join(metrics_columns))
123 fh.write("%s\n" % "\t".join(bam_metrics))
124 fh.write("%s\n" % "\t".join(vcf_metrics))
125 task_queue.task_done()
126 31
127 32
128 def set_num_cpus(num_files, processes): 33 def get_zero_df(reference):
129 num_cpus = int(multiprocessing.cpu_count()) 34 # Create a zero coverage dictionary.
130 if num_files < num_cpus and num_files < processes: 35 zero_dict = {}
131 return num_files 36 for record in SeqIO.parse(reference, "fasta"):
132 if num_cpus < processes: 37 chrom = record.id
133 half_cpus = int(num_cpus / 2) 38 total_len = len(record.seq)
134 if num_files < half_cpus: 39 for pos in list(range(1, total_len + 1)):
135 return num_files 40 zero_dict["%s-%s" % (str(chrom), str(pos))] = 0
136 return half_cpus 41 # Convert it to a data frame with depth_x
137 return processes 42 # and depth_y columns - index is NaN.
43 zero_df = pandas.DataFrame.from_dict(zero_dict, orient='index', columns=["depth"])
44 return zero_df
45
46
47 def output_zc_vcf_file(base_file_name, vcf_file, zero_df, total_zero_coverage, output_vcf):
48 column_names = ["CHROM", "POS", "ID", "REF", "ALT", "QUAL", "FILTER", "INFO", "FORMAT", "Sample"]
49 vcf_df = pandas.read_csv(vcf_file, sep='\t', header=None, names=column_names, comment='#')
50 good_snp_count = len(vcf_df[(vcf_df['ALT'].str.len() == 1) & (vcf_df['REF'].str.len() == 1) & (vcf_df['QUAL'] > 150)])
51 if total_zero_coverage > 0:
52 header_file = "%s_header.csv" % base_file_name
53 with open(header_file, 'w') as outfile:
54 with open(vcf_file) as infile:
55 for line in infile:
56 if re.search('^#', line):
57 outfile.write("%s" % line)
58 vcf_df_snp = vcf_df[vcf_df['REF'].str.len() == 1]
59 vcf_df_snp = vcf_df_snp[vcf_df_snp['ALT'].str.len() == 1]
60 vcf_df_snp['ABS_VALUE'] = vcf_df_snp['CHROM'].map(str) + "-" + vcf_df_snp['POS'].map(str)
61 vcf_df_snp = vcf_df_snp.set_index('ABS_VALUE')
62 cat_df = pandas.concat([vcf_df_snp, zero_df], axis=1, sort=False)
63 cat_df = cat_df.drop(columns=['CHROM', 'POS', 'depth'])
64 cat_df[['ID', 'ALT', 'QUAL', 'FILTER', 'INFO']] = cat_df[['ID', 'ALT', 'QUAL', 'FILTER', 'INFO']].fillna('.')
65 cat_df['REF'] = cat_df['REF'].fillna('N')
66 cat_df['FORMAT'] = cat_df['FORMAT'].fillna('GT')
67 cat_df['Sample'] = cat_df['Sample'].fillna('./.')
68 cat_df['temp'] = cat_df.index.str.rsplit('-', n=1)
69 cat_df[['CHROM', 'POS']] = pandas.DataFrame(cat_df.temp.values.tolist(), index=cat_df.index)
70 cat_df = cat_df[['CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO', 'FORMAT', 'Sample']]
71 cat_df['POS'] = cat_df['POS'].astype(int)
72 cat_df = cat_df.sort_values(['CHROM', 'POS'])
73 body_file = "%s_body.csv" % base_file_name
74 cat_df.to_csv(body_file, sep='\t', header=False, index=False)
75 with open(output_vcf, "w") as outfile:
76 for cf in [header_file, body_file]:
77 with open(cf, "r") as infile:
78 for line in infile:
79 outfile.write("%s" % line)
80 else:
81 shutil.move(vcf_file, output_vcf)
82 return good_snp_count
83
84
85 def output_metrics_file(base_file_name, average_coverage, genome_coverage, good_snp_count, output_metrics):
86 bam_metrics = [base_file_name, "", "%4f" % average_coverage, genome_coverage]
87 vcf_metrics = [base_file_name, str(good_snp_count), "", ""]
88 metrics_columns = ["File", "Number of Good SNPs", "Average Coverage", "Genome Coverage"]
89 with open(output_metrics, "w") as fh:
90 fh.write("# %s\n" % "\t".join(metrics_columns))
91 fh.write("%s\n" % "\t".join(bam_metrics))
92 fh.write("%s\n" % "\t".join(vcf_metrics))
93
94
95 def output_files(vcf_file, total_zero_coverage, zero_df, output_vcf, average_coverage, genome_coverage, output_metrics):
96 base_file_name = get_sample_name(vcf_file)
97 good_snp_count = output_zc_vcf_file(base_file_name, vcf_file, zero_df, total_zero_coverage, output_vcf)
98 output_metrics_file(base_file_name, average_coverage, genome_coverage, good_snp_count, output_metrics)
99
100
101 def get_coverage_and_snp_count(bam_file, vcf_file, reference, output_metrics, output_vcf):
102 coverage_df = get_coverage_df(bam_file)
103 zero_df = get_zero_df(reference)
104 coverage_df = zero_df.merge(coverage_df, left_index=True, right_index=True, how='outer')
105 # depth_x "0" column no longer needed.
106 coverage_df = coverage_df.drop(columns=['depth_x'])
107 coverage_df = coverage_df.rename(columns={'depth_y': 'depth'})
108 # Covert the NaN to 0 coverage and get some metrics.
109 coverage_df = coverage_df.fillna(0)
110 coverage_df['depth'] = coverage_df['depth'].apply(int)
111 total_length = len(coverage_df)
112 average_coverage = coverage_df['depth'].mean()
113 zero_df = coverage_df[coverage_df['depth'] == 0]
114 total_zero_coverage = len(zero_df)
115 total_coverage = total_length - total_zero_coverage
116 genome_coverage = "{:.2%}".format(total_coverage / total_length)
117 # Output a zero-coverage vcf fil and the metrics file.
118 output_files(vcf_file, total_zero_coverage, zero_df, output_vcf, average_coverage, genome_coverage, output_metrics)
138 119
139 120
140 if __name__ == '__main__': 121 if __name__ == '__main__':
141 parser = argparse.ArgumentParser() 122 parser = argparse.ArgumentParser()
142 123
124 parser.add_argument('--bam_input', action='store', dest='bam_input', help='bam input file')
143 parser.add_argument('--output_metrics', action='store', dest='output_metrics', required=False, default=None, help='Output metrics text file') 125 parser.add_argument('--output_metrics', action='store', dest='output_metrics', required=False, default=None, help='Output metrics text file')
144 parser.add_argument('--output_vcf', action='store', dest='output_vcf', required=False, default=None, help='Output VCF file') 126 parser.add_argument('--output_vcf', action='store', dest='output_vcf', required=False, default=None, help='Output VCF file')
145 parser.add_argument('--reference', action='store', dest='reference', help='Reference dataset') 127 parser.add_argument('--reference', action='store', dest='reference', help='Reference dataset')
146 parser.add_argument('--processes', action='store', dest='processes', type=int, help='User-selected number of processes to use for job splitting') 128 parser.add_argument('--vcf_input', action='store', dest='vcf_input', help='vcf input file')
147 129
148 args = parser.parse_args() 130 args = parser.parse_args()
149 131
150 # The assumption here is that the list of files 132 get_coverage_and_snp_count(args.bam_input, args.vcf_input, args.reference, args.output_metrics, args.output_vcf)
151 # in both INPUT_BAM_DIR and INPUT_VCF_DIR are
152 # equal in number and named such that they are
153 # properly matched if the directories contain
154 # more than 1 file (i.e., hopefully the bam file
155 # names and vcf file names will be something like
156 # Mbovis-01D6_* so they can be # sorted and properly
157 # associated with each other).
158 bam_files = []
159 for file_name in sorted(os.listdir(INPUT_BAM_DIR)):
160 file_path = os.path.abspath(os.path.join(INPUT_BAM_DIR, file_name))
161 bam_files.append(file_path)
162 vcf_files = []
163 for file_name in sorted(os.listdir(INPUT_VCF_DIR)):
164 file_path = os.path.abspath(os.path.join(INPUT_VCF_DIR, file_name))
165 vcf_files.append(file_path)
166
167 multiprocessing.set_start_method('spawn')
168 queue1 = multiprocessing.JoinableQueue()
169 num_files = len(bam_files)
170 cpus = set_num_cpus(num_files, args.processes)
171 # Set a timeout for get()s in the queue.
172 timeout = 0.05
173
174 # Add each associated bam and vcf file pair to the queue.
175 for i, bam_file in enumerate(bam_files):
176 vcf_file = vcf_files[i]
177 queue1.put((bam_file, vcf_file))
178
179 # Complete the get_coverage_and_snp_count task.
180 processes = [multiprocessing.Process(target=get_coverage_and_snp_count, args=(queue1, args.reference, args.output_metrics, args.output_vcf, timeout, )) for _ in range(cpus)]
181 for p in processes:
182 p.start()
183 for p in processes:
184 p.join()
185 queue1.join()
186
187 if queue1.empty():
188 queue1.close()
189 queue1.join_thread()