view cravat_convert/vcf_converter.py @ 18:cde9e74d9fdf draft

Uploaded
author in_silico
date Wed, 18 Jul 2018 09:41:20 -0400
parents c042835a7163
children
line wrap: on
line source

"""
A module originally obtained from the cravat package. Modified to use in the vcf
converter galaxy tool.


Register of changes made (Chris Jacoby):
    1) Changed imports as galaxy tool won't have access to complete cravat python package
    2) Defined BadFormatError in BaseConverted file, as I didn't have the BadFormatError module
"""

from base_converter import BaseConverter, BadFormatError
import re

class CravatConverter(BaseConverter):
    
    def __init__(self):
        self.format_name = 'vcf'
        self.samples = []
        self.var_counter = 0
        self.addl_cols = [{'name':'phred',
                           'title':'Phred',
                           'type':'string'},
                          {'name':'filter',
                           'title':'VCF filter',
                           'type':'string'},
                          {'name':'zygosity',
                           'title':'Zygosity',
                           'type':'string'},
                          {'name':'alt_reads',
                           'title':'Alternate reads',
                           'type':'int'},
                          {'name':'tot_reads',
                           'title':'Total reads',
                           'type':'int'},
                          {'name':'af',
                           'title':'Variant allele frequency',
                           'type':'float'}]
    
    def check_format(self, f): 
        return f.readline().startswith('##fileformat=VCF')
    
    def setup(self, f):
        
        vcf_line_no = 0
        for line in f:
            vcf_line_no += 1
            if len(line) < 6:
                continue
            if line[:6] == '#CHROM':
                toks = re.split('\s+', line.rstrip())
                if len(toks) > 8:
                    self.samples = toks[9:]
                break
    
    def convert_line(self, l):
        if l.startswith('#'): return None
        self.var_counter += 1
        toks = l.strip('\r\n').split('\t')
        all_wdicts = []
        if len(toks) < 8:
            raise BadFormatError('Wrong VCF format')
        [chrom, pos, tag, ref, alts, qual, filter, info] = toks[:8]
        if tag == '':
            raise BadFormatError('ID column is blank')
        elif tag == '.':
            tag = 'VAR' + str(self.var_counter)
        if chrom[:3] != 'chr':
            chrom = 'chr' + chrom
        alts = alts.split(',')
        len_alts = len(alts)
        if len(toks) == 8:
            for altno in range(len_alts):
                wdict = None
                alt = alts[altno]
                newpos, newref, newalt = self.extract_vcf_variant('+', pos, ref, alt)
                wdict = {'tags':tag,
                         'chrom':chrom,
                         'pos':newpos,
                         'ref_base':newref,
                         'alt_base':newalt,
                         'sample_id':'no_sample',
                         'phred': qual,
                         'filter': filter}
                all_wdicts.append(wdict)
        elif len(toks) > 8:
            sample_datas = toks[9:]
            genotype_fields = {}
            genotype_field_no = 0
            for genotype_field in toks[8].split(':'):
                genotype_fields[genotype_field] = genotype_field_no
                genotype_field_no += 1
            if not ('GT' in genotype_fields):
                raise BadFormatError('No GT Field')
            gt_field_no = genotype_fields['GT']
            for sample_no in range(len(sample_datas)):
                sample = self.samples[sample_no]
                sample_data = sample_datas[sample_no].split(':')
                gts = {}
                for gt in sample_data[gt_field_no].replace('/', '|').split('|'):
                    if gt == '.':
                        continue
                    else:
                        gts[int(gt)] = True
                for gt in sorted(gts.keys()):
                    wdict = None
                    if gt == 0:
                        continue
                    else:
                        alt = alts[gt - 1]
                        newpos, newref, newalt = self.extract_vcf_variant('+', pos, ref, alt)
                        zyg = self.homo_hetro(sample_data[gt_field_no])
                        depth, alt_reads, af = self.extract_read_info(sample_data, gt, gts, genotype_fields)
                            
                        wdict = {'tags':tag,
                                 'chrom':chrom,
                                 'pos':newpos,
                                 'ref_base':newref,
                                 'alt_base':newalt,
                                 'sample_id':sample,
                                 'phred': qual,
                                 'filter': filter,
                                 'zygosity': zyg,
                                 'tot_reads': depth,
                                 'alt_reads': alt_reads,
                                 'af': af,                                
                                 }
                        all_wdicts.append(wdict)
        return all_wdicts
 
    #The vcf genotype string has a call for each allele separated by '\' or '/'
    #If the call is the same for all allels, return 'hom' otherwise 'het'
    def homo_hetro(self, gt_str):
        if '.' in gt_str:
            return '';
        
        gts = gt_str.strip().replace('/', '|').split('|')
        for gt in gts:
            if gt != gts[0]:
                return 'het'
        return 'hom'            
                        
    #Extract read depth, allele count, and allele frequency from optional VCR information
    def extract_read_info (self, sample_data, gt, gts, genotype_fields): 
        depth = ''
        alt_reads = ''
        ref_reads = ''
        af = ''
        
        #AD contains 2 values usually ref count and alt count unless there are 
        #multiple alts then it will have alt 1 then alt 2.
        if 'AD' in genotype_fields and genotype_fields['AD'] <= len(sample_data): 
            if 0 in gts.keys():
                #if part of the genotype is reference, then AD will have #ref reads, #alt reads
                ref_reads = sample_data[genotype_fields['AD']].split(',')[0]
                alt_reads = sample_data[genotype_fields['AD']].split(',')[1]
            elif gt == max(gts.keys()):    
                #if geontype has multiple alt bases, then AD will have #alt1 reads, #alt2 reads
                alt_reads = sample_data[genotype_fields['AD']].split(',')[1]
            else:
                alt_reads = sample_data[genotype_fields['AD']].split(',')[0]                            
                             
        if 'DP' in genotype_fields and genotype_fields['DP'] <= len(sample_data): 
            depth = sample_data[genotype_fields['DP']] 
        elif alt_reads != '' and ref_reads != '':
            #if DP is not present but we have alt and ref reads count, dp = ref+alt
            depth = int(alt_reads) + int(ref_reads)   

        if 'AF' in genotype_fields and genotype_fields['AF'] <= len(sample_data):
            af = float(sample_data[genotype_fields['AF']] )
        elif depth != '' and alt_reads != '':
            #if AF not specified, calc it from alt and ref reads
            af = float(alt_reads) / float(depth)
 
        return depth, alt_reads, af
            
    def extract_vcf_variant (self, strand, pos, ref, alt):

        reflen = len(ref)
        altlen = len(alt)
        
        # Returns without change if same single nucleotide for ref and alt. 
        if reflen == 1 and altlen == 1 and ref == alt:
            return pos, ref, alt
        
        # Trimming from the start and then the end of the sequence 
        # where the sequences overlap with the same nucleotides
        new_ref2, new_alt2, new_pos = \
            self.trimming_vcf_input(ref, alt, pos, strand)
                
        if new_ref2 == '':
            new_ref2 = '-'
        if new_alt2 == '':
            new_alt2 = '-'
        
        return new_pos, new_ref2, new_alt2
    
    # This function looks at the ref and alt sequences and removes 
    # where the overlapping sequences contain the same nucleotide.
    # This trims from the end first but does not remove the first nucleotide 
    # because based on the format of VCF input the 
    # first nucleotide of the ref and alt sequence occur 
    # at the position specified.
    #     End removed first, not the first nucleotide
    #     Front removed and position changed
    def trimming_vcf_input(self, ref, alt, pos, strand):
        pos = int(pos)
        reflen = len(ref)
        altlen = len(alt)
        minlen = min(reflen, altlen)
        new_ref = ref
        new_alt = alt
        new_pos = pos
        # Trims from the end. Except don't remove the first nucleotide. 
        # 1:6530968 CTCA -> GTCTCA becomes C -> GTC.
        for nt_pos in range(0, minlen - 1): 
            if ref[reflen - nt_pos - 1] == alt[altlen - nt_pos - 1]:
                new_ref = ref[:reflen - nt_pos - 1]
                new_alt = alt[:altlen - nt_pos - 1]
            else:
                break    
        new_ref_len = len(new_ref)
        new_alt_len = len(new_alt)
        minlen = min(new_ref_len, new_alt_len)
        new_ref2 = new_ref
        new_alt2 = new_alt
        # Trims from the start. 1:6530968 G -> GT becomes 1:6530969 - -> T.
        for nt_pos in range(0, minlen):
            if new_ref[nt_pos] == new_alt[nt_pos]:
                if strand == '+':
                    new_pos += 1
                elif strand == '-':
                    new_pos -= 1
                new_ref2 = new_ref[nt_pos + 1:]
                new_alt2 = new_alt[nt_pos + 1:]
            else:
                new_ref2 = new_ref[nt_pos:]
                new_alt2 = new_alt[nt_pos:]
                break  
        return new_ref2, new_alt2, new_pos


if __name__ == "__main__":
    c = CravatConverter()