# Copyright INRA-URGI 2009-2010
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "".
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.

import os
import subprocess
import random
import math

minPositiveValue = 10e-6

Plot simple curves in R

class RPlotter(object):
    Plot some curves
    @ivar nbColors: number of different colors
    @type nbColors: int
    @ivar fileName: name of the file
    @type fileName: string
    @ivar lines: lines to be plotted
    @type lines: array of dict
    @ivar names: name of the lines
    @type names: array of strings
    @ivar colors: color of the lines
    @type colors: array of strings
    @ivar types: type of the lines (plain or dashed)
    @type types: array of strings
    @ivar format: format of the picture
    @type format: string
    @ivar lineWidth: width of the line in a xy-plot
    @type lineWidth: int
    @ivar xMin: minimum value taken on the x-axis
    @type xMin: int
    @ivar xMax: maximum value taken on the x-axis
    @type xMax: int
    @ivar yMin: minimum value taken on the y-axis
    @type yMin: int
    @ivar yMax: maximum value taken on the y-axis
    @type yMax: int
    @ivar minimumX: minimum value allowed on the x-axis
    @type minimumX: int
    @ivar maximumX: maximum value allowed on the x-axis
    @type maximumX: int
    @ivar minimumY: minimum value allowed on the y-axis
    @type minimumY: int
    @ivar maximumY: maximum value allowed on the y-axis
    @type maximumY: int
    @ivar leftMargin:  add some margin in the left part of the plot
    @type leftMargin:  float
    @ivar rightMargin: add some margin in the right part of the plot
    @type rightMargin: float
    @ivar downMargin:  add some margin at the top of the plot
    @type downMargin:  float
    @ivar upMargin:    add some margin at the bottom of the plot
    @type upMargin:    float
    @ivar logX: use log scale on the x-axis
    @type logX: boolean
    @ivar logY: use log scale on the y-axis
    @type logY: boolean
    @ivar logZ: use log scale on the z-axis (the color)
    @type logZ: boolean
    @ival fill: if a value is not given, fill it with given value
    @type fill: int
    @ival bucket: cluster the data into buckets of given size
    @type bucket: int
    @ival seed: a random number
    @type seed: int
    @ival regression: plot a linear regression
    @type regression: boolean
    @ival legend: set the legend
    @type legend: boolean
    @ival legendBySide: set the legend outside of the plot
    @type legendBySde: boolean
    @ival xLabel: label for the x-axis
    @type xLabel: string
    @ival yLabel: label for the y-axis
    @type yLabel: string
    @ival title: title of the plot
    @type title: string
    @ival barplot: use a barplot representation instead
    @type barplot: boolean
    @ival points: use a point cloud instead
    @type points: boolean
    @ival heatPoints: use a colored point cloud instead
    @type heatPoints: boolean
    @ival axesLabels: change the names of the axes
    @type axesLabels: vector of 2 int to string dict
    @ival rotateAxesLabels: rotate the axes labels
    @type rotateAxesLabels: dict of 2 boolean
    @ival verbosity: verbosity of the class
    @type verbosity: int
    @ival keep: keep temporary files
    @type keep: boolean

    def __init__(self, fileName, verbosity = 0, keep = False):
        @param fileName: name of the file to produce
        @type    fileName: string
        @param verbosity: verbosity
        @type    verbosity: int
        @param keep: keep temporary files
        @type keep: boolean
        self.nbColors = 9
        self.fileName = fileName
        self.verbosity = verbosity
        self.keep = keep
        self.format = "png"
        self.fill = None
        self.bucket = None
        self.lines = []
        self.names = []
        self.colors = []
        self.types = []
        self.lineWidth = 1
        self.xMin = None
        self.xMax = None
        self.yMin = None
        self.yMax = None
        self.seed = random.randint(0, 10000)
        self.minimumX = None
        self.maximumX = None
        self.minimumY = None
        self.maximumY = None
        self.leftMargin   = 0
        self.rightMargin  = 0
        self.topMargin    = 0
        self.bottomMargin = 0
        self.logX = False
        self.logY = False
        self.logZ = False
        self.regression = False
        self.width = 1000
        self.height = 500
        self.legend = False
        self.legendBySide = False
        self.xLabel = ""
        self.yLabel = ""
        self.title = None
        self.points = False
        self.heatPoints = False
        self.barplot = False
        self.axesLabels = {1: None, 2: None}
        self.rotateAxesLabels = {1: False, 2: False}
        self.linesToAddBox = ""
    def __del__(self):
        Remove tmp files
        if not self.keep:
            scriptFileName = "tmpScript-%d.R" % (self.seed)
            if os.path.exists(scriptFileName):
            outputFileName = "%sout" % (scriptFileName)
            if os.path.exists(outputFileName):
            nbLines = len(self.lines) + (1 if self.heatPoints else 0)
            for i in range(nbLines):
                if os.path.exists("tmpData-%d-%d.dat" % (self.seed, i)):
                    os.remove("tmpData-%d-%d.dat" % (self.seed, i))

    def setMinimumX(self, xMin):
        Set the minimum value on the x-axis
        @param xMin:minimum value on the x-axis
        @type xMin: int
        self.minimumX = xMin

    def setMaximumX(self, xMax):
        Set the maximum value on the x-axis
        @param xMax: maximum value on the x-axis
        @type xMax: int
        self.maximumX = xMax
    def setMinimumY(self, yMin):
        Set the minimum value on the y-axis
        @param yMin: minimum value on the y-axis
        @type yMin: int
        self.minimumY = yMin

    def setMaximumY(self, yMax):
        Set the maximum value on the y-axis
        @param yMax: maximum value on the y-axis
        @type xmax: int
        self.maximumY = yMax
    def setFill(self, fill):
        Fill empty data with given value
        @param fill: the value to fill with
        @type fill: int
        self.fill = fill

    def setBuckets(self, bucket):
        Cluster the data into buckets of given size
        @param bucket: the size of the buckets
        @type bucket: int
        self.bucket = bucket

    def setRegression(self, regression):
        Plot a linear regression line
        @param regression: whether to plot the regression
        @type  regression: bool
        self.regression = regression

    def setFormat(self, format):
        Set the format of the picture
        @param format: the format
        @type format: string
        if format not in ("png", "pdf", "jpeg", "bmp", "tiff"):
            raise Exception("Format '%s' is not supported by RPlotter" % (format))
        self.format = format

    def setWidth(self, width):
        Set the dimensions of the image produced
        @param width: width of the image
        @type width: int
        self.width = width
    def setHeight(self, height):
        Set the dimensions of the image produced
        @param height: heigth of the image
        @type height: int
        self.height = height
    def setImageSize(self, width, height):
        Set the dimensions of the image produced
        @param width: width of the image
        @type width: int
        @param height: heigth of the image
        @type height: int
    def setLegend(self, legend, bySide = False):
        Print a legend or not
        @param legend: print a legend
        @type  legend: boolean
        @param bySide: put the legend outside of the plot
        @type  bySide: boolean
        self.legend       = legend
        self.legendBySide = bySide

    def setXLabel(self, label):
        Print a label for the x-axis
        @param label: the label
        @type label: string
        self.xLabel = label
        if self.xLabel != None:
            self.xLabel = self.xLabel.replace("_", " ")

    def setYLabel(self, label):
        Print a label for the y-axis
        @param label: the label
        @type label: string
        self.yLabel = label
        if self.yLabel != None:
            self.yLabel = self.yLabel.replace("_", " ")

    def addLeftMargin(self, margin):
        Increase the size of the space on the left part of the graph
        @param margin: the space added
        @type  margin: float
        self.leftMargin = margin

    def addRightMargin(self, margin):
        Increase the size of the space on the right part of the graph
        @param margin: the space added
        @type  margin: float
        self.rightMargin = margin

    def addTopMargin(self, margin):
        Increase the size of the space at the top of the graph
        TopMargin is a percentage if 0 < TopMargin < 1.
        TopMargin is a value if TopMargin >= 1.
        @param margin: the space added
        @type  margin: float
        self.topMargin = margin

    def addBottomMargin(self, margin):
        Increase the size of the space at the bottom of the graph
        @param margin: the space added
        @type  margin: float
        self.bottomMargin = margin

    def getNewYMaxWithTopMargin(self):
        Return new xMin coordinate with left margin
        @param xMin: coordinate
        @type  xMin: float
        yMax = self.yMax
        if 0 < self.topMargin and self.topMargin < 1:
            topMargin = self.topMargin * self.yMax
            yMax = self.yMax + topMargin
        elif self.topMargin >= 1:
            yMax = self.yMax + self.topMargin
        return yMax

    def setTitle(self, title):
        Print a title for graph
        @param title: a title
        @type title: string
        self.title = title
        if self.title != None:
            self.title = self.title.replace("_", " ")

    def setAxisLabel(self, i, labels):
        Change x- or y-labels
        @param i: x for x-label, y for y-label
        @type  i: string
        @param labels: new labels
        @type  labels: int to string dict
        i = i.lower()
        if i not in ("x", "y"):
            raise Exception("Label name '" + i + "' should by 'x' or 'y' while changing axis labels.")
        self.axesLabels[{"x": 1, "y": 2}[i]] = labels

    def rotateAxisLabel(self, i, b = True):
        Rotate x- or y-labels
        @param i: x for x-label, y for y-label
        @type  i: string
        @param b: whether the labels should be rotated
        @type  b: boolean
        i = i.lower()
        if i not in ("x", "y"):
            raise Exception("Label name '" + i + "' should by 'x' or 'y' while rotating axis labels.")
        self.rotateAxesLabels[{"x": 1, "y": 2}[i]] = b

    def setLineWidth(self, width):
        Set the line width in a xy-plot
        @param width: the new line width
        @type  width: int
        self.lineWidth = width

    def setLog(self, log):
        Use log-scale for axes
        @param log: use log scale
        @type log: boolean
        self.logX = ("x" in log)
        self.logY = ("y" in log)
        self.logZ = ("z" in log)

    def setBarplot(self, barplot):
        Use barplot representation instead
        @param barplot: barplot representation
        @type barplot: boolean
        self.barplot = barplot

    def setPoints(self, points):
        Use points cloud representation instead
        @param points: points cloud representation
        @type points: boolean
        self.points = points

    def setHeatPoints(self, heatPoints):
        Use points cloud representation with color representing another variable instead
        @param points: colored points cloud representation
        @type points: boolean
        self.heatPoints = heatPoints

    def addBox(self, lXCoordList, minY, maxY):
        for lXCoord in lXCoordList:
            self.linesToAddBox += "rect(%s,%s,%s,%s,density=50, col='grey',border='transparent')\n" % (lXCoord[0], minY, lXCoord[1], maxY)
    def addLine(self, line, name = "", color = None):
        Add a line 
        @param line: a line to plot
        @type line: dict
        # prepare data
        plot = []
        if self.points or self.heatPoints:
            values = line.values()
        elif self.fill == None:
            values = sorted(line.keys())
            values = range(min(line.keys()), max(line.keys()) + 1)
        for element in values:
            if self.points or self.heatPoints:
                x = element[0]
                y = element[1]
                x = element
                if x not in line:
                    y = self.fill
                    y = line[x]
            if self.minimumX != None and x < self.minimumX:
            if self.maximumX != None and x > self.maximumX:
            if x == None:
                raise Exception("Problem! x is None. Aborting...")
            if y == None:
                raise Exception("Problem! y is None. Aborting...")
            if x == 0 and self.logX:
                x = minPositiveValue
            if y == 0 and self.logY:
                y = minPositiveValue
            if self.xMin == None:
                if not self.logX or x != 0:
                    self.xMin = x
                if not self.logX or x != 0:
                    self.xMin = min(self.xMin, x)
            if self.xMax == None:
                self.xMax = x
                self.xMax = max(self.xMax, x)
            if self.yMin == None:
                if not self.logY or y != 0:
                    self.yMin = y
                if not self.logY or y != 0:
                    if y != "NA":
                        self.yMin = min(self.yMin, y)
            if self.yMax == None:
                self.yMax = y
                if y != "NA":
                    self.yMax = max(self.yMax, y)

            plot.append((x, y))

        # cluster the data into buckets
        if self.bucket != None:
            buckets = dict([((int(value) / int(self.bucket)) * self.bucket, 0) for value in xrange(min(line.keys()), max(line.keys())+1)])
            for distance, nb in line.iteritems():
                buckets[(int(distance) / int(self.bucket)) * self.bucket] += nb
            self.yMax = max(buckets.values())
            plot = []
            for x, y in buckets.iteritems():
                plot.append((x, y))

        # write file
        dataFileName = "tmpData-%d-%d.dat" % (self.seed, len(self.lines))
        dataHandle = open(dataFileName, "w")
        if not self.heatPoints:
        for (x, y) in plot:
            if y != "NA":
                dataHandle.write("%f\t%f\n" % (x, y))
                dataHandle.write("%f\t%s\n" % (x, y))


        if color == None:
            colorNumber = len(self.colors) % (self.nbColors - 1) + 1
            type = "solid"
            if len(self.colors) >= self.nbColors:
                type = "dashed"
            color = "colorPanel[%d]" % (colorNumber)
            color = "\"%s\"" % (color)
            type = "solid"

    def addHeatLine(self, line, name = "", color = None):
        Add the heat line 
        @param line: the line which gives the color of the points
        @type    line: dict
        if not self.heatPoints:
            raise Exception("Error! Trying to add a heat point whereas not mentioned to earlier! Aborting.")
        dataFileName = "tmpData-%d-%d.dat" % (self.seed, len(self.lines))
        dataHandle = open(dataFileName, "w")
        minimumHeat = min(line.values())
        maximumHeat = max(line.values())
        minLogValue = 0.00001
        log = self.logZ
        if log:
            if minimumHeat == 0:
                for element in line:
                    line[element] += minLogValue
                minimumHeat += minLogValue
                maximumHeat += minLogValue
            minimumHeat = math.log10(minimumHeat)
            maximumHeat = math.log10(maximumHeat)
        coeff = 255.0 / (maximumHeat - minimumHeat)

        for element in line:
            value = line[element]
            if log:
                value = math.log10(max(minLogValue, value))
            dataHandle.write("\"#%02X%02X00\"\n" % (int((value - minimumHeat) * coeff), 255 - int((value - minimumHeat) * coeff)))

        if color == None:
            colorNumber = len(self.colors) % (self.nbColors - 1) + 1
            type = "solid"
            if len(self.colors) >= self.nbColors:
                type = "dashed"
            color = "colorPanel[%d]" % (colorNumber)
            color = "\"%s\"" % (color)
            type = "solid"

    def getScript(self):
        Write (unfinished) R script
        script = ""

        xMin = self.xMin - self.leftMargin
        if self.minimumX != None:
            xMin = max(xMin, self.minimumX)
        xMax = self.xMax + self.rightMargin
        if self.maximumX != None:
            xMax = min(xMax, self.maximumX)
        yMin = self.yMin - self.bottomMargin
        if self.minimumY != None:
            yMin = self.minimumY
        yMax = self.getNewYMaxWithTopMargin()
        yMax += min(1, yMax / 100.0)
        if self.maximumY != None:
            yMax = self.maximumY

        log = ""
        if self.logX:
            log += "x"
        if self.logY:
            log += "y"
        if log != "":
            log = ", log=\"%s\"" % (log)

        title = ""
        if self.title != None:
            title = ", main = \"%s\"" % (self.title)

        if self.legend and self.legendBySide:
            script += "layout(matrix(c(1,2), 1, 2), widths=c(5,1))\n"

        if self.rotateAxesLabels[2]:
            script += "par(mar=c(5,12,4,2))\n"
            script += "par(mar=c(5,5,4,2))\n"

        addAxes = True

        if self.barplot:
            script += "data = scan(\"tmpData-%d-0.dat\", list(x = -666, y = -666))\n" % (self.seed)
            if len(self.lines) == 1:
                script += "barplot(data$y, name = data$x, xlab=\"%s\", ylab=\"%s\", ylim = c(%f, %f), cex.axis = 2, cex.names = 2, cex.lab = 2%s%s)\n" % (self.xLabel, self.yLabel, yMin, yMax, title, log)
                addAxes = False
                script += "data1 = scan(\"tmpData-%d-1.dat\", list(x = -666, y = -666))\n" % (self.seed)
                script += "barplot(rbind(data$y, data1$y), name = data$x, xlab=\"%s\", ylab=\"%s\", cex.axis = 2, cex.names = 2, cex.lab = 2%s, beside = TRUE, space=c(-1,0), axes = FALSE%s)\n" % (self.xLabel, self.yLabel, title, log)
        elif self.points:
            script += "data = scan(\"tmpData-%d-0.dat\", list(x = -666, y = -666))\n" % (self.seed)
            script += "plot(data$x, data$y, xlab=\"%s\", ylab=\"%s\", cex.axis = 2, cex.lab = 2, axes = FALSE%s%s)\n" % (self.xLabel, self.yLabel, title, log)
            if self.regression:
                x = "log10(data$x)" if self.logX else "data$x"
                y = "log10(data$y)" if self.logY else "data$y"
                script += "abline(lm(%s ~ %s))\n" % (y, x)
        elif self.heatPoints:
            if len(self.lines) != 1:
                raise Exception("Error! Bad number of input data! Aborting...")
            script += "data = scan(\"tmpData-%d-0.dat\", list(x = -666, y = -666))\n" % (self.seed)
            script += "heatData = scan(\"tmpData-%d-1.dat\", list(x = \"\"))\n" % (self.seed)
            script += "plot(data$x, data$y, col=heatData$x, xlab=\"%s\", ylab=\"%s\", cex.axis = 2, cex.lab = 2, axes = FALSE%s%s)\n" % (self.xLabel, self.yLabel, title, log)
            if self.regression:
                x = "log10(data$x)" if self.logX else "data$x"
                y = "log10(data$y)" if self.logY else "data$y"
                script += "abline(lm(%s ~ %s))\n" % (y, x)
            script += "plot(x = NA, y = NA, panel.first = grid(lwd = 1.0), xlab=\"%s\", ylab=\"%s\", xlim = c(%f, %f), ylim = c(%f, %f), cex.axis = 2, cex.lab = 2, axes = FALSE%s%s)\n" % (self.xLabel, self.yLabel, xMin, xMax, yMin, yMax, title, log)
            for i in range(0, len(self.lines)):
                script += "data = scan(\"tmpData-%d-%d.dat\", list(x = -666.666, y = -666.666))\n" % (self.seed, i)
                script += "lines(x = data$x, y = data$y, col = %s, lty = \"%s\", lwd = %d)\n" % (self.colors[i], self.types[i], self.lineWidth)
            script += self.linesToAddBox
        if addAxes:
            for i in self.axesLabels:
                rotation = ", las = 2" if self.rotateAxesLabels[i] else ""
                if self.axesLabels[i] == None:
                    script += "axis(%d, cex.axis = 2, cex.lab = 2%s)\n" % (i, rotation)
                    oldKeys = ", ".join(["%d" % (key) for key in sorted(self.axesLabels[i].keys())])
                    newKeys = ", ".join(["\"%s\"" % (self.axesLabels[i][key]) for key in sorted(self.axesLabels[i].keys())])
                    script += "axis(%d, at=c(%s), lab=c(%s), cex.axis = 2, cex.lab = 2%s)\n" % (i, oldKeys, newKeys, rotation)
        script += "box()\n"

        if self.legend:
            if self.legendBySide:
                script += "\n"
                script += "par(mar=c(0,0,0,0))\n"
                script += "plot.window(c(0,1), c(0,1))\n"
            script += "legends   = c(%s)\n" % ", ".join(["\"%s\"" % name  for name  in self.names])
            script += "colors    = c(%s)\n" % ", ".join(["%s" %     color for color in self.colors])
            script += "lineTypes = c(%s)\n" % ", ".join(["\"%s\"" % type  for type  in self.types])
            if self.legendBySide:
                script += "legend(0, 1, legend = legends, xjust = 0, yjust = 1, col = colors, lty = lineTypes, lwd = %d, cex = 1.5, ncol = 1, bg = \"white\")\n" % (self.lineWidth)
                script += "legend(\"topright\", legend = legends, xjust = 0, yjust = 1, col = colors, lty = lineTypes, lwd = %d, cex = 1.5, ncol = 1, bg = \"white\")\n" % (self.lineWidth)

        return script

    def plot(self):
        Plot the lines
        scriptFileName = "tmpScript-%d.R" % (self.seed)
        scriptHandle = open(scriptFileName, "w")
        scriptHandle.write("colorPanel = brewer.pal(n=%d, name=\"Set1\")\n" % (self.nbColors))
        scriptHandle.write("%s(%s = \"%s\", width = %d, height = %d, bg = \"white\")\n" % (self.format, "filename" if self.format != "pdf" else "file", self.fileName, self.width, self.height))
        rCommand = "R"
        if "SMARTRPATH" in os.environ:
            rCommand = os.environ["SMARTRPATH"]
        command = "\"%s\" CMD BATCH %s" % (rCommand, scriptFileName)
        status =, shell=True)

        if status != 0:
            self.keep = True
            raise Exception("Problem with the execution of script file %s, status is: %s" % (scriptFileName, status))

    def getCorrelationData(self):
        if not self.regression:
            return ""
        scriptFileName = "tmpScript-%d.R" % (self.seed)
        rScript = open(scriptFileName, "w")
        rScript.write("data = scan(\"tmpData-%d-0.dat\", list(x = -0.000000, y = -0.000000))\n" % (self.seed))
        x = "log10(data$x)" if self.logX else "data$x"
        y = "log10(data$y)" if self.logY else "data$y"
        rScript.write("summary(lm(%s ~ %s))\n" % (y, x))
        rCommand = "R"
        if "SMARTRPATH" in os.environ:
            rCommand = os.environ["SMARTRPATH"]
        command = "\"%s\" CMD BATCH %s" % (rCommand, scriptFileName)
        status =, shell=True)
        if status != 0:
            self.keep = True
            raise Exception("Problem with the execution of script file %s computing the correlation, status is: %s" % (scriptFileName, status))
        outputRFile = open("%sout" % (scriptFileName))
        output      = ""
        start       = False
        end         = False
        for line in outputRFile:
            if start and "> " in line:
                end = True
            if start and not end:
                output += line
            if "summary" in line:
                start = True
        return output

    def getSpearmanRho(self):
        Get the Spearman rho correlation using R
        return None
        if not self.points and not self.barplot and not self.heatPoints:
            raise Exception("Cannot compute Spearman rho correlation whereas not in 'points' or 'bar' mode.")
        scriptFileName = "tmpScript-%d.R" % (self.seed)
        rScript = open(scriptFileName, "w")
        rScript.write("data = scan(\"tmpData-%d-0.dat\", list(x = -0.000000, y = -0.000000))\n" % (self.seed))
        rScript.write("spearman(data$x, data$y)\n")

        rCommand = "R"
        if "SMARTRPATH" in os.environ:
            rCommand = os.environ["SMARTRPATH"]
        command = "\"%s\" CMD BATCH %s" % (rCommand, scriptFileName)
        status =, shell=True)

        if status != 0:
            self.keep = True
            raise Exception("Problem with the execution of script file %s, status is: %s" % (scriptFileName, status))

        outputRFile = open("%sout" % (scriptFileName))
        nextLine = False
        for line in outputRFile:
            line = line.strip()
            if nextLine:
                if line == "NA":
                    return None
                return float(line)
                nextLine = False
            if line == "rho":
                nextLine = True

        return None