Mercurial > repos > greg > vsnp_get_snps
view vsnp_get_snps.py @ 1:770834ba75e4 draft
Uploaded
author | greg |
---|---|
date | Sat, 14 Nov 2020 09:07:53 +0000 |
parents | ee4ef1fc23c6 |
children | 14285a94fb13 |
line wrap: on
line source
#!/usr/bin/env python # Collect quality parsimonious SNPs from vcf files and output alignment files in fasta format. import argparse import multiprocessing import os import pandas import queue import shutil import sys import time import vcf from collections import OrderedDict from datetime import datetime ALL_VCFS_DIR = 'all_vcf' INPUT_VCF_DIR = 'input_vcf_dir' OUTPUT_JSON_AVG_MQ_DIR = 'output_json_avg_mq_dir' OUTPUT_JSON_SNPS_DIR = 'output_json_snps_dir' OUTPUT_SNPS_DIR = 'output_snps_dir' def get_time_stamp(): return datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H-%M-%S') def set_num_cpus(num_files, processes): num_cpus = int(multiprocessing.cpu_count()) if num_files < num_cpus and num_files < processes: return num_files if num_cpus < processes: half_cpus = int(num_cpus / 2) if num_files < half_cpus: return num_files return half_cpus return processes def setup_all_vcfs(vcf_files, vcf_dirs): # Create the all_vcfs directory and link # all input vcf files into it for processing. os.makedirs(ALL_VCFS_DIR) vcf_dirs.append(ALL_VCFS_DIR) for vcf_file in vcf_files: file_name_base = os.path.basename(vcf_file) dst_file = os.path.join(ALL_VCFS_DIR, file_name_base) os.symlink(vcf_file, dst_file) return vcf_dirs class SnpFinder: def __init__(self, num_files, reference, excel_grouper_file, all_isolates, ac, mq_val, n_threshold, qual_threshold, output_summary): self.ac = ac self.all_isolates = all_isolates self.all_positions = None # Filter based on the contents of an Excel file. self.excel_grouper_file = excel_grouper_file # Use Genbank file self.groups = [] # This will be populated from the columns # in the Excel filter file if it is used. self.mq_val = mq_val self.n_threshold = n_threshold self.qual_threshold = qual_threshold self.reference = reference self.start_time = get_time_stamp() self.summary_str = "" self.timer_start = datetime.now() self.num_files = num_files self.initiate_summary(output_summary) def append_to_summary(self, html_str): self.summary_str = "%s%s" % (self.summary_str, html_str) def bin_input_files(self, filename, samples_groups_dict, defining_snps, inverted_defining_snps, found_positions, found_positions_mix): sample_groups_list = [] table_name = self.get_base_file_name(filename) try: defining_snp = False # Absolute positions in set union of two lists. for abs_position in list(defining_snps.keys() & (found_positions.keys() | found_positions_mix.keys())): group = defining_snps[abs_position] sample_groups_list.append(group) self.check_add_group(group) if len(list(defining_snps.keys() & found_positions_mix.keys())) > 0: table_name = self.get_base_file_name(filename) table_name = '%s<font color="red">[[MIXED]]</font>' % table_name self.copy_file(filename, group) defining_snp = True if not set(inverted_defining_snps.keys()).intersection(found_positions.keys() | found_positions_mix.keys()): for abs_position in list(inverted_defining_snps.keys()): group = inverted_defining_snps[abs_position] sample_groups_list.append(group) self.check_add_group(group) self.copy_file(filename, group) defining_snp = True if defining_snp: samples_groups_dict[table_name] = sorted(sample_groups_list) else: samples_groups_dict[table_name] = ['<font color="red">No defining SNP</font>'] except TypeError as e: msg = "<br/>Error processing file %s to generate samples_groups_dict: %s<br/>" % (filename, str(e)) self.append_to_summary(msg) samples_groups_dict[table_name] = [msg] return samples_groups_dict def check_add_group(self, group): if group not in self.groups: self.groups.append(group) def copy_file(self, filename, dir): if not os.path.exists(dir): os.makedirs(dir) shutil.copy(filename, dir) def decide_snps(self, filename): positions_dict = self.all_positions # Find the SNPs in a vcf file to produce a pandas data # frame and a dictionary containing sample map qualities. sample_map_qualities = {} # Eliminate the path. file_name_base = self.get_base_file_name(filename) vcf_reader = vcf.Reader(open(filename, 'r')) sample_dict = {} for record in vcf_reader: alt = str(record.ALT[0]) record_position = "%s:%s" % (str(record.CHROM), str(record.POS)) if record_position in positions_dict: if alt == "None": sample_dict.update({record_position: "-"}) else: # Not sure this is the best place to capture MQM average # may be faster after parsimony SNPs are decided, but # then it will require opening the files again. # On rare occassions MQM gets called "NaN", thus passing # a string when a number is expected when calculating average. mq_val = self.get_mq_val(record.INFO, filename) if str(mq_val).lower() not in ["nan"]: sample_map_qualities.update({record_position: mq_val}) # Add parameters here to change what each vcf represents. # SNP is represented in table, now how will the vcf represent # the called position alt != "None", which means a deletion # as alt is not record.FILTER, or rather passed. len_alt = len(alt) if len_alt == 1: qual_val = self.val_as_int(record.QUAL) ac = record.INFO['AC'][0] ref = str(record.REF[0]) if ac == 2 and qual_val > self.n_threshold: sample_dict.update({record_position: alt}) elif ac == 1 and qual_val > self.n_threshold: alt_ref = "%s%s" % (alt, ref) if alt_ref == "AG": sample_dict.update({record_position: "R"}) elif alt_ref == "CT": sample_dict.update({record_position: "Y"}) elif alt_ref == "GC": sample_dict.update({record_position: "S"}) elif alt_ref == "AT": sample_dict.update({record_position: "W"}) elif alt_ref == "GT": sample_dict.update({record_position: "K"}) elif alt_ref == "AC": sample_dict.update({record_position: "M"}) elif alt_ref == "GA": sample_dict.update({record_position: "R"}) elif alt_ref == "TC": sample_dict.update({record_position: "Y"}) elif alt_ref == "CG": sample_dict.update({record_position: "S"}) elif alt_ref == "TA": sample_dict.update({record_position: "W"}) elif alt_ref == "TG": sample_dict.update({record_position: "K"}) elif alt_ref == "CA": sample_dict.update({record_position: "M"}) else: sample_dict.update({record_position: "N"}) # Poor calls elif qual_val <= 50: # Do not coerce record.REF[0] to a string! sample_dict.update({record_position: record.REF[0]}) elif qual_val <= self.n_threshold: sample_dict.update({record_position: "N"}) else: # Insurance -- Will still report on a possible # SNP even if missed with above statement # Do not coerce record.REF[0] to a string! sample_dict.update({record_position: record.REF[0]}) # Merge dictionaries and order merge_dict = {} # abs_pos:REF merge_dict.update(positions_dict) # abs_pos:ALT replacing all_positions, because keys must be unique merge_dict.update(sample_dict) sample_df = pandas.DataFrame(merge_dict, index=[file_name_base]) return sample_df, file_name_base, sample_map_qualities def df_to_fasta(self, parsimonious_df, group): # Generate SNP alignment file from the parsimonious_df # data frame. snps_file = os.path.join(OUTPUT_SNPS_DIR, "%s.fasta" % group) test_duplicates = [] has_sequence_data = False for index, row in parsimonious_df.iterrows(): for pos in row: if len(pos) > 0: has_sequence_data = True break if has_sequence_data: with open(snps_file, 'w') as fh: for index, row in parsimonious_df.iterrows(): test_duplicates.append(row.name) if test_duplicates.count(row.name) < 2: print(f'>{row.name}', file=fh) for pos in row: print(pos, end='', file=fh) print("", file=fh) return has_sequence_data def find_initial_positions(self, filename): # Find SNP positions in a vcf file. found_positions = {} found_positions_mix = {} try: vcf_reader = vcf.Reader(open(filename, 'r')) try: for record in vcf_reader: qual_val = self.val_as_int(record.QUAL) chrom = record.CHROM position = record.POS absolute_position = "%s:%s" % (str(chrom), str(position)) alt = str(record.ALT[0]) if alt != "None": mq_val = self.get_mq_val(record.INFO, filename) ac = record.INFO['AC'][0] len_ref = len(record.REF) if ac == self.ac and len_ref == 1 and qual_val > self.qual_threshold and mq_val > self.mq_val: found_positions.update({absolute_position: record.REF}) if ac == 1 and len_ref == 1 and qual_val > self.qual_threshold and mq_val > self.mq_val: found_positions_mix.update({absolute_position: record.REF}) return found_positions, found_positions_mix except (ZeroDivisionError, ValueError, UnboundLocalError, TypeError) as e: self.append_to_summar("<br/>Error parsing record in file %s: %s<br/>" % (filename, str(e))) return {'': ''}, {'': ''} except (SyntaxError, AttributeError) as e: self.append_to_summary("<br/>Error attempting to read file %s: %s<br/>" % (filename, str(e))) return {'': ''}, {'': ''} def gather_and_filter(self, prefilter_df, mq_averages, group_dir): # Group a data frame of SNPs. if self.excel_grouper_file is None: filtered_all_df = prefilter_df sheet_names = None else: # Filter positions to be removed from all. xl = pandas.ExcelFile(self.excel_grouper_file) sheet_names = xl.sheet_names # Use the first column to filter "all" postions. exclusion_list_all = self.get_position_list(sheet_names, 0) exclusion_list_group = self.get_position_list(sheet_names, group_dir) exclusion_list = exclusion_list_all + exclusion_list_group # Filters for all applied. filtered_all_df = prefilter_df.drop(columns=exclusion_list, errors='ignore') json_snps_file = os.path.join(OUTPUT_JSON_SNPS_DIR, "%s.json" % group_dir) parsimonious_df = self.get_parsimonious_df(filtered_all_df) samples_number, columns = parsimonious_df.shape if samples_number >= 4: has_sequence_data = self.df_to_fasta(parsimonious_df, group_dir) if has_sequence_data: json_avg_mq_file = os.path.join(OUTPUT_JSON_AVG_MQ_DIR, "%s.json" % group_dir) mq_averages.to_json(json_avg_mq_file, orient='split') parsimonious_df.to_json(json_snps_file, orient='split') else: msg = "<br/>No sequence data" if group_dir is not None: msg = "%s for group: %s" % (msg, group_dir) self.append_to_summary("%s<br/>\n" % msg) else: msg = "<br/>Too few samples to build tree" if group_dir is not None: msg = "%s for group: %s" % (msg, group_dir) self.append_to_summary("%s<br/>\n" % msg) def get_base_file_name(self, file_path): base_file_name = os.path.basename(file_path) if base_file_name.find(".") > 0: # Eliminate the extension. return os.path.splitext(base_file_name)[0] elif base_file_name.find("_") > 0: # The dot extension was likely changed to # the " character. items = base_file_name.split("_") return "_".join(items[0:-1]) else: return base_file_name def get_mq_val(self, record_info, filename): # Get the MQ (gatk) or MQM (freebayes) value # from the record.INFO component of the vcf file. try: mq_val = record_info['MQM'] return self.return_val(mq_val) except Exception: try: mq_val = record_info['MQ'] return self.return_val(mq_val) except Exception: msg = "Invalid or unsupported vcf header %s in file: %s\n" % (str(record_info), filename) sys.exit(msg) def get_parsimonious_df(self, filtered_all_df): # Get the parsimonious SNPs data frame # from a data frame of filtered SNPs. try: ref_series = filtered_all_df.loc['root'] # In all_vcf root needs to be removed. filtered_all_df = filtered_all_df.drop(['root']) except KeyError: pass parsimony = filtered_all_df.loc[:, (filtered_all_df != filtered_all_df.iloc[0]).any()] parsimony_positions = list(parsimony) parse_df = filtered_all_df[parsimony_positions] ref_df = ref_series.to_frame() ref_df = ref_df.T parsimonious_df = pandas.concat([parse_df, ref_df], join='inner') return parsimonious_df def get_position_list(self, sheet_names, group): # Get a list of positions defined by an excel file. exclusion_list = [] try: filter_to_all = pandas.read_excel(self.excel_grouper_file, header=1, usecols=[group]) for value in filter_to_all.values: value = str(value[0]) if "-" not in value.split(":")[-1]: exclusion_list.append(value) elif "-" in value: try: chrom, sequence_range = value.split(":") except Exception as e: sys.exit(str(e)) value = sequence_range.split("-") for position in range(int(value[0].replace(',', '')), int(value[1].replace(',', '')) + 1): exclusion_list.append(chrom + ":" + str(position)) return exclusion_list except ValueError: exclusion_list = [] return exclusion_list def get_snps(self, task_queue, timeout): while True: try: group_dir = task_queue.get(block=True, timeout=timeout) except queue.Empty: break # Parse all vcf files to accumulate SNPs into a # data frame. positions_dict = {} group_files = [] for file_name in os.listdir(os.path.abspath(group_dir)): file_path = os.path.abspath(os.path.join(group_dir, file_name)) group_files.append(file_path) for file_name in group_files: try: found_positions, found_positions_mix = self.find_initial_positions(file_name) positions_dict.update(found_positions) except Exception as e: self.append_to_summary("Error updating the positions_dict dictionary when processing file %s:\n%s\n" % (file_name, str(e))) # Order before adding to file to match # with ordering of individual samples. # all_positions is abs_pos:REF self.all_positions = OrderedDict(sorted(positions_dict.items())) ref_positions_df = pandas.DataFrame(self.all_positions, index=['root']) all_map_qualities = {} df_list = [] for file_name in group_files: sample_df, file_name_base, sample_map_qualities = self.decide_snps(file_name) df_list.append(sample_df) all_map_qualities.update({file_name_base: sample_map_qualities}) all_sample_df = pandas.concat(df_list) # All positions have now been selected for each sample, # so select parisomony informative SNPs. This removes # columns where all fields are the same. # Add reference to top row. prefilter_df = pandas.concat([ref_positions_df, all_sample_df], join='inner') all_mq_df = pandas.DataFrame.from_dict(all_map_qualities) mq_averages = all_mq_df.mean(axis=1).astype(int) self.gather_and_filter(prefilter_df, mq_averages, group_dir) task_queue.task_done() def group_vcfs(self, vcf_files): # Parse an excel file to produce a # grouping dictionary for filtering SNPs. xl = pandas.ExcelFile(self.excel_grouper_file) sheet_names = xl.sheet_names ws = pandas.read_excel(self.excel_grouper_file, sheet_name=sheet_names[0]) defining_snps = ws.iloc[0] defsnp_iterator = iter(defining_snps.iteritems()) next(defsnp_iterator) defining_snps = {} inverted_defining_snps = {} for abs_pos, group in defsnp_iterator: if '!' in abs_pos: inverted_defining_snps[abs_pos.replace('!', '')] = group else: defining_snps[abs_pos] = group samples_groups_dict = {} for vcf_file in vcf_files: found_positions, found_positions_mix = self.find_initial_positions(vcf_file) samples_groups_dict = self.bin_input_files(vcf_file, samples_groups_dict, defining_snps, inverted_defining_snps, found_positions, found_positions_mix) # Output summary grouping table. self.append_to_summary('<br/>') self.append_to_summary('<b>Groupings with %d listed:</b><br/>\n' % len(samples_groups_dict)) self.append_to_summary('<table cellpadding="5" cellspaging="5" border="1">\n') for key, value in samples_groups_dict.items(): self.append_to_summary('<tr align="left"><th>Sample Name</th>\n') self.append_to_summary('<td>%s</td>' % key) for group in value: self.append_to_summary('<td>%s</td>\n' % group) self.append_to_summary('</tr>\n') self.append_to_summary('</table><br/>\n') def initiate_summary(self, output_summary): # Output summary file handle. self.append_to_summary('<html>\n') self.append_to_summary('<head></head>\n') self.append_to_summary('<body style=\"font-size:12px;">') self.append_to_summary("<b>Time started:</b> %s<br/>" % str(get_time_stamp())) self.append_to_summary("<b>Number of VCF inputs:</b> %d<br/>" % self.num_files) self.append_to_summary("<b>Reference:</b> %s<br/>" % str(self.reference)) self.append_to_summary("<b>All isolates:</b> %s<br/>" % str(self.all_isolates)) def return_val(self, val, index=0): # Handle element and single-element list values. if isinstance(val, list): return val[index] return val def val_as_int(self, val): # Handle integer value conversion. try: return int(val) except TypeError: # val is likely None here. return 0 if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--all_isolates', action='store', dest='all_isolates', required=False, default="No", help='Create table with all isolates'), parser.add_argument('--excel_grouper_file', action='store', dest='excel_grouper_file', required=False, default=None, help='Optional Excel filter file'), parser.add_argument('--output_summary', action='store', dest='output_summary', help='Output summary html file'), parser.add_argument('--reference', action='store', dest='reference', help='Reference file'), parser.add_argument('--processes', action='store', dest='processes', type=int, help='User-selected number of processes to use for job splitting') args = parser.parse_args() # Initializations - TODO: should these be passed in as command line args? ac = 2 mq_val = 56 n_threshold = 50 qual_threshold = 150 # Build the list of sample vcf files for the current run. vcf_files = [] for file_name in os.listdir(INPUT_VCF_DIR): file_path = os.path.abspath(os.path.join(INPUT_VCF_DIR, file_name)) vcf_files.append(file_path) multiprocessing.set_start_method('spawn') queue1 = multiprocessing.JoinableQueue() num_files = len(vcf_files) cpus = set_num_cpus(num_files, args.processes) # Set a timeout for get()s in the queue. timeout = 0.05 # Initialize the snp_finder object. snp_finder = SnpFinder(num_files, args.reference, args.excel_grouper_file, args.all_isolates, ac, mq_val, n_threshold, qual_threshold, args.output_summary) # Initialize the set of directories containiing vcf files for analysis. vcf_dirs = [] if args.excel_grouper_file is None: vcf_dirs = setup_all_vcfs(vcf_files, vcf_dirs) else: if args.all_isolates.lower() == "yes": vcf_dirs = setup_all_vcfs(vcf_files, vcf_dirs) # Parse the Excel file to detemine groups for filtering. snp_finder.group_vcfs(vcf_files) # Append the list of group directories created by # the above call to the set of directories containing # vcf files for analysis group_dirs = [d for d in os.listdir(os.getcwd()) if os.path.isdir(d) and d in snp_finder.groups] vcf_dirs.extend(group_dirs) # Populate the queue for job splitting. for vcf_dir in vcf_dirs: queue1.put(vcf_dir) # Complete the get_snps task. processes = [multiprocessing.Process(target=snp_finder.get_snps, args=(queue1, timeout, )) for _ in range(cpus)] for p in processes: p.start() for p in processes: p.join() queue1.join() # Finish summary log. snp_finder.append_to_summary("<br/><b>Time finished:</b> %s<br/>\n" % get_time_stamp()) total_run_time = datetime.now() - snp_finder.timer_start snp_finder.append_to_summary("<br/><b>Total run time:</b> %s<br/>\n" % str(total_run_time)) snp_finder.append_to_summary('</body>\n</html>\n') with open(args.output_summary, "w") as fh: fh.write("%s" % snp_finder.summary_str)