# -*- coding: utf-8 -*-
"""
Created on Wed Sep  4 18:41:42 2013

@author: chmaramis
"""

from __future__ import division
import string as strpy
import numpy as np
from pandas import *
from numpy import nan as NA
import time
import sys


def filter_condition_AAjunction(x):
    x= x.strip()
    if ' ' in x:
        return x.split(' ')[0]
    else:
        return x

#-----------frame creation---------------------
def filtering(inp,cells,psorf,con,prod,CF,Vper,Vgene,laa1,laa2,conaa,Jgene,Dgene,fname):
    
    try:
        path=inp
        frame = DataFrame()
        seqlen = []
        head = []
        tp = read_csv(path, iterator=True, chunksize=5000,sep='\t', index_col=0 )
        frame = concat([chunk for chunk in tp])
        
        frcol = list(frame.columns)
        #print frcol[-1]
        if 'Unnamed' in frcol[-1]:
            del frcol[-1]
        frame=frame[frcol]
        
        frame.index = range(1,len(frame)+1)
        
        head.append('Total reads of raw data')
        seqlen.append(len(frame))
  
        #------------drop nulls--------------------        
        filtered = DataFrame()
        filtall = DataFrame()
        summ_df = DataFrame()
        filtered = frame[isnull(frame['AA JUNCTION']) | isnull(frame['V-GENE and allele'])]
        
        filtall = filtall.append(filtered)
        if len(filtall) > 0:
            filtall.loc[filtered.index,'Reason'] = "NoResults"
        frame = frame[frame['AA JUNCTION'].notnull()]
        frame = frame[frame['V-GENE and allele'].notnull()]
        
        head.append('Not Null CDR3/V')
        head.append('filter out')
        seqlen.append(len(frame))
        seqlen.append(len(filtered))
        filtered = DataFrame()
     
        if psorf.startswith('y') or psorf.startswith('Y'):
            
            cc0=np.array(frame['V-GENE and allele'].unique())
    
        
            for x in cc0:
                x1=x.split('*')
                try:
                    if (x1[1].find('P')>-1) or (x1[1].find('ORF')>-1):
                        filtered = filtered.append(frame[frame['V-GENE and allele'] == x])
                        frame['V-GENE and allele']=frame['V-GENE and allele'].replace(x,NA)
                    elif x.find('or')>-1:
                        posa=x.count('or')        
                        x2=x.split('or')
                        x4=''
                        genelist=[]        
                        for cnt in range(0, posa+1):
                            x3=x2[cnt].split('*')
                            x3[0]=x3[0].strip()#kobei ta space
                            k=x3[0].split(' ')# holds only TRBV
                            if cnt==0:
                                genelist.append(k[1])
                                x4+=k[1]
                            elif  ((str(k[1]) in genelist) == False) & (x3[1].find('P')==-1):# check for P in x3
                                genelist.append(k[1])
                                x4+=' or ' 
                                x4+=k[1]
                                x3=None
                                k1=None
                        genelist=None 
                         
                        frame['V-GENE and allele']=frame['V-GENE and allele'].replace(x,x4)
                                
                    else:
                        s=x1[0].split(' ')
                        frame['V-GENE and allele']=frame['V-GENE and allele'].replace(x,s[1])
                except IndexError as e:
                    print('V-gene is already been formed')
                    continue
            
            x=None
            x1=None
            s=None
            
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'P or ORF'
            frame = frame[frame['V-GENE and allele'].notnull()]
            
            head.append('Functional TRBV')
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
            filtered = DataFrame()
        
        
        
        #------------FILTERING for data quality--------------------
        if con.startswith('y') or con.startswith('Y'):
            filtered = frame [frame['AA JUNCTION'].str.contains('X') |
                            frame['AA JUNCTION'].str.contains('#') |
                            frame['AA JUNCTION'].str.contains('[*]')]
        
        
        
            frame = frame [~frame['AA JUNCTION'].str.contains('X') &
                            ~frame['AA JUNCTION'].str.contains('#') &
                            ~frame['AA JUNCTION'].str.contains('[*]') ]
        
    
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'X,#,*'
            head.append('Not Containing X,#,*')
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
            filtered = DataFrame()
            
        # Set label of functionality column, taking into account current & past IMGT Summary column label
        functionality_label = 'Functionality'
        if 'V-DOMAIN Functionality' in frame.columns:
            functionality_label = 'V-DOMAIN Functionality' 
            
        if prod.startswith('y') or prod.startswith('Y'): 
            filtered = frame[~frame[functionality_label].str.startswith('productive')]
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'not productive'
            
            
            frame=frame[frame[functionality_label].str.startswith('productive')]
            
            head.append('Productive')
            head.append('filter out')
            seqlen.append(len(frame))
            
            seqlen.append(len(filtered))
           
        
        frame['AA JUNCTION'] = frame['AA JUNCTION'].map(filter_condition_AAjunction)
                
        if CF.startswith('y') or CF.startswith('Y'):
            if cells == 'TCR':
                filtered = DataFrame()
                filtered = frame[~frame['AA JUNCTION'].str.startswith('C')  |
                        ~frame['AA JUNCTION'].str.endswith('F')]
                        
                filtall = filtall.append(filtered)
                if len(filtall) > 0:
                    filtall.loc[filtered.index,'Reason'] = 'Not C..F'
            
                frame = frame[frame['AA JUNCTION'].str.startswith('C') & 
                            frame['AA JUNCTION'].str.endswith('F')]
                
                head.append('CDR3 landmarks C-F')
                head.append('filter out')        
                seqlen.append(len(frame))
                seqlen.append(len(filtered))
                filtered = DataFrame()
            elif cells == 'BCR':
                filtered = DataFrame()
                filtered = frame[~frame['AA JUNCTION'].str.startswith('C')  |
                        ~frame['AA JUNCTION'].str.endswith('W')]
                        
                filtall = filtall.append(filtered)
                if len(filtall) > 0:
                    filtall.loc[filtered.index,'Reason'] = 'Not C..W'
            
                frame = frame[frame['AA JUNCTION'].str.startswith('C') & 
                            frame['AA JUNCTION'].str.endswith('W')]
                
                head.append('CDR3 landmarks C-W')
                head.append('filter out')        
                seqlen.append(len(frame))
                seqlen.append(len(filtered))
                filtered = DataFrame()
            else:
                print('TCR or BCR type')
    
    
        filtered = DataFrame()
        
        filtered = frame[frame['V-REGION identity %'] < Vper]
        
       
        filtall = filtall.append(filtered)
        if len(filtall) > 0:
            filtall.loc[filtered.index,'Reason'] = 'identity < {iden}%'.format(iden = Vper)
        
        frame=frame[frame['V-REGION identity %']>= Vper]
        head.append('Identity >= {iden}%'.format(iden = Vper))
        head.append('filter out')
        seqlen.append(len(frame))
        seqlen.append(len(filtered))
    
        head.append('Total filter out A')
        head.append('Total filter in A')
        seqlen.append(len(filtall))
        seqlen.append(len(frame))
        
        ###############################
        if Vgene != 'null':
            
            filtered = DataFrame()
        
            filtered = frame[frame['V-GENE and allele'] != Vgene]
        
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'V-GENE != {} '.format(Vgene)
        
    
            frame = frame[frame['V-GENE and allele'] == Vgene]
            
            
            
            head.append('V-GENE = {} '.format(Vgene))
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
        
            
        
        ###############################
        if (laa1 != 'null') or (laa2 != 'null'):
            if int(laa2) == 0:
                low = int(laa1)
                high = 100
            elif int(laa1) > int(laa2):
                low = int(laa2)
                high = int(laa1)
            else:
                low = int(laa1)
                high = int(laa2)
            
            filtered = DataFrame()
            criteria = frame['AA JUNCTION'].apply(lambda row: (len(row)-2) < low)
            criteria2 = frame['AA JUNCTION'].apply(lambda row: (len(row)-2) > high)
            filtered = frame[criteria | criteria2]

            filtall = filtall.append(filtered)
            if int(laa2)==0:
                if len(filtall) > 0:
                    filtall.loc[filtered.index,'Reason'] = 'CDR3 length not bigger than {}'.format(low)
            else:
                if len(filtall) > 0:
                    filtall.loc[filtered.index,'Reason'] = 'CDR3 length not from {} to {}'.format(low,high)
            
            criteria3 = frame['AA JUNCTION'].apply(lambda row: (len(row)-2) >= low)
            criteria4 = frame['AA JUNCTION'].apply(lambda row: (len(row)-2) <= high)
            frame = frame[criteria3 & criteria4] 
            
            if int(laa2)==0:            
                head.append('CDR3 length bigger than {}'.format(low))
            else:
                head.append('CDR3 length from {} to {} '.format(low,high))
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
    
        ###############################
        if conaa != 'null':
            if conaa.islower():
                conaa = conaa.upper()
            filtered = DataFrame()
        
            filtered = frame[~frame['AA JUNCTION'].str.contains(conaa)]
        
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'CDR3 not containing {}'.format(conaa)
        
            frame = frame[frame['AA JUNCTION'].str.contains(conaa) ]    
        
            head.append('CDR3 containing {}'.format(conaa))
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
        
        
        
       
        #####------------keep the small J gene name--------------------
        #frame['J-GENE and allele'] = frame['J-GENE and allele'].map(filter_condition_Jgene)
        cc2=np.array(frame['J-GENE and allele'].unique())
        
        for x in cc2:
            try:
                if notnull(x):
                    x1=x.split('*')
            #        print(x)
            #        print (x1[0]) 
                    trbj=x1[0].split(' ')
                    frame['J-GENE and allele']=frame['J-GENE and allele'].replace(x,trbj[1])
            except IndexError as e:
                print('J-Gene has been formed')
            
            
        
        x=None
        x1=None
        
        
        #------------keep the small D gene name--------------------
        cc1=np.array(frame['D-GENE and allele'].unique())
        for x in cc1:
            try:
                if notnull(x):    
                    x1=x.split('*')
                    trbd=x1[0].split(' ')
                    frame['D-GENE and allele']=frame['D-GENE and allele'].replace(x,trbd[1])
                else:
                    frame['D-GENE and allele']=frame['D-GENE and allele'].replace(x,'none')
            except IndexError as e:
                print('D-gene has been formed')
            
        
        x=None
        x1=None    
        
        
        if Jgene != 'null':
            
            filtered = DataFrame()
        
            filtered = frame[frame['J-GENE and allele'] != Jgene]
        
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'J-GENE not {} '.format(Jgene)
        
    
            frame = frame[frame['J-GENE and allele'] == Jgene]
            
            
            
            head.append('J-GENE = {} '.format(Jgene))
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
            
            

        if Dgene != 'null':
            
            filtered = DataFrame()
        
            filtered = frame[frame['D-GENE and allele'] != Dgene]
        
            filtall = filtall.append(filtered)
            if len(filtall) > 0:
                filtall.loc[filtered.index,'Reason'] = 'D-GENE not {} '.format(Dgene)
        
    
            frame = frame[frame['D-GENE and allele'] == Dgene]
            
            
            
            head.append('D-GENE = {} '.format(Dgene))
            head.append('filter out')
            seqlen.append(len(frame))
            seqlen.append(len(filtered))
        
        
        head.append('Total filter out')
        head.append('Total filter in')
        seqlen.append(len(filtall))
        seqlen.append(len(frame))
        summ_df = DataFrame(index = head)
        col = fname
       
        summ_df[col] = seqlen
        frame=frame.rename(columns = {'V-GENE and allele':'V-GENE',
        'J-GENE and allele':'J-GENE','D-GENE and allele':'D-GENE'})
        
        
        frcol.append('Reason')
        
        filtall = filtall[frcol]
    
        #--------------out CSV---------------------------      
        frame.index = range(1,len(frame)+1)
        if not summ_df.empty:
            summ_df['%'] = (100*summ_df[summ_df.columns[0]]/summ_df[summ_df.columns[0]][summ_df.index[0]]).map(('{:.4f}'.format))
        return(frame,filtall,summ_df)
    except KeyError as e:
        print('This file has no ' + str(e) + ' column')
        return(frame,filtall,summ_df)


if __name__ == '__main__':   

    start=time.time()    
    
    # Parse input arguments
    inp = sys.argv[1]
    cells = sys.argv[2]
    psorf = sys.argv[3]
    con = sys.argv[4]
    prod = sys.argv[5]
    CF = sys.argv[6]
    Vper = float(sys.argv[7])
    Vgene = sys.argv[8]
    laa1 = sys.argv[9]
    conaa = sys.argv[10]
    filterin = sys.argv[11]
    filterout = sys.argv[12]
    Sum_table = sys.argv[13]
    Jgene = sys.argv[14]
    Dgene = sys.argv[15]
    laa2 = sys.argv[16]
    fname = sys.argv[17]
        
    # Execute basic function
    fin,fout,summ = filtering(inp,cells,psorf,con,prod,CF,Vper,Vgene,laa1,laa2,conaa,Jgene,Dgene,fname)
    
    # Save output to CSV files
    if not summ.empty:
        summ.to_csv(Sum_table, sep = '\t')
    if not fin.empty:
        fin.to_csv(filterin , sep = '\t')
    if not fout.empty:        
        fout.to_csv(filterout, sep= '\t')
        
    # Print execution time
    stop=time.time()
    print('Runtime:' + str(stop-start))

