import unittest
from writer.mySqlTranscriptWriter import *
from structure.transcriptContainer import *


class TestMySqlTranscriptTable(unittest.TestCase):
  
  def setUp(self):
    pass
        

  def tearDown(self):
    pass


  def testGetRange(self):
    transcript = Transcript()
    transcript.setName("test1.1")
    transcript.setChromosome("arm_X")
    transcript.setStart(1000)
    transcript.setEnd(4000)
    transcript.setSize(2000)
    transcript.setDirection("+")

    exon1 = Interval()
    exon1.setName("test1.1")
    exon1.setChromosome("arm_X")
    exon1.setStart(1000)
    exon1.setEnd(2000)
    exon1.setSize(1000)

    exon2 = Interval()
    exon2.setName("test1.1")
    exon2.setChromosome("arm_X")
    exon2.setStart(3000)
    exon2.setEnd(4000)
    exon2.setSize(1000)

    transcript.addExon(exon1)
    transcript.addExon(exon2)

    writer = MySqlTranscriptWriter("testMySqlTranscriptTableGetRange")
    writer.addTranscript(transcript)
    writer.write()

    transcriptContainer = TranscriptContainer("testMySqlTranscriptTableGetRange", "sql")
    self.assertEqual(transcriptContainer.getNbTranscripts(), 1)
    for transcript in transcriptContainer.getIterator():
      self.assertEqual(transcript.name, "test1.1")
      self.assertEqual(transcript.chromosome, "arm_X")
      self.assertEqual(transcript.start, 1000)
      self.assertEqual(transcript.end, 4000)
      self.assertEqual(transcript.getSize(), 2000)
      self.assertEqual(transcript.getNbExons(), 2)
      exons = transcript.getExons()
      self.assertEqual(exons[0].start, 1000)
      self.assertEqual(exons[0].end, 2000)
      self.assertEqual(exons[1].start, 3000)
      self.assertEqual(exons[1].end, 4000)


  def testSetDefaultTagValue(self):
    transcript1 = Transcript()
    transcript1.setName("test1.1")
    transcript1.setChromosome("arm_X")
    transcript1.setStart(1000)
    transcript1.setEnd(2000)
    transcript1.setDirection("+")

    exon1 = Interval()
    exon1.setName("test1.1")
    exon1.setChromosome("arm_X")
    exon1.setStart(1000)
    exon1.setEnd(2000)

    transcript1.addExon(exon1)

    transcript2 = Transcript()
    transcript2.setName("test2.1")
    transcript2.setChromosome("arm_X")
    transcript2.setStart(1000)
    transcript2.setEnd(2000)
    transcript2.setDirection("+")
    transcript2.setTagValue("nbOccurrences", "2")

    exon2 = Interval()
    exon2.setName("test2.1")
    exon2.setChromosome("arm_X")
    exon2.setStart(1000)
    exon2.setEnd(2000)

    transcript2.addExon(exon2)

    transcript3 = Transcript()
    transcript3.setName("test3.1")
    transcript3.setChromosome("arm_X")
    transcript3.setStart(1000)
    transcript3.setEnd(2000)
    transcript3.setDirection("+")
    transcript3.setTagValue("occurrences", "2")

    exon3 = Interval()
    exon3.setName("test3.1")
    exon3.setChromosome("arm_X")
    exon3.setStart(1000)
    exon3.setEnd(2000)

    transcript3.addExon(exon3)

    table = MySqlTranscriptTable("testMySqlTranscriptTableSetDefaultTagValue")
    table.createTranscriptTable()
    table.addTranscript(transcript1)
    table.addTranscript(transcript2)
    table.addTranscript(transcript3)
    table.setDefaultTagValue("occurrence", "1")

    cpt = 0
    for transcript in table.getIterator():
      cpt += 1
      self.assert_(cpt != 4)
      if cpt == 1:
        self.assertEqual(transcript.name, "test1.1")
        self.assertEqual(transcript.chromosome, "arm_X")
        self.assertEqual(transcript.start, 1000)
        self.assertEqual(transcript.end, 2000)
        self.assertEqual(transcript.getSize(), 1001)
        self.assertEqual(transcript.getNbExons(), 1)
        exons = transcript.getExons()
        self.assertEqual(exons[0].start, 1000)
        self.assertEqual(exons[0].end, 2000)
        self.assertEqual(transcript.getTagValue("occurrence"), 1)
      elif cpt == 2:
        self.assertEqual(transcript.name, "test2.1")
        self.assertEqual(transcript.chromosome, "arm_X")
        self.assertEqual(transcript.start, 1000)
        self.assertEqual(transcript.end, 2000)
        self.assertEqual(transcript.getSize(), 1001)
        self.assertEqual(transcript.getNbExons(), 1)
        exons = transcript.getExons()
        self.assertEqual(exons[0].start, 1000)
        self.assertEqual(exons[0].end, 2000)
        self.assertEqual(transcript.getTagValue("nbOccurrences"), 2)
        self.assertEqual(transcript.getTagValue("occurrence"), 1)
      elif cpt == 2:
        self.assertEqual(transcript.name, "test3.1")
        self.assertEqual(transcript.chromosome, "arm_X")
        self.assertEqual(transcript.start, 1000)
        self.assertEqual(transcript.end, 2000)
        self.assertEqual(transcript.getSize(), 1001)
        self.assertEqual(transcript.getNbExons(), 1)
        exons = transcript.getExons()
        self.assertEqual(exons[0].start, 1000)
        self.assertEqual(exons[0].end, 2000)
        self.assertEqual(transcript.getTagValue("occurrence"), 2)
    
    table.remove()


if __name__ == '__main__':
    unittest.main()
