view SMART/Java/Python/test/Test_F_mergeTranscriptLists.py @ 31:0ab839023fe4

Uploaded
author m-zytnicki
date Tue, 30 Apr 2013 14:33:21 -0400
parents 94ab73e8a190
children
line wrap: on
line source

import unittest
import os, os.path, glob
from SMART.Java.Python.structure.Transcript import Transcript
from SMART.Java.Python.mergeTranscriptLists import MergeLists
from commons.core.writer.Gff3Writer import Gff3Writer
from commons.core.parsing.GffParser import GffParser

class Test_F_mergeTranscriptLists(unittest.TestCase):

    def setUp(self):
        self.queryFileName     = "testQuery.gff3"
        self.referenceFileName = "testReference.gff3"
        self.outputFileName    = "testOutput.gff3"
         
    def tearDown(self):
        for fileRoot in (self.queryFileName, self.referenceFileName, self.outputFileName):
            for file in glob.glob("%s*" % (fileRoot)):
                os.remove(file)
        for file in glob.glob("tmp_*.gff3"):
            os.remove(file)

    def test_run_simple(self):
        reference1 = self._createTranscript("chr1", 1000, 2000, "+", "ref1")
        reference2 = self._createTranscript("chr1", 3000, 4000, "+", "ref2")
        reference3 = self._createTranscript("chr1", 5000, 6000, "+", "ref3")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.addTranscript(reference3)
        writer.close()
        query1 = self._createTranscript("chr1", 1500, 3500, "+", "query1")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.close()
        ml = MergeLists(0)
        ml.setInputFileName(self.queryFileName, 'gff3', 0)
        ml.setInputFileName(self.referenceFileName, 'gff3', 1)
        ml.setOutputFileName(self.outputFileName)
        ml.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 1)
        for transcript in parser.getIterator():
            self._checkTranscript(transcript, "chr1", 1000, 4000, "+", None)

    def test_run_simple_aggregate(self):
        reference1 = self._createTranscript("chr1", 1000, 2000, "+", "ref1")
        reference2 = self._createTranscript("chr1", 3000, 4000, "+", "ref2")
        reference3 = self._createTranscript("chr1", 5000, 6000, "+", "ref3")
        writer = Gff3Writer(self.referenceFileName, 0)
        writer.addTranscript(reference1)
        writer.addTranscript(reference2)
        writer.addTranscript(reference3)
        writer.close()
        query1 = self._createTranscript("chr1", 1500, 3500, "+", "query1")
        writer = Gff3Writer(self.queryFileName, 0)
        writer.addTranscript(query1)
        writer.close()
        ml = MergeLists(0)
        ml.setInputFileName(self.queryFileName, 'gff3', 0)
        ml.setInputFileName(self.referenceFileName, 'gff3', 1)
        ml.setOutputFileName(self.outputFileName)
        ml.setAggregate(True)
        ml.run()
        parser = GffParser(self.outputFileName)
        self.assertEqual(parser.getNbTranscripts(), 2)
        for cpt, transcript in enumerate(parser.getIterator()):
            if cpt == 0:
                self._checkTranscript(transcript, "chr1", 1000, 4000, "+", None)
            else:
                self._checkTranscript(transcript, "chr1", 5000, 6000, "+", None)

    def _createTranscript(self, chromosome, start, end, strand, name):
        transcript = Transcript()
        transcript.setChromosome(chromosome)
        transcript.setStart(start)
        transcript.setEnd(end)
        transcript.setDirection(strand)
        transcript.setName(name)
        return transcript

    def _checkTranscript(self, transcript, chromosome, start, end, strand, name):
        self.assertEqual(transcript.getChromosome(), chromosome)
        self.assertEqual(transcript.getStart(), start)
        self.assertEqual(transcript.getEnd(), end)
        self.assertEqual(transcript.getStrand(), strand)
        if name != None:
            self.assertEqual(transcript.getName(), name)

        
if __name__ == "__main__":
    unittest.main()