#!/usr/bin/env

## Generate basic summary stats for SRST2 (v2) allele score output. Generate summaries for each profile defined in the database
## author: errol strain, estrain@gmail.com

from argparse import (ArgumentParser, FileType)
import sys
import glob 
from decimal import Decimal

def parse_args():
  "Parse the input arguments, use '-h' for help."

  parser = ArgumentParser(description='Generate Summary Scores for SRST2 Allele Score Output')

  # Read inputs
  parser.add_argument('--mlst_definitions', type=str, required=True, nargs=1, help='MLST Definitions')
  parser.add_argument('--output', type=str, required=True, nargs=1, help='MLST Definitions')
  parser.add_argument('--profile_cov', type=str, required=False, help='Minimum Average Coverage to Report ST Profile',default="98")
  parser.add_argument('--profile_max_mismatch', type=str, required=False, help='Maximum Number of Mismatches (SNPs)', default="1")

  return parser.parse_args()

args =parse_args()

allHash = {}

# Read in Profile Database
profiles = open(args.mlst_definitions[0],"r")
output = open(args.output[0],"w")

# Minimum mean percent coverage for reporting profile
min_per=float(args.profile_cov)
# Maximum mean mismatch for reporting profile
max_mis=float(args.profile_max_mismatch)

# Read in Allele Scores
# Scores should be in srts2*.scores file
# Column 0:Allele, 1:Score, 2:Avg Depth, 5:% Coverage, 7:Mismatches, 8:Indels
scoreFile = open(glob.glob("srst2*.scores")[0],"r")
scoreFile.readline()

for line in scoreFile.readlines() :
  els = line.split("\t")
  vals = els[0].split("_")
  allHash.update({els[0]:line})
  

# Allele names in first row
als = profiles.readline().rstrip()
filehead = als + str("\tMean_Score\tMean_Depth\tMean_%_Coverage\tTotal_Mismatches\tTotal_Indels\n")
output.write(filehead)

genes = als.split("\t")

for line in profiles.readlines() :
  line = line.rstrip()
  els = line.split("\t")
  alleles = [] 
  for i in range(1,len(genes)) :
    alleles.append(genes[i] + "_" + els[i])  
  mscore=0
  mdepth=0
  mcover=0
  mmisma=0
  mindel=0
  for i in alleles :
    if i in allHash :
      vals=str(allHash[i]).split("\t")
      mscore+=float(vals[1])
      mdepth+=float(vals[2])
      mcover+=float(vals[5])
      mmisma+=int(vals[7])
      mindel+=int(vals[8])

  mscore=round(Decimal(mscore/(len(genes)-1)),5)
  mdepth=round(Decimal(mdepth/(len(genes)-1)),2)
  mcover=round(Decimal(mcover/(len(genes)-1)),2)

  if mmisma<=max_mis and mcover>=min_per :
    output.write(line + "\t" + str(mscore) + "\t" + str(mdepth) + "\t" + str(mcover) + "\t" + str(mmisma) + "\t" + str(mindel) + "\n") 

