import sys
from misc.utils import *
from structure.interval import *
from structure.sequence import *


class Transcript(Interval):
  """
  A class that models an transcript, considered as a specialized interval (the bounds of the transcript) that contains exons (also represented as intervals)
  @ivar exons:         a list of exons (intervals)
  @type exons:         list of L{Interval{Interval}}
  @ival size:          size of the transcript (in number of nucleotides) [computed]
  @type size:          int
  """

  def __init__(self, transcript = None, verbosity = 0):
    """
    Constructor
    @param transcript: transcript to be copied
    @type  transcript: class L{Transcript<Transcript>}
    @param verbosity:  verbosity
    @type  verbosity:  int
    """
    super(Transcript, self).__init__(None, verbosity)
    self.exons         = []
    self.introns       = None
    self.size          = None


  def copy(self, transcript):
    """
    Copy method
    @param transcript: transcript to be copied
    @type  transcript: class L{Transcript<Transcript>} or L{Interval<Interval>}
    """
    super(Transcript, self).copy(transcript)
    if transcript.__class__.__name__ == "Transcript":
      for exon in transcript.getExons():
        exonCopy = Interval()
        exonCopy.copy(exon)
        self.addExon(exonCopy)


  def setDirection(self, direction):
    """
    Set the direction of the interval
    Possibly parse different formats
    Impact all exons
    @param direction: direction of the transcript (+ / -)
    @type  direction: int or string
    """
    super(Transcript, self).setDirection(direction)
    for exon in self.exons:
      exon.setDirection(direction)
      

  def setChromosome(self, chromosome):
    """
    Set the chromosome
    @param chromosome: chromosome on which the transcript is
    @type  chromosome: string
    """
    super(Transcript, self).setChromosome(chromosome)
    for exon in self.exons:
      exon.setChromosome(chromosome)

  
  def addExon(self, exon):
    """
    Add an exon to the list of exons
    @param exon: a new exon
    @type  exon: class L{Interval<Interval>}
    """
    self.exons.append(exon)
    self.size = None
    
    
  def getUniqueName(self):
    """
    Try to give a unique name by possibly adding occurrence
    """
    if "nbOccurrences" in self.tags and "occurrence" in self.tags and self.tags["nbOccurrences"] != 1:
      return "%s-%d" % (self.name, self.tags["occurrence"])
    return self.name


  def getNbExons(self):
    """
    Get the number of exons
    """
    return max(1, len(self.exons))


  def getExon(self, i):
    """
    Get a specific exon
    @param i: the rank of the exon
    @type  i: int
    """
    if len(self.exons) == 0:
      if i != 0:
        sys.exit("Cannot get exon #%i while there is no exon in the transcript" % (i))
      return self
    return self.exons[i]


  def getExons(self):
    """
    Get all the exons
    """
    if len(self.exons) == 0:
      self.size = self.end - self.start + 1
      return [Interval(self)]
    return self.exons


  def getIntrons(self):
    """
    Get all the introns
    Compute introns on the fly
    """
    if self.introns != None:
      return self.introns
    self.sortExons()
    self.introns = []
    exonStart = self.getExon(0)
    for exonEnd in self.exons[1:]:
      intron = Interval()
      intron.copy(exonStart)
      if self.direction == 1:
        intron.setStart(exonStart.end + 1)
        intron.setEnd(exonEnd.start - 1)
      else:
        intron.setStart(exonEnd.end + 1)
        intron.setEnd(exonStart.start - 1)
      if intron.getSize() > 0:
        self.introns.append(intron)
      exonStart = exonEnd
      intron.setSize(intron.end - intron.start + 1)
    return self.introns
  
  
  def getSize(self):
    """
    Get the size of the transcript (i.e. the number of nucleotides)
    Compute size on the fly
    """
    if self.size != None:
      return self.size
    if len(self.exons) == 0:
      self.size = self.end - self.start + 1
      return self.size
    self.size = 0
    for exon in self.exons:
      self.size += exon.size
    return self.size


  def getSizeWithIntrons(self):
    """
    Get the size of the interval (i.e. distance from start to end)
    """
    return self.end - self.start + 1


  def overlapWithExon(self, transcript):
    """
    Check if the exons of this transcript overlap with the exons of another transcript
    @param transcript: transcript to be compared to
    @type  transcript: class L{Transcript<Transcript>}
    """
    if not self.overlapWith(transcript):
      return False
    for thisExon in self.getExons():
      for thatExon in transcript.getExons():
        if thisExon.overlapWith(thatExon):
          return True
    return False
  
  
  def merge(self, transcript):
    """
    Merge with another transcript
    Merge exons if they overlap, otherwise add exons
    @param transcript: transcript to be merged to
    @type  transcript: class L{Transcript<Transcript>}
    """
    start     = self.start
    end       = self.end

    for thatExon in transcript.getExons():
      toBeRemoved = []
      for thisExon in self.getExons():
        if thisExon.overlapWith(thatExon):
          thatExon.merge(thisExon)
          toBeRemoved.append(thisExon)
      if len(self.exons) != 0:
        for element in toBeRemoved:
          self.exons.remove(element)
      start = min(start, thatExon.start)
      end   = max(end, thatExon.end)
      self.addExon(thatExon)
      
    self.setName("%s--%s" % (self.getUniqueName(), transcript.getUniqueName()))
    self.setStart(start)
    self.setEnd(end)
    self.size = None
    if "nbElements" not in self.tags:
      self.setTagValue("nbElements", 1)
    if "nbElements" not in transcript.tags:
      transcript.setTagValue("nbElements", 1)
    self.setTagValue("nbElements", self.getTagValue("nbElements") + transcript.getTagValue("nbElements"))
    for tagName in ("identity", "nbOccurrences", "occurrence", "nbMismatches", "nbGaps", "rank", "evalue", "bestRegion"):
      if tagName in self.getTagNames():
        del self.tags[tagName]


  def getSqlVariables(cls):
    """
    Get the properties of the object that should be saved in a database
    """
    variables = Interval.getSqlVariables()
    return variables
  getSqlVariables = classmethod(getSqlVariables)


  def setSqlValues(self, array):
    """
    Set the values of the properties of this object as given by a results line of a SQL query
    @param array: the values to be copied
    @type  array: a list
    """
    super(Transcript, self).setSqlValues(array)


  def getSqlValues(self):
    """
    Get the values of the properties that should be saved in a database
    """
    values = super(Transcript, self).getSqlValues()
    values["size"] = self.getSize()
    return values


  def getSqlTypes(cls):
    """
    Get the types of the properties that should be saved in a database
    """
    return Interval.getSqlTypes()
  getSqlTypes = classmethod(getSqlTypes)

  
  def getSqlSizes(cls):
    """
    Get the sizes of the properties that should be saved in a database
    """
    return Interval.getSqlSizes()
  getSqlSizes = classmethod(getSqlSizes)
  
    
  def sortExons(self):
    """
    Sort the exons
    Increasing order if transcript is on strand "+", decreasing otherwise
    """
    self.sortExonsIncreasing()
    if self.direction == -1:
      exons = self.getExons()
      exons.reverse()
      self.exons = exons
    
    
  def sortExonsIncreasing(self):
    """
    Sort the exons
    Increasing order
    """
    exons = self.getExons()
    sortedExons = []
    while len(exons) > 0:
      minExon = exons[0]
      for index in range(1, len(exons)):
        if minExon.start > exons[index].start:
          minExon  = exons[index]
      sortedExons.append(minExon)
      exons.remove(minExon)
    self.exons = sortedExons
    
    
  def extendStart(self, size):
    """
    Extend the transcript by the 5' end
    @param size: the size to be exended
    @type  size: int
    """
    if self.direction == 1:
      self.setStart(max(0, self.start - size))
    else:
      self.setEnd(self.end + size)
    if len(self.exons) != 0:
      self.sortExons()
      if self.direction == 1:
        self.exons[0].setStart(max(0, self.exons[0].start - size))
      else:
        self.exons[-1].setEnd(self.exons[-1].end + size)
    self.size = None
    self.bin  = None
    
    
  def extendEnd(self, size):
    """
    Extend the transcript by the 3' end
    @param size: the size to be exended
    @type  size: int
    """
    if self.direction == 1:
      self.setEnd(self.end + size)
    else:
      self.setStart(max(0, self.start - size))
    if len(self.exons) != 0:
      self.sortExons()
      if self.direction == 1:
        self.exons[-1].setEnd(self.exons[-1].end + size)
      else:
        self.exons[0].setStart(max(0, self.exons[0].start - size))
    self.size = None
    self.bin  = None
    
    
  def restrictStart(self, size = 1):
    """
    Restrict the transcript by some nucleotides, start from its start position
    Remove the exons
    @param size: the size to be restricted to
    @type  size: int
    """
    if self.direction == 1:
      self.setEnd(self.start + size - 1)
    else:
      self.setStart(self.end - size + 1)
    self.removeExons()
    
    
  def restrictEnd(self, size = 1):
    """
    Restrict the transcript by some nucleotides, end from its end position
    Remove the exons
    @param size: the size to be restricted to
    @type  size: int
    """
    if self.direction == 1:
      self.setStart(self.end - size + 1)
    else:
      self.setEnd(self.start + size - 1)
    self.removeExons()
    
    
  def removeExons(self):
    """
    Remove the exons and transforms the current transcript into a mere interval
    """
    self.exons = []
    self.size  = None
    self.bin   = None


  def printGff2(self, title):
    """
    Export this transcript using GFF2 format
    @param title: the title of the transcripts
    @type  title: string
    @return: a string
    """
    direction = "+"
    if self.direction == -1:
      direction = "-"
    self.sortExonsIncreasing()
    comment = self.getTagValues()
    if comment != None:
      comment = ";%s" % (comment)
    string = "%s\t%s\tmatch_set\t%d\t%d\t%d\t%s\t.\tGENE %s%s\n" % (self.chromosome, title, self.start, self.end, (self.end - self.start + 1), direction, self.name, comment)
    for exon in self.getExons():
      string += "%s\t%s\t%match_part\t%d\t%d\t%d\t%s\t.\tGENE %s\n" % (self.chromosome, title, exon.start, exon.end, exon.size, direction, self.name)
    return string
  

  def printGff3(self, title):
    """
    Export this transcript using GFF3 format
    @param title: the title of the transcripts
    @type  title: string
    @return: a string
    """
    direction = "+"
    if self.direction == -1:
      direction = "-"
    self.sortExonsIncreasing()
    if "ID" not in self.getTagValues():
      self.setTagValue("ID", self.getUniqueName())
    comment = self.getTagValues(";", "=")
    if comment != None:
      comment = ";%s" % (comment)
    string = "%s\t%s\t%s\t%d\t%d\t%d\t%s\t.\tName=%s%s\n" % (self.chromosome, title, "match_set" if len(self.getExons()) > 1 else "match", self.start, self.end, (self.end - self.start + 1), direction, self.name, comment)
    i = 0
    if len(self.getExons()) > 1:
      for exon in self.getExons():
        i      += 1
        string += "%s\t%s\tmatch_part\t%d\t%d\t%d\t%s\t.\tID=%s-exon%d;Name=%s-exon%d;Parent=%s\n" % (self.chromosome, title, exon.start, exon.end, exon.size, direction, self.getTagValue("ID"), i, self.name, i, self.getTagValue("ID"))
    return string


  def printBed(self):
    """
    Export this transcript using BED format
    @return: a string
    """
    name = self.name
    if "nbOccurrences" in self.getTagNames() and self.getTagValue("nbOccurrences") != 1:
      name = "%s-%d" % (name, self.getTagValue("occurrence"))
    comment   = self.getTagValues(";", "=")
    sizes     = []
    starts    = []
    direction = "+"
    if self.direction == -1:
      direction = "-"
    self.sortExonsIncreasing()
    for exon in self.getExons():
      sizes.append("%d" % (exon.size))
      starts.append("%d" % (exon.start - self.start))
    return "%s\t%d\t%d\t%s\t1000\t%s\t%d\t%d\t0\t%d\t%s,\t%s,\n" % (self.chromosome, self.start, self.end+1, name, direction, self.start, self.end+1, self.getNbExons(), ",".join(sizes), ",".join(starts))


  def printSam(self):
    """
    Export this transcript using SAM format
    @return: a string
    """
    name            = self.name
    flag            = 0 if self.direction == 1 else 0x10
    chromosome      = self.chromosome
    genomeStart     = self.start
    quality         = 255
    cigar           = ""
    mate            = "*"
    mateGenomeStart = 0
    gapSize         = 0
    sequence        = "*"
    qualityString   = "*"
    tags            = "NM:i:0"

    lastExonEnd = None
    for i, exon in enumerate(self.getExons()):
      cigar += "%dM" % (exon.getSize())
      if i != 0:
        cigar += "%dN" % (lastExonEnd - exon.start - 1)
      lastExonEnd = exon.end

    return "%s\t%d\t%s\t%d\t%d\t%s\t%s\t%d\t%d\t%s\t%s\t%s\n" % (name, flag, chromosome, genomeStart, quality, cigar, mate, mateGenomeStart, gapSize, sequence, qualityString, tags)


  def printUcsc(self):
    """
    Export this transcript using UCSC BED format
    @return: a string
    """
    if self.chromosome.find("Het") != -1:
      return ""
    name       = self.name
    comment    = self.getTagValues(";", "")
    sizes      = []
    starts     = []
    direction  = "+"
    if self.direction == -1:
      direction = "-"
    self.sortExonsIncreasing()
    for exon in self.getExons():
      sizes.append("%d" % (exon.size))
      starts.append("%d" % (exon.start - self.start))
    return "%s\t%d\t%d\t%s\t1000\t%s\t%d\t%d\t0\t%d\t%s,\t%s,\n" % (self.chromosome.replace("arm_", "chr"), self.start, self.end+1, name, direction, self.start, self.end+1, self.getNbExons(), ",".join(sizes), ",".join(starts))


  def printGBrowseReference(self):
    """
    Export this transcript using GBrowse format (1st line only)
    @return: a string
    """
    return "reference = %s\n" % (self.chromosome)


  def printGBrowseLine(self):
    """
    Export this transcript using GBrowse format (2nd line only)
    @return: a string
    """
    self.sortExons()
    coordinates = []
    for exon in self.getExons():
      coordinates.append(exon.printCoordinates())
    coordinatesString = ",".join(coordinates)
    comment = self.getTagValues()
    if comment == "":
      comment = "\t\"%s\"" % (comment)
    return "READS\t%s\t%s%s\n" % (self.name, coordinatesString, comment)

  
  def printGBrowse(self):
    """
    Export this transcript using GBrowse format
    @return: a string
    """
    return "%s%s" % (self.printGBrowseReference(), self.printGBrowseLine())


  def extractSequence(self, parser):
    """
    Get the sequence corresponding to this transcript
    @param parser: a parser to a FASTA file
    @type  parser: class L{SequenceListParser<SequenceListParser>}
    @return      : a instance of L{Sequence<Sequence>}
    """
    self.sortExons()
    sequence = Sequence(self.name)
    for exon in self.getExons():
      sequence.concatenate(exon.extractSequence(parser))
    return sequence
  
  
  def extractWigData(self, parser):
    """
    Get some wig data corresponding to this transcript
    @param parser: a parser to a wig file
    @type  parser: class L{WigParser<WigParser>}
    @return: a sequence of float
    """
    values = []
    self.sortExons()
    for exon in self.getExons():
      values.extend(exon.extractWigData(parser))
    return values
