diff SMART/Java/Python/plotCoverage.py @ 36:44d5973c188c

Uploaded
author m-zytnicki
date Tue, 30 Apr 2013 15:02:29 -0400
parents
children 169d364ddd91
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/SMART/Java/Python/plotCoverage.py	Tue Apr 30 15:02:29 2013 -0400
@@ -0,0 +1,481 @@
+#! /usr/bin/env python
+#
+# Copyright INRA-URGI 2009-2010
+# 
+# This software is governed by the CeCILL license under French law and
+# abiding by the rules of distribution of free software. You can use,
+# modify and/ or redistribute the software under the terms of the CeCILL
+# license as circulated by CEA, CNRS and INRIA at the following URL
+# "http://www.cecill.info".
+# 
+# As a counterpart to the access to the source code and rights to copy,
+# modify and redistribute granted by the license, users are provided only
+# with a limited warranty and the software's author, the holder of the
+# economic rights, and the successive licensors have only limited
+# liability.
+# 
+# In this respect, the user's attention is drawn to the risks associated
+# with loading, using, modifying and/or developing or reproducing the
+# software by the user in light of its specific status of free software,
+# that may mean that it is complicated to manipulate, and that also
+# therefore means that it is reserved for developers and experienced
+# professionals having in-depth computer knowledge. Users are therefore
+# encouraged to load and test the software's suitability as regards their
+# requirements in conditions enabling the security of their systems and/or
+# data to be ensured and, more generally, to use and operate it in the
+# same conditions as regards security.
+# 
+# The fact that you are presently reading this means that you have had
+# knowledge of the CeCILL license and that you accept its terms.
+#
+import os, os.path, subprocess, glob, random
+from optparse import OptionParser
+from SMART.Java.Python.structure.Interval import Interval
+from SMART.Java.Python.structure.Transcript import Transcript
+from commons.core.parsing.ParserChooser import ParserChooser
+from SMART.Java.Python.misc.RPlotter import RPlotter
+from SMART.Java.Python.misc.Progress import Progress
+from commons.core.parsing.FastaParser import FastaParser
+
+strands = [-1, 1]
+colors  = {-1: "blue", 1: "red", 0: "black"}
+colorLine = "black"
+
+def parseTargetField(field):
+	strand             = "+"
+	splittedFieldSpace = field.split()
+	splittedFieldPlus  = field.split("+", 4)
+	if len(splittedFieldSpace) == 3:
+		id, start, end = splittedFieldSpace
+	elif len(splittedFieldSpace) == 4:
+		id, start, end, strand = splittedFieldSpace
+	elif len(splittedFieldPlus) == 3:
+		id, start, end = splittedFieldPlus
+	elif len(splittedFieldPlus) == 4:
+		id, start, end, strand = splittedFieldPlus
+	else:
+		raise Exception("Cannot parse Target field '%s'." % (field))
+	return (id, int(start), int(end), strand)
+
+
+class SimpleTranscript(object):
+	def __init__(self, transcript1, transcript2, color = None):
+		self.start  = max(0, transcript1.getStart() - transcript2.getStart())
+		self.end    = min(transcript2.getEnd() - transcript2.getStart(), transcript1.getEnd() - transcript2.getStart())
+		self.strand = transcript1.getDirection() * transcript2.getDirection()
+		self.exons  = []
+		for exon in transcript1.getExons():
+			if exon.getEnd() >= transcript2.getStart() and exon.getStart() <= transcript2.getEnd():
+				start = max(0, exon.getStart() - transcript2.getStart())
+				end   = min(transcript2.getEnd() - transcript2.getStart(), exon.getEnd() - transcript2.getStart())
+				self.addExon(start, end, self.strand, color)
+
+	def addExon(self, start, end, strand, color):
+		exon = SimpleExon(start, end, strand, color)
+		self.exons.append(exon)
+
+	def getRScript(self, yOffset, height):
+		rString     = ""
+		previousEnd = None
+		for exon in sorted(self.exons, key=lambda exon: exon.start):
+			if previousEnd != None:
+				rString += "segments(%.1f, %.1f, %.1f, %.1f, col = \"%s\")\n" % (previousEnd, yOffset + height / 4.0, exon.start, yOffset + height / 4.0, colorLine)
+			rString    += exon.getRScript(yOffset, height)
+			previousEnd = exon.end
+		return rString
+
+
+class SimpleExon(object):
+	def __init__(self, start, end, strand, color = None):
+		self.start  = start
+		self.end    = end
+		self.strand = strand
+		self.color  = color
+		
+	def getRScript(self, yOffset, height):
+		color = self.color if self.color != None else colors[self.strand]
+		return "rect(%.1f, %.1f, %.1f, %.1f, col=\"%s\", border = \"%s\")\n" % (self.start, yOffset, self.end, yOffset + height / 2.0, color, colorLine)
+
+
+class Plotter(object):
+	
+	def __init__(self, seed, index, verbosity):
+		self.seed        = seed
+		self.index       = index
+		self.verbosity   = verbosity
+		self.maxCoverage = 0
+		self.maxOverlap  = 0
+		self.log         = ""
+		self.merge       = False
+		self.width       = 1500
+		self.heigth      = 1000
+		self.xLabel      = ""
+		self.yLabel      = ""
+		self.title       = None
+		self.absPath     = os.getcwd()
+		self.coverageDataFileName    = "tmpFile_%d_%s.dat" % (seed, index)
+		self.coverageScript          = ""
+		self.overlapScript           = ""
+		self.outputFileName          = None
+
+	def setOutputFileName(self, fileName):
+		self.outputFileName = fileName
+
+	def setTranscript(self, transcript):
+		self.transcript = transcript
+		self.name       = transcript.getName()
+		self.size       = transcript.getEnd() - transcript.getStart() + 1
+		if self.title == None:
+			self.title = self.name
+		else:
+			self.title += " " + self.name
+
+	def setTitle(self, title):
+		self.title = title + " " + self.name
+
+	def setPlotSize(self, width, height):
+		self.width  = width
+		self.height = height
+
+	def setLabels(self, xLabel, yLabel):
+		self.xLabel = xLabel
+		self.yLabel = yLabel
+
+	def setMerge(self, merge):
+		self.merge = merge
+
+	def setCoverageData(self, coverage):
+		outputCoveragePerStrand = dict([strand, 0] for strand in strands)
+		outputCoverage          = 0
+		dataFile = open(os.path.abspath(self.coverageDataFileName), "w")
+		for position in range(self.size+1):
+			sumValue = 0
+			found    = False
+			dataFile.write("%d\t" % (position))
+			for strand in strands:
+				value     = coverage[strand].get(position, 0)
+				sumValue += value
+				dataFile.write("%d\t" % (value))
+				if value > 0:
+					found = True
+					outputCoveragePerStrand[strand] += 1
+			self.maxCoverage = max(self.maxCoverage, sumValue)
+			dataFile.write("%d\n" % (sumValue))
+			if found:
+				outputCoverage += 1
+		dataFile.close()
+		self.log += "%s (%d nt):\n - both strands: %d (%.0f%%)\n - (+) strand: %d (%.0f%%)\n - (-) strand: %d (%.0f%%)\n" % (self.name, self.size, outputCoverage, float(outputCoverage) / self.size * 100, outputCoveragePerStrand[1], float(outputCoveragePerStrand[1]) / self.size * 100, outputCoveragePerStrand[-1], float(outputCoveragePerStrand[-1]) / self.size * 100) 
+		self.coverageScript += "data = scan(\"%s\", list(pos = -666, minus = -666, plus = -666, sumValue = -666), sep=\"\t\")\n" % (os.path.abspath(self.coverageDataFileName))
+		self.coverageScript += "lines(x = data$pos, y = data$minus,    col = \"%s\")\n" % (colors[-1])
+		self.coverageScript += "lines(x = data$pos, y = data$plus,     col = \"%s\")\n" % (colors[1])
+		self.coverageScript += "lines(x = data$pos, y = data$sumValue, col = \"%s\")\n" % (colors[0])
+
+	def setOverlapData(self, overlap):
+		height              = 1
+		self.maxOverlap     = (len(overlap) + 1) * height
+		thisElement         = SimpleTranscript(self.transcript, self.transcript, "black")
+		self.overlapScript += thisElement.getRScript(0, height)
+		for cpt, transcript in enumerate(sorted(overlap, cmp=lambda c1, c2: c1.start - c2.start if c1.start != c2.start else c1.end - c2.end)):
+			self.overlapScript += transcript.getRScript((cpt + 1) * height, height)
+
+	def getFirstLine(self, suffix = None):
+		return "png(file = \"%s_%s%s.png\", width = %d, height = %d, bg = \"white\")\n" % (self.outputFileName, self.name, "" if suffix == None or self.merge else "_%s" % (suffix), self.width, self.height)
+
+	def getLastLine(self):
+		return "dev.off()\n"
+
+	def startR(self, fileName, script):
+		scriptFile = open(fileName, "w")
+		scriptFile.write(script)
+		scriptFile.close()
+		command = "R CMD BATCH %s" % (fileName)
+		status  = subprocess.call(command, shell=True)
+		if status != 0:
+			raise Exception("Problem with the execution of script file %s, status is: %s" % (fileName, status))
+
+	def plot(self):
+		if self.merge:
+			fileName = "%s_%d_%s.R" % (self.outputFileName, self.seed, self.index)
+			plotLine = "plot(x = NA, y = NA, xlab=\"%s\", ylab=\"%s\", panel.first = grid(lwd = 1.0), xlim = c(0, %d), ylim = c(0, %d), cex.axis = 2, cex.lab = 2, cex.main=2, main = \"%s\")\n" % (self.xLabel, self.yLabel, self.size, max(self.maxCoverage, self.maxOverlap), self.title)
+			script   = self.getFirstLine() + plotLine + self.overlapScript + self.coverageScript + self.getLastLine()
+			self.startR(fileName, script)
+		else:
+			fileName = "%s_%d_%s_overlap.R" % (self.outputFileName, self.seed, self.index)
+			plotLine = "plot(x = NA, y = NA, xlab=\"%s\", ylab=\"%s\", panel.first = grid(lwd = 1.0), xlim = c(0, %d), ylim = c(0, %d), cex.axis = 2, cex.lab = 2, cex.main=2, main = \"%s\")\n" % (self.xLabel, self.yLabel, self.size, self.maxOverlap, self.title)
+			script   = self.getFirstLine("overlap") + plotLine + self.overlapScript + self.getLastLine()
+			self.startR(fileName, script)
+			fileName = "%s_%d_%s_coverage.R" % (self.outputFileName, self.seed, self.index)
+			plotLine = "plot(x = NA, y = NA, xlab=\"%s\", ylab=\"%s\", panel.first = grid(lwd = 1.0), xlim = c(0, %d), ylim = c(0, %d), cex.axis = 2, cex.lab = 2, cex.main=2, main = \"%s\")\n" % (self.xLabel, self.yLabel, self.size, self.maxCoverage, self.title)
+			script   = self.getFirstLine("coverage") + plotLine + self.coverageScript + self.getLastLine()
+			self.startR(fileName, script)
+
+
+class PlotParser(object):
+
+	def __init__(self, verbosity):
+		self.verbosity      = verbosity
+		self.parsers        = [None, None]
+		self.sequenceParser = None
+		self.seed           = random.randint(0, 10000)
+		self.title          = ""
+		self.merge          = False
+
+	def __del__(self):
+		for fileName in glob.glob("tmpFile_%d*.dat" % (self.seed)):
+			os.remove(fileName)
+		for fileName in glob.glob("%s*.R" % (os.path.abspath(self.outputFileName))):
+			os.remove(fileName)
+		for fileName in glob.glob("%s*.Rout" % (os.path.abspath(self.outputFileName))):
+			os.remove(fileName)
+
+	def addInput(self, inputNb, fileName, fileFormat):
+		if fileName == None:
+			return
+		chooser = ParserChooser(self.verbosity)
+		chooser.findFormat(fileFormat)
+		self.parsers[inputNb] = chooser.getParser(fileName)
+		if inputNb == 0:
+			self.parsers[1] = self.parsers[0]
+
+	def addSequence(self, fileName):
+		if fileName == None:
+			return
+		self.sequenceParser = FastaParser(fileName, self.verbosity)
+
+	def setOutput(self, fileName):
+		self.outputFileName = fileName
+
+	def setPlotSize(self, width, height):
+		self.width  = width
+		self.height = height
+
+	def setLabels(self, xLabel, yLabel):
+		self.xLabel = xLabel
+		self.yLabel = yLabel
+
+	def setTitle(self, title):
+		self.title = title
+
+	def setMerge(self, merge):
+		self.merge = merge
+
+	def initializeDataFromSequences(self):
+		self.sizes    = {}
+		self.coverage = {}
+		self.overlap  = {}
+		for region in self.sequenceParser.getRegions():
+			self.sizes[region]    = self.sequenceParser.getSizeOfRegion(region)
+			self.coverage[region] = {}
+			self.overlap[region]  = []
+			for strand in strands:
+				self.coverage[region][strand] = {}
+				self.coverage[region][strand][1] = 0
+				self.coverage[region][strand][self.sizes[region]] = 0
+
+	def initializeDataFromTranscripts(self):
+		self.coverage = dict([i, None] for i in range(self.parsers[1].getNbTranscripts()))
+		self.overlap  = dict([i, None] for i in range(self.parsers[1].getNbTranscripts()))
+		self.sizes    = dict([i, 0]    for i in range(self.parsers[1].getNbTranscripts()))
+		progress = Progress(self.parsers[1].getNbTranscripts(), "Reading regions", self.verbosity)
+		for cpt, transcript in enumerate(self.parsers[1].getIterator()):
+			self.coverage[cpt] = {}
+			self.overlap[cpt]  = []
+			for strand in strands:
+				self.coverage[cpt][strand] = {}
+				self.coverage[cpt][strand][0] = 0
+				self.coverage[cpt][strand][transcript.getEnd() - transcript.getStart()] = 0
+			for exon in transcript.getExons():
+				self.sizes[cpt] += exon.getSize()
+			progress.inc()
+		progress.done()
+
+	def initialize(self):
+		if self.sequenceParser == None:
+			self.initializeDataFromTranscripts()
+		else:
+			self.initializeDataFromSequences()
+
+	def computeCoverage(self, transcript1, transcript2, id):
+		strand = transcript1.getDirection() * transcript2.getDirection()
+		for exon1 in transcript1.getExons():
+			for exon2 in transcript2.getExons():
+				if exon1.overlapWith(exon2):
+					for position in range(max(exon1.getStart(), exon2.getStart()), min(exon1.getEnd(), exon2.getEnd()) + 1):
+						relativePosition = position - transcript2.getStart() + 1
+						self.coverage[id][strand][relativePosition] = self.coverage[id][strand].get(relativePosition, 0) + 1
+
+	def computeOverlap(self, transcript1, transcript2, id):
+		simpleTranscript = SimpleTranscript(transcript1, transcript2)
+		self.overlap[id].append(simpleTranscript)
+		
+	def compute2TranscriptFiles(self):
+		progress = Progress(self.parsers[1].getNbTranscripts(), "Comparing regions", self.verbosity)
+		for cpt2, transcript2 in enumerate(self.parsers[1].getIterator()):
+			for transcript1 in self.parsers[0].getIterator():
+				if transcript1.overlapWithExon(transcript2):
+					self.computeCoverage(transcript1, transcript2, cpt2)
+					self.computeOverlap(transcript1, transcript2, cpt2)
+			progress.inc()
+		progress.done()
+
+	def extractReferenceQueryMapping(self, mapping):
+		queryTranscript = mapping.getTranscript()
+		referenceTranscript = Transcript()
+		referenceTranscript.setChromosome(queryTranscript.getChromosome())
+		referenceTranscript.setName(queryTranscript.getChromosome())
+		referenceTranscript.setDirection("+")
+		referenceTranscript.setEnd(self.sizes[queryTranscript.getChromosome()])
+		referenceTranscript.setStart(1)
+		return (referenceTranscript, queryTranscript)
+
+	def extractReferenceQuery(self, inputTranscript):
+		if "Target" not in inputTranscript.getTagNames():
+			raise Exception("Cannot extract Target field in line '%s'." % (inputTranscript))
+		id, start, end, strand = parseTargetField(inputTranscript.getTagValue("Target"))
+		if id not in self.sizes:
+			raise Exception("Target id '%s' of transcript '%s' does not correspond to anything in FASTA file." % (id, inputTranscript))
+		referenceTranscript = Transcript()
+		referenceTranscript.setChromosome(id)
+		referenceTranscript.setName(id)
+		referenceTranscript.setDirection("+")
+		referenceTranscript.setEnd(self.sizes[id])
+		referenceTranscript.setStart(1)
+		queryTranscript = Transcript()
+		queryTranscript.setChromosome(id)
+		queryTranscript.setName(id)
+		queryTranscript.setStart(start)
+		queryTranscript.setEnd(end)
+		queryTranscript.setDirection(strand)
+		if inputTranscript.getNbExons() > 1:
+			factor = float(end - start) / (inputTranscript.getEnd() - inputTranscript.getStart())
+			for exon in inputTranscript.getExons():
+				newExon = Interval()
+				newExon.setChromosome(id)
+				newExon.setDirection(strand)
+				if "Target" in inputTranscript.getTagNames():
+					id, start, end, strand = parseTargetField(exon.getTagValue("Target"))
+					newExon.setStart(start)
+					newExon.setEnd(end)
+				else:
+					newExon.setStart(int(round((exon.getStart() - inputTranscript.getStart()) * factor)) + start)
+					newExon.setEnd(  int(round((exon.getEnd() -   inputTranscript.getStart()) * factor)) + start)
+				queryTranscript.addExon(newExon)
+		return (referenceTranscript, queryTranscript)
+
+	def compute1TranscriptFiles(self):
+		progress = Progress(self.parsers[1].getNbItems(), "Comparing regions", self.verbosity)
+		for transcript in self.parsers[1].getIterator():
+			if transcript.__class__.__name__ == "Mapping":
+				referenceTranscript, queryTranscript = self.extractReferenceQueryMapping(transcript)
+			else:
+				referenceTranscript, queryTranscript = self.extractReferenceQuery(transcript)
+			self.computeCoverage(queryTranscript, referenceTranscript, referenceTranscript.getName())
+			self.computeOverlap(queryTranscript, referenceTranscript, referenceTranscript.getName())
+			progress.inc()
+		progress.done()
+
+	def compute(self):
+		if self.sequenceParser == None:
+			self.compute2TranscriptFiles()
+		else:
+			self.compute1TranscriptFiles()
+
+	def plotTranscript(self, index, transcript):
+		plotter = Plotter(self.seed, index, self.verbosity)
+		plotter.setOutputFileName(self.outputFileName)
+		plotter.setTranscript(transcript)
+		plotter.setTitle(self.title)
+		plotter.setLabels(self.xLabel, self.yLabel)
+		plotter.setPlotSize(self.width, self.height)
+		plotter.setCoverageData(self.coverage[index])
+		plotter.setOverlapData(self.overlap[index])
+		plotter.setMerge(self.merge)
+		plotter.plot()
+		output = plotter.log
+		return output
+		
+	def plot1TranscriptFile(self):
+		self.outputCoverage          = {}
+		self.outputCoveragePerStrand = {}
+		output   = ""
+		progress = Progress(len(self.sequenceParser.getRegions()), "Plotting regions", self.verbosity)
+		for cpt2, region in enumerate(self.sequenceParser.getRegions()):
+			transcript = Transcript()
+			transcript.setName(region)
+			transcript.setDirection("+")
+			transcript.setEnd(self.sizes[region])
+			transcript.setStart(1)
+			output += self.plotTranscript(region, transcript)
+			progress.inc()
+		progress.done()
+		if self.verbosity > 0:
+			print output
+
+	def plot2TranscriptFiles(self):
+		self.outputCoverage          = [0] * self.parsers[1].getNbTranscripts()
+		self.outputCoveragePerStrand = [None] * self.parsers[1].getNbTranscripts()
+		for cpt in range(self.parsers[1].getNbTranscripts()):
+			self.outputCoveragePerStrand[cpt] = dict([strand, 0] for strand in strands)
+		progress = Progress(self.parsers[1].getNbTranscripts(), "Plotting regions", self.verbosity)
+		output = ""
+		for cpt2, transcript2 in enumerate(self.parsers[1].getIterator()):
+			output += self.plotTranscript(cpt2, transcript2)
+			progress.inc()
+		progress.done()
+		if self.verbosity > 0:
+			print output
+
+	def plot(self):
+		if self.sequenceParser == None:
+			self.plot2TranscriptFiles()
+		else:
+			self.plot1TranscriptFile()
+
+	def start(self):
+		self.initialize()
+		self.compute()
+		self.plot()
+
+
+if __name__ == "__main__":
+	
+	# parse command line
+	description = "Plot Coverage v1.0.1: Plot the coverage of the first data with respect to the second one. [Category: Visualization]"
+
+	parser = OptionParser(description = description)
+	parser.add_option("-i", "--input1",       dest="inputFileName1", action="store",                       type="string", help="input file 1 [compulsory] [format: file in transcript or mapping format given by -f]")
+	parser.add_option("-f", "--inputFormat1", dest="inputFormat1",   action="store",                       type="string", help="format of input file 1 [compulsory] [format: transcript or mapping file format]")
+	parser.add_option("-j", "--input2",       dest="inputFileName2", action="store",                       type="string", help="input file 2 [compulsory] [format: file in transcript format given by -g]")
+	parser.add_option("-g", "--inputFormat2", dest="inputFormat2",   action="store",                       type="string", help="format of input file 2 [compulsory] [format: transcript file format]")
+	parser.add_option("-q", "--sequence",     dest="inputSequence",  action="store",      default=None,    type="string", help="input sequence file [format: file in FASTA format] [default: None]")
+	parser.add_option("-o", "--output",       dest="outputFileName", action="store",                       type="string", help="output file [compulsory] [format: output file in PNG format]")
+	parser.add_option("-w", "--width",        dest="width",          action="store",      default=1500,    type="int",    help="width of the plots (in px) [format: int] [default: 1500]")
+	parser.add_option("-e", "--height",       dest="height",         action="store",      default=1000,    type="int",    help="height of the plots (in px) [format: int] [default: 1000]")
+	parser.add_option("-t", "--title",        dest="title",          action="store",      default="",      type="string", help="title of the plots [format: string]")
+	parser.add_option("-x", "--xlab",         dest="xLabel",         action="store",      default="",      type="string", help="label on the x-axis [format: string]")
+	parser.add_option("-y", "--ylab",         dest="yLabel",         action="store",      default="",      type="string", help="label on the y-axis [format: string]")
+	parser.add_option("-p", "--plusColor",    dest="plusColor",      action="store",      default="red",   type="string", help="color for the elements on the plus strand [format: string] [default: red]")
+	parser.add_option("-m", "--minusColor",   dest="minusColor",     action="store",      default="blue",  type="string", help="color for the elements on the minus strand [format: string] [default: blue]")
+	parser.add_option("-s", "--sumColor",     dest="sumColor",       action="store",      default="black", type="string", help="color for 2 strands coverage line [format: string] [default: black]")
+	parser.add_option("-l", "--lineColor",    dest="lineColor",      action="store",      default="black", type="string", help="color for the lines [format: string] [default: black]")
+	parser.add_option("-1", "--merge",        dest="merge",          action="store_true", default=False,                  help="merge the 2 plots in 1 [format: boolean] [default: false]")
+	parser.add_option("-D", "--directory",    dest="working_Dir",    action="store",      default=os.getcwd(), type="string", help="the directory to store the results [format: directory]")
+	parser.add_option("-v", "--verbosity",    dest="verbosity",      action="store",      default=1,       type="int",    help="trace level [format: int]")
+	(options, args) = parser.parse_args()
+
+	colors[1]  = options.plusColor
+	colors[-1] = options.minusColor
+	colors[0]  = options.sumColor
+	colorLine  = options.lineColor
+
+	pp = PlotParser(options.verbosity)
+	pp.addInput(0, options.inputFileName1, options.inputFormat1)
+	pp.addInput(1, options.inputFileName2, options.inputFormat2)
+	pp.addSequence(options.inputSequence)
+	pp.setOutput(options.outputFileName if os.path.isabs(options.outputFileName) else os.path.join(options.working_Dir, options.outputFileName))
+	pp.setPlotSize(options.width, options.height)
+	pp.setLabels(options.xLabel, options.yLabel)
+	pp.setTitle(options.title)
+	pp.setMerge(options.merge)
+	pp.start()
+