diff SMART/Java/Python/plotCoverage.py @ 6:769e306b7933

Change the repository level.
author yufei-luo
date Fri, 18 Jan 2013 04:54:14 -0500
parents
children 94ab73e8a190
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/SMART/Java/Python/plotCoverage.py	Fri Jan 18 04:54:14 2013 -0500
@@ -0,0 +1,473 @@
+#! /usr/bin/env python
+#
+# Copyright INRA-URGI 2009-2010
+# 
+# This software is governed by the CeCILL license under French law and
+# abiding by the rules of distribution of free software. You can use,
+# modify and/ or redistribute the software under the terms of the CeCILL
+# license as circulated by CEA, CNRS and INRIA at the following URL
+# "http://www.cecill.info".
+# 
+# As a counterpart to the access to the source code and rights to copy,
+# modify and redistribute granted by the license, users are provided only
+# with a limited warranty and the software's author, the holder of the
+# economic rights, and the successive licensors have only limited
+# liability.
+# 
+# In this respect, the user's attention is drawn to the risks associated
+# with loading, using, modifying and/or developing or reproducing the
+# software by the user in light of its specific status of free software,
+# that may mean that it is complicated to manipulate, and that also
+# therefore means that it is reserved for developers and experienced
+# professionals having in-depth computer knowledge. Users are therefore
+# encouraged to load and test the software's suitability as regards their
+# requirements in conditions enabling the security of their systems and/or
+# data to be ensured and, more generally, to use and operate it in the
+# same conditions as regards security.
+# 
+# The fact that you are presently reading this means that you have had
+# knowledge of the CeCILL license and that you accept its terms.
+#
+import os, subprocess, glob, random
+from optparse import OptionParser
+from SMART.Java.Python.structure.Interval import Interval
+from SMART.Java.Python.structure.Transcript import Transcript
+from SMART.Java.Python.structure.TranscriptContainer import TranscriptContainer
+from SMART.Java.Python.misc.RPlotter import RPlotter
+from SMART.Java.Python.misc.Progress import Progress
+from commons.core.parsing.FastaParser import FastaParser
+
+strands = [-1, 1]
+colors  = {-1: "blue", 1: "red", 0: "black"}
+colorLine = "black"
+
+def parseTargetField(field):
+    strand             = "+"
+    splittedFieldSpace = field.split()
+    splittedFieldPlus  = field.split("+", 4)
+    if len(splittedFieldSpace) == 3:
+        id, start, end = splittedFieldSpace
+    elif len(splittedFieldSpace) == 4:
+        id, start, end, strand = splittedFieldSpace
+    elif len(splittedFieldPlus) == 3:
+        id, start, end = splittedFieldPlus
+    elif len(splittedFieldPlus) == 4:
+        id, start, end, strand = splittedFieldPlus
+    else:
+        raise Exception("Cannot parse Target field '%s'." % (field))
+    return (id, int(start), int(end), strand)
+
+
+class SimpleTranscript(object):
+    def __init__(self, transcript1, transcript2, color = None):
+        self.start  = max(0, transcript1.getStart() - transcript2.getStart())
+        self.end    = min(transcript2.getEnd() - transcript2.getStart(), transcript1.getEnd() - transcript2.getStart())
+        self.strand = transcript1.getDirection() * transcript2.getDirection()
+        self.exons  = []
+        for exon in transcript1.getExons():
+            if exon.getEnd() >= transcript2.getStart() and exon.getStart() <= transcript2.getEnd():
+                start = max(0, exon.getStart() - transcript2.getStart())
+                end   = min(transcript2.getEnd() - transcript2.getStart(), exon.getEnd() - transcript2.getStart())
+                self.addExon(start, end, self.strand, color)
+
+    def addExon(self, start, end, strand, color):
+        exon = SimpleExon(start, end, strand, color)
+        self.exons.append(exon)
+
+    def getRScript(self, yOffset, height):
+        rString     = ""
+        previousEnd = None
+        for exon in sorted(self.exons, key=lambda exon: exon.start):
+            if previousEnd != None:
+                rString += "segments(%.1f, %.1f, %.1f, %.1f, col = \"%s\")\n" % (previousEnd, yOffset + height / 4.0, exon.start, yOffset + height / 4.0, colorLine)
+            rString    += exon.getRScript(yOffset, height)
+            previousEnd = exon.end
+        return rString
+
+
+class SimpleExon(object):
+    def __init__(self, start, end, strand, color = None):
+        self.start  = start
+        self.end    = end
+        self.strand = strand
+        self.color  = color
+        
+    def getRScript(self, yOffset, height):
+        color = self.color if self.color != None else colors[self.strand]
+        return "rect(%.1f, %.1f, %.1f, %.1f, col=\"%s\", border = \"%s\")\n" % (self.start, yOffset, self.end, yOffset + height / 2.0, color, colorLine)
+
+
+class Plotter(object):
+    
+    def __init__(self, seed, index, verbosity):
+        self.seed        = seed
+        self.index       = index
+        self.verbosity   = verbosity
+        self.maxCoverage = 0
+        self.maxOverlap  = 0
+        self.log         = ""
+        self.merge       = False
+        self.width       = 1500
+        self.heigth      = 1000
+        self.xLabel      = ""
+        self.yLabel      = ""
+        self.title       = None
+        self.absPath     = os.getcwd()
+        self.coverageDataFileName    = "tmpFile_%d_%s.dat" % (seed, index)
+        self.coverageScript          = ""
+        self.overlapScript           = ""
+        self.outputFileName          = None
+
+    def setOutputFileName(self, fileName):
+        self.outputFileName = fileName
+
+    def setTranscript(self, transcript):
+        self.transcript = transcript
+        self.name       = transcript.getName()
+        self.size       = transcript.getEnd() - transcript.getStart() + 1
+        if self.title == None:
+            self.title = self.name
+        else:
+            self.title += " " + self.name
+
+    def setTitle(self, title):
+        self.title = title + " " + self.name
+
+    def setPlotSize(self, width, height):
+        self.width  = width
+        self.height = height
+
+    def setLabels(self, xLabel, yLabel):
+        self.xLabel = xLabel
+        self.yLabel = yLabel
+
+    def setMerge(self, merge):
+        self.merge = merge
+
+    def setCoverageData(self, coverage):
+        outputCoveragePerStrand = dict([strand, 0] for strand in strands)
+        outputCoverage          = 0
+        dataFile = open(os.path.abspath(self.coverageDataFileName), "w")
+        for position in range(self.size+1):
+            sumValue = 0
+            found    = False
+            dataFile.write("%d\t" % (position))
+            for strand in strands:
+                value     = coverage[strand].get(position, 0)
+                sumValue += value
+                dataFile.write("%d\t" % (value))
+                if value > 0:
+                    found = True
+                    outputCoveragePerStrand[strand] += 1
+            self.maxCoverage = max(self.maxCoverage, sumValue)
+            dataFile.write("%d\n" % (sumValue))
+            if found:
+                outputCoverage += 1
+        dataFile.close()
+        self.log += "%s (%d nt):\n - both strands: %d (%.0f%%)\n - (+) strand: %d (%.0f%%)\n - (-) strand: %d (%.0f%%)\n" % (self.name, self.size, outputCoverage, float(outputCoverage) / self.size * 100, outputCoveragePerStrand[1], float(outputCoveragePerStrand[1]) / self.size * 100, outputCoveragePerStrand[-1], float(outputCoveragePerStrand[-1]) / self.size * 100) 
+        self.coverageScript += "data = scan(\"%s\", list(pos = -666, minus = -666, plus = -666, sumValue = -666), sep=\"\t\")\n" % (os.path.abspath(self.coverageDataFileName))
+        self.coverageScript += "lines(x = data$pos, y = data$minus,    col = \"%s\")\n" % (colors[-1])
+        self.coverageScript += "lines(x = data$pos, y = data$plus,     col = \"%s\")\n" % (colors[1])
+        self.coverageScript += "lines(x = data$pos, y = data$sumValue, col = \"%s\")\n" % (colors[0])
+
+    def setOverlapData(self, overlap):
+        height              = 1
+        self.maxOverlap     = (len(overlap) + 1) * height
+        thisElement         = SimpleTranscript(self.transcript, self.transcript, "black")
+        self.overlapScript += thisElement.getRScript(0, height)
+        for cpt, transcript in enumerate(sorted(overlap, cmp=lambda c1, c2: c1.start - c2.start if c1.start != c2.start else c1.end - c2.end)):
+            self.overlapScript += transcript.getRScript((cpt + 1) * height, height)
+
+    def getFirstLine(self, suffix = None):
+        return "png(file = \"%s_%s%s.png\", width = %d, height = %d, bg = \"white\")\n" % (self.outputFileName, self.name, "" if suffix == None or self.merge else "_%s" % (suffix), self.width, self.height)
+
+    def getLastLine(self):
+        return "dev.off()\n"
+
+    def startR(self, fileName, script):
+        scriptFile = open(fileName, "w")
+        scriptFile.write(script)
+        scriptFile.close()
+        command = "R CMD BATCH %s" % (fileName)
+        status  = subprocess.call(command, shell=True)
+        if status != 0:
+            raise Exception("Problem with the execution of script file %s, status is: %s" % (fileName, status))
+
+    def plot(self):
+        print "outputfileName is written in :", self.outputFileName
+        if self.merge:
+            fileName = "%s_%d_%s.R" % (self.outputFileName, self.seed, self.index)
+            plotLine = "plot(x = NA, y = NA, xlab=\"%s\", ylab=\"%s\", panel.first = grid(lwd = 1.0), xlim = c(0, %d), ylim = c(0, %d), cex.axis = 2, cex.lab = 2, cex.main=2, main = \"%s\")\n" % (self.xLabel, self.yLabel, self.size, max(self.maxCoverage, self.maxOverlap), self.title)
+            script   = self.getFirstLine() + plotLine + self.overlapScript + self.coverageScript + self.getLastLine()
+            self.startR(fileName, script)
+        else:
+            fileName = "%s_%d_%s_overlap.R" % (self.outputFileName, self.seed, self.index)
+            print "overlap file is written in :", fileName
+            plotLine = "plot(x = NA, y = NA, xlab=\"%s\", ylab=\"%s\", panel.first = grid(lwd = 1.0), xlim = c(0, %d), ylim = c(0, %d), cex.axis = 2, cex.lab = 2, cex.main=2, main = \"%s\")\n" % (self.xLabel, self.yLabel, self.size, self.maxOverlap, self.title)
+            script   = self.getFirstLine("overlap") + plotLine + self.overlapScript + self.getLastLine()
+            self.startR(fileName, script)
+            fileName = "%s_%d_%s_coverage.R" % (self.outputFileName, self.seed, self.index)
+            plotLine = "plot(x = NA, y = NA, xlab=\"%s\", ylab=\"%s\", panel.first = grid(lwd = 1.0), xlim = c(0, %d), ylim = c(0, %d), cex.axis = 2, cex.lab = 2, cex.main=2, main = \"%s\")\n" % (self.xLabel, self.yLabel, self.size, self.maxCoverage, self.title)
+            script   = self.getFirstLine("coverage") + plotLine + self.coverageScript + self.getLastLine()
+            self.startR(fileName, script)
+
+
+class PlotParser(object):
+
+    def __init__(self, verbosity):
+        self.verbosity      = verbosity
+        self.parsers        = [None, None]
+        self.sequenceParser = None
+        self.seed           = random.randint(0, 10000)
+        self.title          = ""
+        self.merge          = False
+
+    def __del__(self):
+        for fileName in glob.glob("tmpFile_%d*.dat" % (self.seed)):
+            os.remove(fileName)
+        for fileName in glob.glob("%s*.R" % (os.path.abspath(self.outputFileName))):
+            os.remove(fileName)
+        for fileName in glob.glob("%s*.Rout" % (os.path.abspath(self.outputFileName))):
+            os.remove(fileName)
+
+    def addInput(self, inputNb, fileName, fileFormat):
+        if fileName == None:
+            return
+        self.parsers[inputNb] = TranscriptContainer(fileName, fileFormat, self.verbosity)
+        if inputNb == 0:
+            self.parsers[1] = self.parsers[0]
+
+    def addSequence(self, fileName):
+        if fileName == None:
+            return
+        self.sequenceParser = FastaParser(fileName, self.verbosity)
+
+    def setOutput(self, fileName):
+        self.outputFileName = fileName
+
+    def setPlotSize(self, width, height):
+        self.width  = width
+        self.height = height
+
+    def setLabels(self, xLabel, yLabel):
+        self.xLabel = xLabel
+        self.yLabel = yLabel
+
+    def setTitle(self, title):
+        self.title = title
+
+    def setMerge(self, merge):
+        self.merge = merge
+
+    def initializeDataFromSequences(self):
+        self.sizes    = {}
+        self.coverage = {}
+        self.overlap  = {}
+        for region in self.sequenceParser.getRegions():
+            self.sizes[region]    = self.sequenceParser.getSizeOfRegion(region)
+            self.coverage[region] = {}
+            self.overlap[region]  = []
+            for strand in strands:
+                self.coverage[region][strand] = {}
+                self.coverage[region][strand][1] = 0
+                self.coverage[region][strand][self.sizes[region]] = 0
+
+
+    def initializeDataFromTranscripts(self):
+        self.coverage = dict([i, None] for i in range(self.parsers[1].getNbTranscripts()))
+        self.overlap  = dict([i, None] for i in range(self.parsers[1].getNbTranscripts()))
+        self.sizes    = dict([i, 0]    for i in range(self.parsers[1].getNbTranscripts()))
+        self.parsers[0].findData()
+        progress = Progress(self.parsers[1].getNbTranscripts(), "Reading regions", self.verbosity)
+        for cpt, transcript in enumerate(self.parsers[1].getIterator()):
+            self.coverage[cpt] = {}
+            self.overlap[cpt]  = []
+            for strand in strands:
+                self.coverage[cpt][strand] = {}
+                self.coverage[cpt][strand][0] = 0
+                self.coverage[cpt][strand][transcript.getEnd() - transcript.getStart()] = 0
+            for exon in transcript.getExons():
+                self.sizes[cpt] += exon.getSize()
+            progress.inc()
+        progress.done()
+
+    def initialize(self):
+        if self.sequenceParser == None:
+            self.initializeDataFromTranscripts()
+        else:
+            self.initializeDataFromSequences()
+
+    def computeCoverage(self, transcript1, transcript2, id):
+        strand = transcript1.getDirection() * transcript2.getDirection()
+        for exon1 in transcript1.getExons():
+            for exon2 in transcript2.getExons():
+                if exon1.overlapWith(exon2):
+                    for position in range(max(exon1.getStart(), exon2.getStart()), min(exon1.getEnd(), exon2.getEnd()) + 1):
+                        relativePosition = position - transcript2.getStart() + 1
+                        self.coverage[id][strand][relativePosition] = self.coverage[id][strand].get(relativePosition, 0) + 1
+
+    def computeOverlap(self, transcript1, transcript2, id):
+        simpleTranscript = SimpleTranscript(transcript1, transcript2)
+        self.overlap[id].append(simpleTranscript)
+        
+    def compute2TranscriptFiles(self):
+        progress = Progress(self.parsers[1].getNbTranscripts(), "Comparing regions", self.verbosity)
+        for cpt2, transcript2 in enumerate(self.parsers[1].getIterator()):
+            for transcript1 in self.parsers[0].getIterator():
+                if transcript1.overlapWithExon(transcript2):
+                    self.computeCoverage(transcript1, transcript2, cpt2)
+                    self.computeOverlap(transcript1, transcript2, cpt2)
+            progress.inc()
+        progress.done()
+
+    def extractReferenceQuery(self, inputTranscript):
+        if "Target" not in inputTranscript.getTagNames():
+            raise Exception("Cannot extract Target field in line '%s'." % (inputTranscript))
+        id, start, end, strand = parseTargetField(inputTranscript.getTagValue("Target"))
+        if id not in self.sizes:
+            raise Exception("Target id '%s' of transcript '%s' does not correspond to anything in FASTA file." % (id, inputTranscript))
+        referenceTranscript = Transcript()
+        referenceTranscript.setChromosome(id)
+        referenceTranscript.setName(id)
+        referenceTranscript.setDirection("+")
+        referenceTranscript.setEnd(self.sizes[id])
+        referenceTranscript.setStart(1)
+        queryTranscript = Transcript()
+        queryTranscript.setChromosome(id)
+        queryTranscript.setName(id)
+        queryTranscript.setStart(start)
+        queryTranscript.setEnd(end)
+        queryTranscript.setDirection(strand)
+        if inputTranscript.getNbExons() > 1:
+            factor = float(end - start) / (inputTranscript.getEnd() - inputTranscript.getStart())
+            for exon in inputTranscript.getExons():
+                newExon = Interval()
+                newExon.setChromosome(id)
+                newExon.setDirection(strand)
+                if "Target" in inputTranscript.getTagNames():
+                    id, start, end, strand = parseTargetField(exon.getTagValue("Target"))
+                    newExon.setStart(start)
+                    newExon.setEnd(end)
+                else:
+                    newExon.setStart(int(round((exon.getStart() - inputTranscript.getStart()) * factor)) + start)
+                    newExon.setEnd(  int(round((exon.getEnd() -   inputTranscript.getStart()) * factor)) + start)
+                queryTranscript.addExon(newExon)
+        return (referenceTranscript, queryTranscript)
+
+    def compute1TranscriptFiles(self):
+        progress = Progress(self.parsers[1].getNbTranscripts(), "Comparing regions", self.verbosity)
+        for transcript in self.parsers[1].getIterator():
+            referenceTranscript, queryTranscript = self.extractReferenceQuery(transcript)
+            self.computeCoverage(queryTranscript, referenceTranscript, referenceTranscript.getName())
+            self.computeOverlap(queryTranscript, referenceTranscript, referenceTranscript.getName())
+            progress.inc()
+        progress.done()
+
+    def compute(self):
+        if self.sequenceParser == None:
+            self.compute2TranscriptFiles()
+        else:
+            self.compute1TranscriptFiles()
+
+    def plotTranscript(self, index, transcript):
+        plotter = Plotter(self.seed, index, self.verbosity)
+        plotter.setOutputFileName(self.outputFileName)
+        plotter.setTranscript(transcript)
+        plotter.setTitle(self.title)
+        plotter.setLabels(self.xLabel, self.yLabel)
+        plotter.setPlotSize(self.width, self.height)
+        plotter.setCoverageData(self.coverage[index])
+        plotter.setOverlapData(self.overlap[index])
+        plotter.setMerge(self.merge)
+        plotter.plot()
+        output = plotter.log
+        return output
+        
+    def plot1TranscriptFile(self):
+        self.outputCoverage          = {}
+        self.outputCoveragePerStrand = {}
+        output   = ""
+        progress = Progress(len(self.sequenceParser.getRegions()), "Plotting regions", self.verbosity)
+        for cpt2, region in enumerate(self.sequenceParser.getRegions()):
+            transcript = Transcript()
+            transcript.setName(region)
+            transcript.setDirection("+")
+            transcript.setEnd(self.sizes[region])
+            transcript.setStart(1)
+            output += self.plotTranscript(region, transcript)
+            progress.inc()
+        progress.done()
+        if self.verbosity > 0:
+            print output
+
+    def plot2TranscriptFiles(self):
+        self.outputCoverage          = [0] * self.parsers[1].getNbTranscripts()
+        self.outputCoveragePerStrand = [None] * self.parsers[1].getNbTranscripts()
+        for cpt in range(self.parsers[1].getNbTranscripts()):
+            self.outputCoveragePerStrand[cpt] = dict([strand, 0] for strand in strands)
+        progress = Progress(self.parsers[1].getNbTranscripts(), "Plotting regions", self.verbosity)
+        output = ""
+        for cpt2, transcript2 in enumerate(self.parsers[1].getIterator()):
+            output += self.plotTranscript(cpt2, transcript2)
+            progress.inc()
+        progress.done()
+        if self.verbosity > 0:
+            print output
+
+    def plot(self):
+        if self.sequenceParser == None:
+            self.plot2TranscriptFiles()
+        else:
+            self.plot1TranscriptFile()
+
+    def start(self):
+        self.initialize()
+        self.compute()
+        self.plot()
+
+
+if __name__ == "__main__":
+    
+    # parse command line
+    description = "Plot Coverage v1.0.1: Plot the coverage of the first data with respect to the second one. [Category: Visualization]"
+
+    parser = OptionParser(description = description)
+    parser.add_option("-i", "--input1",       dest="inputFileName1", action="store",                       type="string", help="input file 1 [compulsory] [format: file in transcript format given by -f]")
+    parser.add_option("-f", "--inputFormat1", dest="inputFormat1",   action="store",                       type="string", help="format of input file 1 [compulsory] [format: transcript file format]")
+    parser.add_option("-j", "--input2",       dest="inputFileName2", action="store",                       type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]")
+    parser.add_option("-g", "--inputFormat2", dest="inputFormat2",   action="store",                       type="string", help="format of input file 2 [compulsory] [format: transcript file format]")
+    parser.add_option("-q", "--sequence",     dest="inputSequence",  action="store",      default=None,    type="string", help="input sequence file [format: file in FASTA format] [default: None]")
+    parser.add_option("-o", "--output",       dest="outputFileName", action="store",                       type="string", help="output file [compulsory] [format: output file in PNG format]")
+    parser.add_option("-w", "--width",        dest="width",          action="store",      default=1500,    type="int",    help="width of the plots (in px) [format: int] [default: 1500]")
+    parser.add_option("-e", "--height",       dest="height",         action="store",      default=1000,    type="int",    help="height of the plots (in px) [format: int] [default: 1000]")
+    parser.add_option("-t", "--title",        dest="title",          action="store",      default="",      type="string", help="title of the plots [format: string]")
+    parser.add_option("-x", "--xlab",         dest="xLabel",         action="store",      default="",      type="string", help="label on the x-axis [format: string]")
+    parser.add_option("-y", "--ylab",         dest="yLabel",         action="store",      default="",      type="string", help="label on the y-axis [format: string]")
+    parser.add_option("-p", "--plusColor",    dest="plusColor",      action="store",      default="red",   type="string", help="color for the elements on the plus strand [format: string] [default: red]")
+    parser.add_option("-m", "--minusColor",   dest="minusColor",     action="store",      default="blue",  type="string", help="color for the elements on the minus strand [format: string] [default: blue]")
+    parser.add_option("-s", "--sumColor",     dest="sumColor",       action="store",      default="black", type="string", help="color for 2 strands coverage line [format: string] [default: black]")
+    parser.add_option("-l", "--lineColor",    dest="lineColor",      action="store",      default="black", type="string", help="color for the lines [format: string] [default: black]")
+    parser.add_option("-1", "--merge",        dest="merge",          action="store_true", default=False,                  help="merge the 2 plots in 1 [format: boolean] [default: false]")
+    parser.add_option("-D", "--directory",    dest="working_Dir",    action="store",      default=os.getcwd(), type="string", help="the directory to store the results [format: directory]")
+    parser.add_option("-v", "--verbosity",    dest="verbosity",      action="store",      default=1,       type="int",    help="trace level [format: int]")
+    (options, args) = parser.parse_args()
+
+    colors[1]  = options.plusColor
+    colors[-1] = options.minusColor
+    colors[0]  = options.sumColor
+    colorLine  = options.lineColor
+
+    pp = PlotParser(options.verbosity)
+    pp.addInput(0, options.inputFileName1, options.inputFormat1)
+    pp.addInput(1, options.inputFileName2, options.inputFormat2)
+    pp.addSequence(options.inputSequence)
+    if options.working_Dir[-1] != '/':
+        path = options.working_Dir + '/'
+    pp.setOutput(path + options.outputFileName)
+    pp.setPlotSize(options.width, options.height)
+    pp.setLabels(options.xLabel, options.yLabel)
+    pp.setTitle(options.title)
+    pp.setMerge(options.merge)
+    pp.start()
+
+