import unittest
from SMART.Java.Python.structure.Interval import Interval

class Test_Interval(unittest.TestCase):

    def setUp(self):
        self.iInterval = Interval()
        self.iInterval1 = Interval()
        self.iInterval2 = Interval()
        
    def test__init__(self):
        self.iInterval.setChromosome("chromosome")
        self.iInterval.setName("sequence")
        self.iInterval.setStart(0)
        self.iInterval.setEnd(123)
        obsStart = self.iInterval.getStart()
        obsEnd = self.iInterval.getEnd()
        expStart = 0
        expEnd = 123
        
        self.assertEqual(expStart, obsStart)
        self.assertEqual(expEnd, obsEnd)

    def test_copy(self):
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(300)
        self.iInterval1.setDirection("+")

        self.iInterval2.copy(self.iInterval1)
        self.assertEqual(self.iInterval2.getName(), "interval1")
        self.assertEqual(self.iInterval2.getChromosome(), "chr1")
        self.assertEqual(self.iInterval2.getStart(), 100)
        self.assertEqual(self.iInterval2.getEnd(), 300)
        self.assertEqual(self.iInterval2.getDirection(), 1)

        self.iInterval1.setStart(200)
        self.assertEqual(self.iInterval2.getStart(), 100)
        
    def test_getDirection(self):
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(300)
        self.iInterval1.setDirection("+")
        expDirect = 1
        self.assertEquals(expDirect,self.iInterval1.getDirection())

    #!!!! Warning: two methods getStart() and getEnd() give the information maximum and minimum in interval.!!!!#
    #In case strand = "+", start < end; strand = "-", start > end   
    def test_setStartEnd(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(300)
        self.iInterval1.setDirection("+")
        
        self.assertEqual(self.iInterval1.getName(), "interval1")
        self.assertEqual(self.iInterval1.getChromosome(), "chr1")
        self.assertEqual(self.iInterval1.getStart(),100)
        self.assertEqual(self.iInterval1.getEnd(), 300)
        self.assertEqual(self.iInterval1.getDirection(), 1)

        self.iInterval1.setStart(200)
        self.assertEqual(self.iInterval1.getStart(), 200)
        self.assertEqual(self.iInterval1.getEnd(), 300)

        self.iInterval1.setEnd(300)
        self.iInterval1.setStart(100)
        self.assertEqual(self.iInterval1.getStart(), 100)
        self.assertEqual(self.iInterval1.getEnd(), 300)

        self.iInterval1.setEnd(1200)
        self.iInterval1.setStart(1000)
        self.assertEqual(self.iInterval1.getStart(), 1000)
        self.assertEqual(self.iInterval1.getEnd(), 1200)

        self.iInterval1.reverse()
        self.assertEqual(self.iInterval1.getDirection(), -1)
        self.assertEqual(self.iInterval1.getStart(), 1000)
        self.assertEqual(self.iInterval1.getEnd(), 1200)

        self.iInterval1.setStart(1100)
        self.assertEqual(self.iInterval1.getStart(), 1100)
        self.assertEqual(self.iInterval1.getEnd(), 1200)

        self.iInterval1.setEnd(2200)
        self.iInterval1.setStart(2000)
        self.assertEqual(self.iInterval1.getStart(), 2000)
        self.assertEqual(self.iInterval1.getEnd(), 2200)

        self.iInterval1.setStart(1000)
        self.iInterval1.setEnd(1200)
        self.assertEqual(self.iInterval1.getStart(), 1000)
        self.assertEqual(self.iInterval1.getEnd(), 1200)

    def test_reverse(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
        self.iInterval1.reverse()
        self.assertEqual(self.iInterval1.getStart(), 100)
        self.assertEqual(self.iInterval1.getEnd(), 200)
        self.assertEqual(self.iInterval1.getDirection(), -1)

    def test_overlapWith(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
        
        self.iInterval2 = Interval()
        self.iInterval2.copy(self.iInterval1)
        self.iInterval2.setName("interval2")

        self.assertTrue(self.iInterval1.overlapWith(self.iInterval2))

        self.iInterval2.setChromosome("chr2")
        try:
            self.iInterval1.overlapWith(self.iInterval2)
            self.fail()
        except Exception:
            pass

        self.iInterval2.setChromosome("chr1")
        self.iInterval2.setEnd(400)
        self.iInterval2.setStart(300)
        self.assertFalse(self.iInterval1.overlapWith(self.iInterval2))
        
        self.iInterval2.setStart(200)
        self.assertTrue(self.iInterval1.overlapWith(self.iInterval2))

    def test_isIncludeIn(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
        
        self.iInterval2 = Interval()
        self.iInterval2.setName("interval2")
        self.iInterval2.setChromosome("chr1")
        self.iInterval2.setStart(80)
        self.iInterval2.setEnd(280)
        self.iInterval2.setDirection("+")        
        self.assertTrue(self.iInterval1.isIncludeIn(self.iInterval2))
        
    def test_getDistance(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
        
        self.iInterval2 = Interval()
        self.iInterval2.copy(self.iInterval1)
        self.iInterval2.setName("interval2")

        self.assertEqual(self.iInterval1.getDistance(self.iInterval2), 0)
        self.assertEqual(self.iInterval2.getDistance(self.iInterval1), 0)

        self.iInterval2.setChromosome("chr2")
        try:
            self.iInterval1.getDistance(self.iInterval2)
            self.fail()
        except Exception:
            pass

        self.iInterval2.setChromosome("chr1")
        self.iInterval2.setEnd(400)
        self.iInterval2.setStart(300)
        self.assertEqual(self.iInterval1.getDistance(self.iInterval2), 100)
        self.assertEqual(self.iInterval2.getDistance(self.iInterval1), 100)
        
    def test_getRelativeDistance(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
        
        self.iInterval2 = Interval()
        self.iInterval2.copy(self.iInterval1)
        self.iInterval2.setName("interval2")

        self.assertEqual(self.iInterval1.getDistance(self.iInterval2), 0)
        self.assertEqual(self.iInterval2.getDistance(self.iInterval1), 0)

        self.iInterval2.setChromosome("chr2")
        try:
            self.iInterval1.getDistance(self.iInterval2)
            self.fail()
        except Exception:
            pass

        self.iInterval2.setChromosome("chr1")
        self.iInterval2.setEnd(400)
        self.iInterval2.setStart(300)
        self.assertEqual(self.iInterval1.getRelativeDistance(self.iInterval2), 100)
        self.assertEqual(self.iInterval2.getRelativeDistance(self.iInterval1), -100)
        
    def test_merge(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
        
        self.iInterval2 = Interval()
        self.iInterval2.copy(self.iInterval1)
        self.iInterval2.setName("interval2")
        self.iInterval2.merge(self.iInterval1)

        self.assertEqual(self.iInterval1, self.iInterval2)

        self.iInterval2.setChromosome("chr2")
        expMessage = "Cannot merge '%s' and '%s' for they are on different chromosomes." % (str(self.iInterval2), str(self.iInterval1))
        isExceptionRaised = False
        try:
            self.iInterval2.merge(self.iInterval1)
        except Exception, e:
            isExceptionRaised = True
        obsMessage = str(e)

        self.assertTrue(isExceptionRaised)
        self.assertEquals(expMessage, obsMessage)
        #Warning! Both two intervals should be on the same chromosome and direction.  
        self.iInterval2.setChromosome("chr1")
        self.iInterval2.setStart(300)
        self.iInterval2.setEnd(400)
        self.iInterval2.merge(self.iInterval1)
        self.assertEqual(self.iInterval2.getStart(), 100)
        self.assertEqual(self.iInterval2.getEnd(), 400)
        self.assertEqual(self.iInterval2.getChromosome(), "chr1")

    def test_include(self):
        iInterval1 = Interval()
        iInterval1.setName("interval1")
        iInterval1.setChromosome("chr1")
        iInterval1.setStart(100)
        iInterval1.setEnd(200)
        iInterval1.setDirection("+")
        
        iInterval2 = Interval()
        iInterval2.copy(iInterval1)
        iInterval2.setName("interval2")
        self.assertTrue(iInterval1.include(iInterval2))
        self.assertTrue(iInterval2.include(iInterval1))

        iInterval2.setChromosome("chr2")
        self.assertFalse(iInterval1.include(iInterval2))
        self.assertFalse(iInterval2.include(iInterval1))

        iInterval2.setChromosome("chr1")
        iInterval1.setStart(1)
        self.assertTrue(iInterval1.include(iInterval2))
        self.assertFalse(iInterval2.include(iInterval1))
        
        iInterval1.setStart(100)
        iInterval1.setEnd(300)
        self.assertTrue(iInterval1.include(iInterval2))
        self.assertFalse(iInterval2.include(iInterval1))
        

    def test_getDifference(self):
        iInterval1 = Interval()
        iInterval1.setName("interval1")
        iInterval1.setChromosome("chr1")
        iInterval1.setStart(100)
        iInterval1.setEnd(400)
        iInterval1.setDirection("+")
        
        iInterval2 = Interval()
        iInterval2.copy(iInterval1)
        iInterval2.setName("interval2")
        self.assertEqual(iInterval1.getDifference(iInterval2), [])
        self.assertEqual(iInterval2.getDifference(iInterval1), [])

        iInterval2.setChromosome("chr2")
        results = iInterval1.getDifference(iInterval2)
        self.assertEqual(len(results), 1)
        resultInterval = results[0]
        self.assertEqual(resultInterval.getStart(),      iInterval1.getStart())
        self.assertEqual(resultInterval.getEnd(),        iInterval1.getEnd())
        self.assertEqual(resultInterval.getDirection(),  iInterval1.getDirection())
        self.assertEqual(resultInterval.getChromosome(), iInterval1.getChromosome())

        iInterval2.setChromosome("chr1")
        iInterval2.setEnd(300)
        results = iInterval1.getDifference(iInterval2)
        self.assertEqual(len(results), 1)
        resultInterval = results[0]
        self.assertEqual(resultInterval.getStart(),      301)
        self.assertEqual(resultInterval.getEnd(),        iInterval1.getEnd())
        self.assertEqual(resultInterval.getDirection(),  iInterval1.getDirection())
        self.assertEqual(resultInterval.getChromosome(), iInterval1.getChromosome())
        
        iInterval2.setDirection("-")
        results = iInterval1.getDifference(iInterval2, True)
        self.assertEqual(len(results), 1)
        resultInterval = results[0]
        self.assertEqual(resultInterval.getStart(),      iInterval1.getStart())
        self.assertEqual(resultInterval.getEnd(),        iInterval1.getEnd())
        self.assertEqual(resultInterval.getDirection(),  iInterval1.getDirection())
        self.assertEqual(resultInterval.getChromosome(), iInterval1.getChromosome())
        
        iInterval2.setDirection("+")
        iInterval2.setStart(200)
        results = iInterval1.getDifference(iInterval2)
        self.assertEqual(len(results), 2)
        resultInterval1, resultInterval2 = results
        self.assertEqual(resultInterval1.getStart(),      iInterval1.getStart())
        self.assertEqual(resultInterval1.getEnd(),        199)
        self.assertEqual(resultInterval1.getDirection(),  iInterval1.getDirection())
        self.assertEqual(resultInterval1.getChromosome(), iInterval1.getChromosome())
        self.assertEqual(resultInterval2.getStart(),      301)
        self.assertEqual(resultInterval2.getEnd(),        iInterval1.getEnd())
        self.assertEqual(resultInterval2.getDirection(),  iInterval1.getDirection())
        self.assertEqual(resultInterval2.getChromosome(), iInterval1.getChromosome())

        iInterval2.setEnd(2000)
        iInterval2.setStart(1000)
        results = iInterval1.getDifference(iInterval2)
        self.assertEqual(len(results), 1)
        resultInterval = results[0]
        self.assertEqual(resultInterval.getStart(),      iInterval1.getStart())
        self.assertEqual(resultInterval.getEnd(),        iInterval1.getEnd())
        self.assertEqual(resultInterval.getDirection(),  iInterval1.getDirection())
        self.assertEqual(resultInterval.getChromosome(), iInterval1.getChromosome())
 
    def test_mergeWithDifferentStrand(self):
        self.iInterval1 = Interval()
        self.iInterval1.setName("interval1")
        self.iInterval1.setChromosome("chr1")
        self.iInterval1.setStart(100)
        self.iInterval1.setEnd(200)
        self.iInterval1.setDirection("+")
   
        self.iInterval2 = Interval()
        self.iInterval2.setName("interval2")
        self.iInterval2.setChromosome("chr1")
        self.iInterval2.setStart(300)
        self.iInterval2.setEnd(400)
        self.iInterval2.setDirection("-")

        expMessage = "Cannot merge '%s' and '%s' for they are on different strands." % (str(self.iInterval2), str(self.iInterval1))
        isExceptionRaised = False
        try:
            self.iInterval2.merge(self.iInterval1)
        except Exception, e:
            isExceptionRaised = True
        obsMessage = str(e)

        self.assertTrue(isExceptionRaised)
        self.assertEquals(expMessage, obsMessage)

if __name__ == "__main__":
    unittest.main()
