from commons.core.writer.MySqlTranscriptWriter import MySqlTranscriptWriter
from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer
from SMART.Java.Python.structure.Transcript import Transcript
from SMART.Java.Python.structure.Interval import Interval
from SMART.Java.Python.mySql.MySqlConnection import MySqlConnection
from SMART.Java.Python.mySql.MySqlTranscriptTable import MySqlTranscriptTable
import unittest


class Test_MySqlTranscriptTable(unittest.TestCase):
  
    def test_getRange(self):
        transcript = Transcript()
        transcript.setName("test1.1")
        transcript.setChromosome("arm_X")
        transcript.setStart(1000)
        transcript.setEnd(4000)
        transcript.setSize(2000)
        transcript.setDirection("+")
        
        exon1 = Interval()
        exon1.setName("test1.1")
        exon1.setChromosome("arm_X")
        exon1.setStart(1000)
        exon1.setEnd(2000)
        exon1.setSize(1000)
        
        exon2 = Interval()
        exon2.setName("test1.1")
        exon2.setChromosome("arm_X")
        exon2.setStart(3000)
        exon2.setEnd(4000)
        exon2.setSize(1000)
        
        transcript.addExon(exon1)
        transcript.addExon(exon2)
        
        connection = MySqlConnection()
        writer = MySqlTranscriptWriter(connection, "testMySqlTranscriptTableGetRange")
        writer.addTranscript(transcript)
        writer.write()
        
        transcriptContainer = TranscriptContainer("testMySqlTranscriptTableGetRange", "sql")
        transcriptContainer.mySqlConnection = connection
        self.assertEqual(transcriptContainer.getNbTranscripts(), 1)
        for transcript in transcriptContainer.getIterator():
            self.assertEqual(transcript.getName(), "test1.1")
            self.assertEqual(transcript.getChromosome(), "arm_X")
            self.assertEqual(transcript.getStart(), 1000)
            self.assertEqual(transcript.getEnd(), 4000)
            self.assertEqual(transcript.getSize(), 2002)
            self.assertEqual(transcript.getNbExons(), 2)
            exons = transcript.getExons()
            self.assertEqual(exons[0].getStart(), 1000)
            self.assertEqual(exons[0].getEnd(), 2000)
            self.assertEqual(exons[1].getStart(), 3000)
            self.assertEqual(exons[1].getEnd(), 4000)
    
    
    def test_setDefaultTagValue(self):
        transcript1 = Transcript()
        transcript1.setName("test1.1")
        transcript1.setChromosome("arm_X")
        transcript1.setStart(1000)
        transcript1.setEnd(2000)
        transcript1.setDirection("+")
        
        exon1 = Interval()
        exon1.setName("test1.1")
        exon1.setChromosome("arm_X")
        exon1.setStart(1000)
        exon1.setEnd(2000)
        
        transcript1.addExon(exon1)
        
        transcript2 = Transcript()
        transcript2.setName("test2.1")
        transcript2.setChromosome("arm_X")
        transcript2.setStart(1000)
        transcript2.setEnd(2000)
        transcript2.setDirection("+")
        transcript2.setTagValue("nbOccurrences", "2")
        
        exon2 = Interval()
        exon2.setName("test2.1")
        exon2.setChromosome("arm_X")
        exon2.setStart(1000)
        exon2.setEnd(2000)
        
        transcript2.addExon(exon2)
        
        transcript3 = Transcript()
        transcript3.setName("test3.1")
        transcript3.setChromosome("arm_X")
        transcript3.setStart(1000)
        transcript3.setEnd(2000)
        transcript3.setDirection("+")
        transcript3.setTagValue("occurrences", "2")
        
        exon3 = Interval()
        exon3.setName("test3.1")
        exon3.setChromosome("arm_X")
        exon3.setStart(1000)
        exon3.setEnd(2000)
        
        transcript3.addExon(exon3)
        
        connection = MySqlConnection()
        table      = MySqlTranscriptTable(connection, "testMySqlTranscriptTableSetDefaultTagValue")
        table.createTranscriptTable()
        table.addTranscript(transcript1)
        table.addTranscript(transcript2)
        table.addTranscript(transcript3)
        table.setDefaultTagValue("occurrence", "1")
        
        cpt = 0
        for transcript in table.getIterator():
            cpt += 1
            self.assert_(cpt != 4)
            if cpt == 1:
                self.assertEqual(transcript.name, "test1.1")
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 1000)
                self.assertEqual(transcript.getEnd(), 2000)
                self.assertEqual(transcript.getSize(), 1001)
                self.assertEqual(transcript.getNbExons(), 1)
                exons = transcript.getExons()
                self.assertEqual(exons[0].getStart(), 1000)
                self.assertEqual(exons[0].getEnd(), 2000)
                self.assertEqual(transcript.getTagValue("occurrence"), 1)
            elif cpt == 2:
                self.assertEqual(transcript.name, "test2.1")
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 1000)
                self.assertEqual(transcript.getEnd(), 2000)
                self.assertEqual(transcript.getSize(), 1001)
                self.assertEqual(transcript.getNbExons(), 1)
                exons = transcript.getExons()
                self.assertEqual(exons[0].getStart(), 1000)
                self.assertEqual(exons[0].getEnd(), 2000)
                self.assertEqual(transcript.getTagValue("nbOccurrences"), 2)
                self.assertEqual(transcript.getTagValue("occurrence"), 1)
            elif cpt == 2:
                self.assertEqual(transcript.name, "test3.1")
                self.assertEqual(transcript.getChromosome(), "arm_X")
                self.assertEqual(transcript.getStart(), 1000)
                self.assertEqual(transcript.getEnd(), 2000)
                self.assertEqual(transcript.getSize(), 1001)
                self.assertEqual(transcript.getNbExons(), 1)
                exons = transcript.getExons()
                self.assertEqual(exons[0].getStart(), 1000)
                self.assertEqual(exons[0].getEnd(), 2000)
                self.assertEqual(transcript.getTagValue("occurrence"), 2)
      
        table.remove()

if __name__ == '__main__':
    unittest.main()
