import unittest
from structure.transcriptListsComparator import *


class TestTranscriptListsComparator(unittest.TestCase):
  
  def setUp(self):
    pass
        

  def tearDown(self):
    pass


  def testCompareTranscriptList(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptList1.bed", "bed")
    container2 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptList2.bed", "bed")

    comparator = TranscriptListsComparator()
    comparator.computeOdds(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
    comparator.compareTranscriptList()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 2)

    cpt = 0
    for transcript in outputContainer.getIterator():
      if cpt == 0:
        self.assertEqual(transcript.chromosome, "arm_X")
        self.assertEqual(transcript.start, 1000)
        self.assertEqual(transcript.end, 1999)
        self.assertEqual(transcript.direction, 1)
      elif cpt == 1:
        self.assertEqual(transcript.chromosome, "arm_X")
        self.assertEqual(transcript.start, 1000)
        self.assertEqual(transcript.end, 1999)
        self.assertEqual(transcript.direction, -1)
      cpt += 1


  def testCompareTranscriptListDistanceSimple(self):
    container1 = TranscriptContainer("TestFiles/testCompareTranscriptListDistanceSimple1.gff3", "gff")
    container2 = TranscriptContainer("TestFiles/testCompareTranscriptListDistanceSimple2.gff3", "gff")

    comparator = TranscriptListsComparator()
    comparator.computeOdds(True)
    comparator.setMaxDistance(1000)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
    distances = comparator.compareTranscriptListDistance()

    self.assertEqual(distances, {0: 1})

    comparator = TranscriptListsComparator()
    comparator.computeOdds(True)
    comparator.setMaxDistance(1000)
    comparator.setInputTranscriptContainer(comparator.QUERY, container2)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container1)
    distances = comparator.compareTranscriptListDistance()

    self.assertEqual(distances, {0: 1, -1000: 1})


  def testCompareTranscriptListDistanceAntisense(self):
    container1 = TranscriptContainer("TestFiles/testCompareTranscriptListDistanceAntisense1.gff3", "gff")
    container2 = TranscriptContainer("TestFiles/testCompareTranscriptListDistanceAntisense2.gff3", "gff")

    comparator = TranscriptListsComparator()
    comparator.computeOdds(True)
    comparator.setMaxDistance(10000)
    comparator.getAntisenseOnly(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
    distances = comparator.compareTranscriptListDistance()

    self.assertEqual(distances, {1000: 1})



  def testCompareTranscriptListMergeSimple(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSimple1.bed", "bed")
    container2 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSimple2.bed", "bed")

    comparator = TranscriptListsComparator()
    comparator.computeOdds(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
    comparator.compareTranscriptListMerge()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 1)
    for transcript in outputContainer.getIterator():
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 1000)
      self.assertEqual(transcript.end, 3999)
      self.assertEqual(transcript.direction, 1)
      self.assertEqual(transcript.getSize(), 3000)


  def testCompareTranscriptListMergeSenseAntiSenseAway(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSenseAntiSenseAway1.bed", "bed")
    container2 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeSenseAntiSenseAway2.bed", "bed")

    comparator = TranscriptListsComparator()
    comparator.restrictToStart(comparator.QUERY, 1)
    comparator.restrictToStart(comparator.REFERENCE, 1)
    comparator.extendFivePrime(comparator.REFERENCE, 150)
    comparator.getAntisenseOnly(True)
    comparator.computeOdds(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
    comparator.compareTranscriptListMerge()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 1)
    for transcript in outputContainer.getIterator():
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 10000049)
      self.assertEqual(transcript.end, 10000199)
      self.assertEqual(transcript.getSize(), 151)
      self.assertEqual(transcript.getNbExons(), 1)


  def testCompareTranscriptListMergeAggregation(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeAggregation1.bed", "bed")
    container2 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListMergeAggregation2.bed", "bed")

    comparator = TranscriptListsComparator()
    comparator.getColinearOnly(True)
    comparator.computeOdds(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.setInputTranscriptContainer(comparator.REFERENCE, container2)
    comparator.aggregate(True)
    comparator.compareTranscriptListMerge()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 1)
    for transcript in outputContainer.getIterator():
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 10000000)
      self.assertEqual(transcript.end, 10000199)
      self.assertEqual(transcript.getSize(), 200)
      self.assertEqual(transcript.getNbExons(), 1)


  def testCompareTranscriptListSelfMerge(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListSelfMerge1.gff3", "gff")

    comparator = TranscriptListsComparator()
    comparator.computeOdds(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.compareTranscriptListSelfMerge()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 1)
    for transcript in outputContainer.getIterator():
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 1000)
      self.assertEqual(transcript.end, 2000)
      self.assertEqual(transcript.direction, 1)
      self.assertEqual(transcript.getNbExons(), 1)
      self.assertEqual(transcript.getSize(), 1001)
      self.assertEqual(transcript.getTagValue("nbElements"), 3)


  def testCompareTranscriptListSelfMergeSense(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListSelfMergeSense1.bed", "bed")

    comparator = TranscriptListsComparator()
    comparator.getColinearOnly(True)
    comparator.computeOdds(True)
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.compareTranscriptListSelfMerge()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 1)
    for transcript in outputContainer.getIterator():
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 1000)
      self.assertEqual(transcript.end, 5999)
      self.assertEqual(transcript.direction, 1)
      self.assertEqual(transcript.getNbExons(), 3)
      self.assertEqual(transcript.getSize(), 3000)


  def testCompareTranscriptListSelfMergeDifferentClusters(self):
    container1 = TranscriptContainer("TestFiles/testTranscriptListsComparatorCompareTranscriptListSelfMergeDifferentClusters1.bed", "bed")

    comparator = TranscriptListsComparator()
    comparator.setInputTranscriptContainer(comparator.QUERY, container1)
    comparator.compareTranscriptListSelfMerge()

    outputContainer = comparator.getOutputTranscripts()

    self.assertEqual(outputContainer.getNbTranscripts(), 1)
    for transcript in outputContainer.getIterator():
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 100)
      self.assertEqual(transcript.end, 100099)
      self.assertEqual(transcript.direction, 1)
      self.assertEqual(transcript.getNbExons(), 1)
      self.assertEqual(transcript.getSize(), 100000)



if __name__ == '__main__':
    unittest.main()
