view vsnp_get_snps.py @ 0:ee4ef1fc23c6 draft

Uploaded
author greg
date Tue, 21 Apr 2020 10:14:11 -0400
parents
children 14285a94fb13
line wrap: on
line source

#!/usr/bin/env python

# Collect quality parsimonious SNPs from vcf files and output alignment files in fasta format.

import argparse
import multiprocessing
import os
import pandas
import queue
import shutil
import sys
import time
import vcf
from collections import OrderedDict
from datetime import datetime

ALL_VCFS_DIR = 'all_vcf'
INPUT_VCF_DIR = 'input_vcf_dir'
OUTPUT_JSON_AVG_MQ_DIR = 'output_json_avg_mq_dir'
OUTPUT_JSON_SNPS_DIR = 'output_json_snps_dir'
OUTPUT_SNPS_DIR = 'output_snps_dir'


def get_time_stamp():
    return datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H-%M-%S')


def set_num_cpus(num_files, processes):
    num_cpus = int(multiprocessing.cpu_count())
    if num_files < num_cpus and num_files < processes:
        return num_files
    if num_cpus < processes:
        half_cpus = int(num_cpus / 2)
        if num_files < half_cpus:
            return num_files
        return half_cpus
    return processes


def setup_all_vcfs(vcf_files, vcf_dirs):
    # Create the all_vcfs directory and link
    # all input vcf files into it for processing.
    os.makedirs(ALL_VCFS_DIR)
    vcf_dirs.append(ALL_VCFS_DIR)
    for vcf_file in vcf_files:
        file_name_base = os.path.basename(vcf_file)
        dst_file = os.path.join(ALL_VCFS_DIR, file_name_base)
        os.symlink(vcf_file, dst_file)
    return vcf_dirs


class SnpFinder:

    def __init__(self, num_files, reference, excel_grouper_file,
                 all_isolates, ac, mq_val, n_threshold, qual_threshold, output_summary):
        self.ac = ac
        self.all_isolates = all_isolates
        self.all_positions = None
        # Filter based on the contents of an Excel file.
        self.excel_grouper_file = excel_grouper_file
        # Use Genbank file
        self.groups = []
        # This will be populated from the columns
        # in the Excel filter file if it is used.
        self.mq_val = mq_val
        self.n_threshold = n_threshold
        self.qual_threshold = qual_threshold
        self.reference = reference
        self.start_time = get_time_stamp()
        self.summary_str = ""
        self.timer_start = datetime.now()
        self.num_files = num_files
        self.initiate_summary(output_summary)

    def append_to_summary(self, html_str):
        self.summary_str = "%s%s" % (self.summary_str, html_str)

    def bin_input_files(self, filename, samples_groups_dict, defining_snps, inverted_defining_snps, found_positions, found_positions_mix):
        sample_groups_list = []
        table_name = self.get_base_file_name(filename)
        try:
            defining_snp = False
            # Absolute positions in set union of two lists.
            for abs_position in list(defining_snps.keys() & (found_positions.keys() | found_positions_mix.keys())):
                group = defining_snps[abs_position]
                sample_groups_list.append(group)
                self.check_add_group(group)
                if len(list(defining_snps.keys() & found_positions_mix.keys())) > 0:
                    table_name = self.get_base_file_name(filename)
                    table_name = '%s<font color="red">[[MIXED]]</font>' % table_name
                self.copy_file(filename, group)
                defining_snp = True
            if not set(inverted_defining_snps.keys()).intersection(found_positions.keys() | found_positions_mix.keys()):
                for abs_position in list(inverted_defining_snps.keys()):
                    group = inverted_defining_snps[abs_position]
                    sample_groups_list.append(group)
                    self.check_add_group(group)
                    self.copy_file(filename, group)
                    defining_snp = True
            if defining_snp:
                samples_groups_dict[table_name] = sorted(sample_groups_list)
            else:
                samples_groups_dict[table_name] = ['<font color="red">No defining SNP</font>']
        except TypeError as e:
            msg = "<br/>Error processing file %s to generate samples_groups_dict: %s<br/>" % (filename, str(e))
            self.append_to_summary(msg)
            samples_groups_dict[table_name] = [msg]
        return samples_groups_dict

    def check_add_group(self, group):
        if group not in self.groups:
            self.groups.append(group)

    def copy_file(self, filename, dir):
        if not os.path.exists(dir):
            os.makedirs(dir)
        shutil.copy(filename, dir)

    def decide_snps(self, filename):
        positions_dict = self.all_positions
        # Find the SNPs in a vcf file to produce a pandas data
        # frame and a dictionary containing sample map qualities.
        sample_map_qualities = {}
        # Eliminate the path.
        file_name_base = self.get_base_file_name(filename)
        vcf_reader = vcf.Reader(open(filename, 'r'))
        sample_dict = {}
        for record in vcf_reader:
            alt = str(record.ALT[0])
            record_position = "%s:%s" % (str(record.CHROM), str(record.POS))
            if record_position in positions_dict:
                if alt == "None":
                    sample_dict.update({record_position: "-"})
                else:
                    # Not sure this is the best place to capture MQM average
                    # may be faster after parsimony SNPs are decided, but
                    # then it will require opening the files again.
                    # On rare occassions MQM gets called "NaN", thus passing
                    # a string when a number is expected when calculating average.
                    mq_val = self.get_mq_val(record.INFO, filename)
                    if str(mq_val).lower() not in ["nan"]:
                        sample_map_qualities.update({record_position: mq_val})
                    # Add parameters here to change what each vcf represents.
                    # SNP is represented in table, now how will the vcf represent
                    # the called position alt != "None", which means a deletion
                    # as alt is not record.FILTER, or rather passed.
                    len_alt = len(alt)
                    if len_alt == 1:
                        qual_val = self.val_as_int(record.QUAL)
                        ac = record.INFO['AC'][0]
                        ref = str(record.REF[0])
                        if ac == 2 and qual_val > self.n_threshold:
                            sample_dict.update({record_position: alt})
                        elif ac == 1 and qual_val > self.n_threshold:
                            alt_ref = "%s%s" % (alt, ref)
                            if alt_ref == "AG":
                                sample_dict.update({record_position: "R"})
                            elif alt_ref == "CT":
                                sample_dict.update({record_position: "Y"})
                            elif alt_ref == "GC":
                                sample_dict.update({record_position: "S"})
                            elif alt_ref == "AT":
                                sample_dict.update({record_position: "W"})
                            elif alt_ref == "GT":
                                sample_dict.update({record_position: "K"})
                            elif alt_ref == "AC":
                                sample_dict.update({record_position: "M"})
                            elif alt_ref == "GA":
                                sample_dict.update({record_position: "R"})
                            elif alt_ref == "TC":
                                sample_dict.update({record_position: "Y"})
                            elif alt_ref == "CG":
                                sample_dict.update({record_position: "S"})
                            elif alt_ref == "TA":
                                sample_dict.update({record_position: "W"})
                            elif alt_ref == "TG":
                                sample_dict.update({record_position: "K"})
                            elif alt_ref == "CA":
                                sample_dict.update({record_position: "M"})
                            else:
                                sample_dict.update({record_position: "N"})
                            # Poor calls
                        elif qual_val <= 50:
                            # Do not coerce record.REF[0] to a string!
                            sample_dict.update({record_position: record.REF[0]})
                        elif qual_val <= self.n_threshold:
                            sample_dict.update({record_position: "N"})
                        else:
                            # Insurance -- Will still report on a possible
                            # SNP even if missed with above statement
                            # Do not coerce record.REF[0] to a string!
                            sample_dict.update({record_position: record.REF[0]})
        # Merge dictionaries and order
        merge_dict = {}
        # abs_pos:REF
        merge_dict.update(positions_dict)
        # abs_pos:ALT replacing all_positions, because keys must be unique
        merge_dict.update(sample_dict)
        sample_df = pandas.DataFrame(merge_dict, index=[file_name_base])
        return sample_df, file_name_base, sample_map_qualities

    def df_to_fasta(self, parsimonious_df, group):
        # Generate SNP alignment file from the parsimonious_df
        # data frame.
        snps_file = os.path.join(OUTPUT_SNPS_DIR, "%s.fasta" % group)
        test_duplicates = []
        has_sequence_data = False
        for index, row in parsimonious_df.iterrows():
            for pos in row:
                if len(pos) > 0:
                    has_sequence_data = True
                    break
        if has_sequence_data:
            with open(snps_file, 'w') as fh:
                for index, row in parsimonious_df.iterrows():
                    test_duplicates.append(row.name)
                    if test_duplicates.count(row.name) < 2:
                        print(f'>{row.name}', file=fh)
                        for pos in row:
                            print(pos, end='', file=fh)
                        print("", file=fh)
        return has_sequence_data

    def find_initial_positions(self, filename):
        # Find SNP positions in a vcf file.
        found_positions = {}
        found_positions_mix = {}
        try:
            vcf_reader = vcf.Reader(open(filename, 'r'))
            try:
                for record in vcf_reader:
                    qual_val = self.val_as_int(record.QUAL)
                    chrom = record.CHROM
                    position = record.POS
                    absolute_position = "%s:%s" % (str(chrom), str(position))
                    alt = str(record.ALT[0])
                    if alt != "None":
                        mq_val = self.get_mq_val(record.INFO, filename)
                        ac = record.INFO['AC'][0]
                        len_ref = len(record.REF)
                        if ac == self.ac and len_ref == 1 and qual_val > self.qual_threshold and mq_val > self.mq_val:
                            found_positions.update({absolute_position: record.REF})
                        if ac == 1 and len_ref == 1 and qual_val > self.qual_threshold and mq_val > self.mq_val:
                            found_positions_mix.update({absolute_position: record.REF})
                return found_positions, found_positions_mix
            except (ZeroDivisionError, ValueError, UnboundLocalError, TypeError) as e:
                self.append_to_summar("<br/>Error parsing record in file %s: %s<br/>" % (filename, str(e)))
                return {'': ''}, {'': ''}
        except (SyntaxError, AttributeError) as e:
            self.append_to_summary("<br/>Error attempting to read file %s: %s<br/>" % (filename, str(e)))
            return {'': ''}, {'': ''}

    def gather_and_filter(self, prefilter_df, mq_averages, group_dir):
        # Group a data frame of SNPs.
        if self.excel_grouper_file is None:
            filtered_all_df = prefilter_df
            sheet_names = None
        else:
            # Filter positions to be removed from all.
            xl = pandas.ExcelFile(self.excel_grouper_file)
            sheet_names = xl.sheet_names
            # Use the first column to filter "all" postions.
            exclusion_list_all = self.get_position_list(sheet_names, 0)
            exclusion_list_group = self.get_position_list(sheet_names, group_dir)
            exclusion_list = exclusion_list_all + exclusion_list_group
            # Filters for all applied.
            filtered_all_df = prefilter_df.drop(columns=exclusion_list, errors='ignore')
        json_snps_file = os.path.join(OUTPUT_JSON_SNPS_DIR, "%s.json" % group_dir)
        parsimonious_df = self.get_parsimonious_df(filtered_all_df)
        samples_number, columns = parsimonious_df.shape
        if samples_number >= 4:
            has_sequence_data = self.df_to_fasta(parsimonious_df, group_dir)
            if has_sequence_data:
                json_avg_mq_file = os.path.join(OUTPUT_JSON_AVG_MQ_DIR, "%s.json" % group_dir)
                mq_averages.to_json(json_avg_mq_file, orient='split')
                parsimonious_df.to_json(json_snps_file, orient='split')
            else:
                msg = "<br/>No sequence data"
                if group_dir is not None:
                    msg = "%s for group: %s" % (msg, group_dir)
                self.append_to_summary("%s<br/>\n" % msg)
        else:
            msg = "<br/>Too few samples to build tree"
            if group_dir is not None:
                msg = "%s for group: %s" % (msg, group_dir)
            self.append_to_summary("%s<br/>\n" % msg)

    def get_base_file_name(self, file_path):
        base_file_name = os.path.basename(file_path)
        if base_file_name.find(".") > 0:
            # Eliminate the extension.
            return os.path.splitext(base_file_name)[0]
        elif base_file_name.find("_") > 0:
            # The dot extension was likely changed to
            # the " character.
            items = base_file_name.split("_")
            return "_".join(items[0:-1])
        else:
            return base_file_name

    def get_mq_val(self, record_info, filename):
        # Get the MQ (gatk) or MQM (freebayes) value
        # from the record.INFO component of the vcf file.
        try:
            mq_val = record_info['MQM']
            return self.return_val(mq_val)
        except Exception:
            try:
                mq_val = record_info['MQ']
                return self.return_val(mq_val)
            except Exception:
                msg = "Invalid or unsupported vcf header %s in file: %s\n" % (str(record_info), filename)
                sys.exit(msg)

    def get_parsimonious_df(self, filtered_all_df):
        # Get the parsimonious SNPs data frame
        # from a data frame of filtered SNPs.
        try:
            ref_series = filtered_all_df.loc['root']
            # In all_vcf root needs to be removed.
            filtered_all_df = filtered_all_df.drop(['root'])
        except KeyError:
            pass
        parsimony = filtered_all_df.loc[:, (filtered_all_df != filtered_all_df.iloc[0]).any()]
        parsimony_positions = list(parsimony)
        parse_df = filtered_all_df[parsimony_positions]
        ref_df = ref_series.to_frame()
        ref_df = ref_df.T
        parsimonious_df = pandas.concat([parse_df, ref_df], join='inner')
        return parsimonious_df

    def get_position_list(self, sheet_names, group):
        # Get a list of positions defined by an excel file.
        exclusion_list = []
        try:
            filter_to_all = pandas.read_excel(self.excel_grouper_file, header=1, usecols=[group])
            for value in filter_to_all.values:
                value = str(value[0])
                if "-" not in value.split(":")[-1]:
                    exclusion_list.append(value)
                elif "-" in value:
                    try:
                        chrom, sequence_range = value.split(":")
                    except Exception as e:
                        sys.exit(str(e))
                    value = sequence_range.split("-")
                    for position in range(int(value[0].replace(',', '')), int(value[1].replace(',', '')) + 1):
                        exclusion_list.append(chrom + ":" + str(position))
            return exclusion_list
        except ValueError:
            exclusion_list = []
            return exclusion_list

    def get_snps(self, task_queue, timeout):
        while True:
            try:
                group_dir = task_queue.get(block=True, timeout=timeout)
            except queue.Empty:
                break
            # Parse all vcf files to accumulate SNPs into a
            # data frame.
            positions_dict = {}
            group_files = []
            for file_name in os.listdir(os.path.abspath(group_dir)):
                file_path = os.path.abspath(os.path.join(group_dir, file_name))
                group_files.append(file_path)
            for file_name in group_files:
                try:
                    found_positions, found_positions_mix = self.find_initial_positions(file_name)
                    positions_dict.update(found_positions)
                except Exception as e:
                    self.append_to_summary("Error updating the positions_dict dictionary when processing file %s:\n%s\n" % (file_name, str(e)))
            # Order before adding to file to match
            # with ordering of individual samples.
            # all_positions is abs_pos:REF
            self.all_positions = OrderedDict(sorted(positions_dict.items()))
            ref_positions_df = pandas.DataFrame(self.all_positions, index=['root'])
            all_map_qualities = {}
            df_list = []
            for file_name in group_files:
                sample_df, file_name_base, sample_map_qualities = self.decide_snps(file_name)
                df_list.append(sample_df)
                all_map_qualities.update({file_name_base: sample_map_qualities})
            all_sample_df = pandas.concat(df_list)
            # All positions have now been selected for each sample,
            # so select parisomony informative SNPs.  This removes
            # columns where all fields are the same.
            # Add reference to top row.
            prefilter_df = pandas.concat([ref_positions_df, all_sample_df], join='inner')
            all_mq_df = pandas.DataFrame.from_dict(all_map_qualities)
            mq_averages = all_mq_df.mean(axis=1).astype(int)
            self.gather_and_filter(prefilter_df, mq_averages, group_dir)
            task_queue.task_done()

    def group_vcfs(self, vcf_files):
        # Parse an excel file to produce a
        # grouping dictionary for filtering SNPs.
        xl = pandas.ExcelFile(self.excel_grouper_file)
        sheet_names = xl.sheet_names
        ws = pandas.read_excel(self.excel_grouper_file, sheet_name=sheet_names[0])
        defining_snps = ws.iloc[0]
        defsnp_iterator = iter(defining_snps.iteritems())
        next(defsnp_iterator)
        defining_snps = {}
        inverted_defining_snps = {}
        for abs_pos, group in defsnp_iterator:
            if '!' in abs_pos:
                inverted_defining_snps[abs_pos.replace('!', '')] = group
            else:
                defining_snps[abs_pos] = group
        samples_groups_dict = {}
        for vcf_file in vcf_files:
            found_positions, found_positions_mix = self.find_initial_positions(vcf_file)
            samples_groups_dict = self.bin_input_files(vcf_file, samples_groups_dict, defining_snps, inverted_defining_snps, found_positions, found_positions_mix)
        # Output summary grouping table.
        self.append_to_summary('<br/>')
        self.append_to_summary('<b>Groupings with %d listed:</b><br/>\n' % len(samples_groups_dict))
        self.append_to_summary('<table  cellpadding="5" cellspaging="5" border="1">\n')
        for key, value in samples_groups_dict.items():
            self.append_to_summary('<tr align="left"><th>Sample Name</th>\n')
            self.append_to_summary('<td>%s</td>' % key)
            for group in value:
                self.append_to_summary('<td>%s</td>\n' % group)
            self.append_to_summary('</tr>\n')
        self.append_to_summary('</table><br/>\n')

    def initiate_summary(self, output_summary):
        # Output summary file handle.
        self.append_to_summary('<html>\n')
        self.append_to_summary('<head></head>\n')
        self.append_to_summary('<body style=\"font-size:12px;">')
        self.append_to_summary("<b>Time started:</b> %s<br/>" % str(get_time_stamp()))
        self.append_to_summary("<b>Number of VCF inputs:</b> %d<br/>" % self.num_files)
        self.append_to_summary("<b>Reference:</b> %s<br/>" % str(self.reference))
        self.append_to_summary("<b>All isolates:</b> %s<br/>" % str(self.all_isolates))

    def return_val(self, val, index=0):
        # Handle element and single-element list values.
        if isinstance(val, list):
            return val[index]
        return val

    def val_as_int(self, val):
        # Handle integer value conversion.
        try:
            return int(val)
        except TypeError:
            # val is likely None here.
            return 0


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--all_isolates', action='store', dest='all_isolates', required=False, default="No", help='Create table with all isolates'),
    parser.add_argument('--excel_grouper_file', action='store', dest='excel_grouper_file', required=False, default=None, help='Optional Excel filter file'),
    parser.add_argument('--output_summary', action='store', dest='output_summary', help='Output summary html file'),
    parser.add_argument('--reference', action='store', dest='reference', help='Reference file'),
    parser.add_argument('--processes', action='store', dest='processes', type=int, help='User-selected number of processes to use for job splitting')

    args = parser.parse_args()

    # Initializations - TODO: should these be passed in as command line args?
    ac = 2
    mq_val = 56
    n_threshold = 50
    qual_threshold = 150

    # Build the list of sample vcf files for the current run.
    vcf_files = []
    for file_name in os.listdir(INPUT_VCF_DIR):
        file_path = os.path.abspath(os.path.join(INPUT_VCF_DIR, file_name))
        vcf_files.append(file_path)

    multiprocessing.set_start_method('spawn')
    queue1 = multiprocessing.JoinableQueue()
    num_files = len(vcf_files)
    cpus = set_num_cpus(num_files, args.processes)
    # Set a timeout for get()s in the queue.
    timeout = 0.05

    # Initialize the snp_finder object.
    snp_finder = SnpFinder(num_files, args.reference, args.excel_grouper_file, args.all_isolates,
                           ac, mq_val, n_threshold, qual_threshold, args.output_summary)

    # Initialize the set of directories containiing vcf files for analysis.
    vcf_dirs = []
    if args.excel_grouper_file is None:
        vcf_dirs = setup_all_vcfs(vcf_files, vcf_dirs)
    else:
        if args.all_isolates.lower() == "yes":
            vcf_dirs = setup_all_vcfs(vcf_files, vcf_dirs)
        # Parse the Excel file to detemine groups for filtering.
        snp_finder.group_vcfs(vcf_files)
        # Append the list of group directories created by
        # the above call to the set of directories containing
        # vcf files for analysis
        group_dirs = [d for d in os.listdir(os.getcwd()) if os.path.isdir(d) and d in snp_finder.groups]
        vcf_dirs.extend(group_dirs)

    # Populate the queue for job splitting.
    for vcf_dir in vcf_dirs:
        queue1.put(vcf_dir)

    # Complete the get_snps task.
    processes = [multiprocessing.Process(target=snp_finder.get_snps, args=(queue1, timeout, )) for _ in range(cpus)]
    for p in processes:
        p.start()
    for p in processes:
        p.join()
    queue1.join()

    # Finish summary log.
    snp_finder.append_to_summary("<br/><b>Time finished:</b> %s<br/>\n" % get_time_stamp())
    total_run_time = datetime.now() - snp_finder.timer_start
    snp_finder.append_to_summary("<br/><b>Total run time:</b> %s<br/>\n" % str(total_run_time))
    snp_finder.append_to_summary('</body>\n</html>\n')
    with open(args.output_summary, "w") as fh:
        fh.write("%s" % snp_finder.summary_str)