# Copyright INRA-URGI 2009-2010
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "".
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
import re
import sys

class MySqlTable(object):
    Store a table of a mySQL database, used for transcripts or exons
    Record a a name and a type (int, float, double) for each column
    @ivar name:            name of the table
    @type name:            string
    @ivar variables:       name of the columns
    @type variables:       list of string
    @ivar types:           type of the columns
    @type types:           dict of string
    @ivar mySqlConnection: connection to a database
    @type mySqlConnection: class L{MySqlConnection<MySqlConnection>}
    @ivar nbLines:         number of rows
    @type nbLines:         int
    @ivar verbosity:       verbosity
    @type verbosity:       int

    def __init__(self, connection, name, verbosity = 0):
        Possibly retrieve column names and types if table exists
        @param mySqlConnection: connection to a databas
        @type  mySqlConnection: class L{MySqlConnection<MySqlConnection>}
        @param name:            name of the table
        @type  name:            string
        @param verbosity:       verbosity
        @type  verbosity:       int
        """      = name
        self.variables = []
        self.types     = {}
        self.sizes     = {}
        self.nbLines   = None
        self.verbosity = verbosity
        self.mySqlConnection = connection
        queryTables = self.mySqlConnection.executeQuery("SELECT name FROM sqlite_master WHERE type LIKE 'table' AND name LIKE '%s'" % (
        self.created = not queryTables.isEmpty()
        if self.created:
            queryFields = self.mySqlConnection.executeQuery("PRAGMA table_info('%s')" % (name))
            for field in queryFields.getIterator():
                if field[1] != "id":
                    self.types[field[1]] = field[2]
                    self.sizes[field[1]] = field[3]
    def getName(self):

    def create(self, variables, types, sizes):
        Create a table using give column names and types
        @param variables: names of the columns
        @type  variables: list of string
        @param types:     types of the columns
        @type  types:     dict of string
        @param sizes:     sizes of the types
        @type  sizes:     dict of int
        self.variables = variables
        self.types = types
        self.sizes = sizes
        if self.created:
        query = "CREATE TABLE '%s' (id INTEGER PRIMARY KEY" % (
        for variable in variables:
            query = "%s, %s %s(%d)" % (query, variable, types[variable], sizes[variable])
        query += ")"
        self.created = True

    def insertMany(self, lines):
        Insert many lines
        @param lines: the list of values
        @type  lines: list of lists
        commands = []
        for values in lines:
            commands.append("INSERT INTO '%s' (%s) VALUES (%s)" % (, ", ".join(self.variables), ", ".join([MySqlTable.formatSql(values[variable], self.types[variable], self.sizes[variable]) for variable in self.variables])))
    def rename(self, name):
        Rename the table
        @param name: the new name
        @type  name: string
        self.mySqlConnection.executeQuery("RENAME TABLE '%s' TO '%s'" % (, name)) = name
    def copy(self, table):
        Copy the given table this one
        @param table: the table to be copied
        @type  table: class L{MySqlTable<MySqlTable>}
        variables = []
        types = {}
        sizes = {}
        fields = self.mySqlConnection.executeQuery("PRAGMA table_info(%s)" % (
        for field in fields.getIterator():
            if field[1] != "id":
                m ="(\w+)\((\d+)\)", field[2])
                if m == None:
                    raise Exception("\nFormat %s in table %s is strange." % (field[2],
                types[field[1]] =
                sizes[field[1]] = int(
        self.create(variables, types, sizes)
        self.mySqlConnection.executeQuery("INSERT INTO '%s' SELECT * FROM %s" % (,

    def add(self, table):
        Add the content of a table to this one
        @param table: the table to be added
        @type  table: class L{MySqlTable<MySqlTable>}
        self.mySqlConnection.executeQuery("INSERT INTO '%s' SELECT * FROM %s" % (,
        self.created = True
    def exists(self):
        Check if the table exists in mySQL
        @return: true if it exits
        return self.created

    def remove(self):
        Remove this table
        if self.exists():
            query = "DROP TABLE IF EXISTS '%s'" % (
        self.created = False
    def clear(self):
        Clear the content of this table
        self.mySqlConnection.executeQuery("DELETE FROM '%s'" % (
    def getNbElements(self):
        Count the number of rows in the table
        command = "SELECT COUNT(id) FROM '%s'" % (
        query = self.mySqlConnection.executeQuery(command)
        return int(query.getLine()[0])

    def formatSql(self, value, type, size):
        Format a value using MySQL encapsulation
        if type.find("int") != -1:
            return "%d" % value
        if type.find("float") != -1:
            return "%.10f" % value
        if type.find("double") != -1:
            return "%.20f" % value
        if type.find("varchar") != -1:
            if len(value) > size:
                return "'%s'" % value[0:size]
            return "'%s'" % value
        raise Exception("Do not understand type %s" % (type))
    formatSql = classmethod(formatSql)

    def addLine(self, values):
        Add a row to this table
        @param values: the values of the row
        @type  values: dict
        @return:       the id of the added row
        sqlValues = []
        for variable in self.variables:
            sqlValues.append(self.formatSql(values[variable], self.types[variable], self.sizes[variable]))
        command = "INSERT INTO '%s' (%s) VALUES (%s)" % (, ", ".join(self.variables), ", ".join(sqlValues))
        id = self.mySqlConnection.executeQuery(command, True)
        return id
    def retrieveFromId(self, id):
        Retrieve a row from its id
        @param id: the id of the row
        @type  id: int
        @return:   the row
        query = self.mySqlConnection.executeQuery("SELECT * FROM '%s' WHERE id = %d" % (, id))
        result = query.getLine()
        if result == None:
            raise Exception("Error! Id %d is not in the table %s!" % (id,
        return result

    def retrieveBulkFromId(self, ids):
        Retrieve a row from its id
        @param id: the ids of the row
        @type  id: list of int
        @return:   the row
        if not ids:
            return []
        MAXSIZE = 1000
        results = []
        for batch in range(len(ids) / MAXSIZE + 1):
            theseIds = ids[batch * MAXSIZE : (batch+1) * MAXSIZE]
            if theseIds:
                query = self.mySqlConnection.executeQuery("SELECT * FROM '%s' WHERE id IN (%s)" % (, ", ".join(["%d" % (id) for id in theseIds])))
                lines = query.getLines()
                if len(lines) != len(theseIds):
                    raise Exception("Error! Some Ids of (%s) is are missing in the table '%s' (got %d instead of %d)!" % (", ".join(["%d" % (id) for id in theseIds]),, len(lines)), len(theseIds))
        return results

    def removeFromId(self, id):
        Remove a row from its id
        @param id: the id of the row
        @type  id: int
        self.mySqlConnection.executeQuery("DELETE FROM '%s' WHERE id = %d" % (, id))
    def getIterator(self):
        Iterate on the content of table
        @return: iterator to the rows of the table
        if not self.created:
        MAXSIZE = 1000
        query = self.mySqlConnection.executeQuery("SELECT count(id) FROM '%s'" % (
        nbRows = int(query.getLine()[0])
        for chunk in range((nbRows / MAXSIZE) + 1):
            query = self.mySqlConnection.executeQuery("SELECT * FROM '%s' LIMIT %d, %d" % (, chunk * MAXSIZE, MAXSIZE))
            for line in query.getIterator():
                yield line

    def createIndex(self, indexName, values, unique = False, fullText = False):
        Add an index on the table
        @param indexName: name of the index
        @type  indexName: string
        @param values:    values to be indexed
        @type  values:    string
        @param unique:    if the index is unique
        @type  unique:    boolean
        @param fullText:  whether full text should be indexed
        @type  fullText:  boolean
        self.mySqlConnection.executeQuery("CREATE %s%sINDEX '%s' ON '%s' (%s)" % ("UNIQUE " if unique else "", "FULLTEXT " if fullText else "", indexName,, ", ".join(values)))

    def setDefaultTagValue(self, field, name, value):
        Add a tag value
        @param name:  name of the tag
        @type  name:  string
        @param value: value of the tag
        @type  value: string or int
        newData = {}
        for line in MySqlTable.getIterator(self):
            id = line[0]
            tags = line[field]
            if tags == '':
                newTag = "%s=%s" % (name, value)
                newTag = "%s;%s=%s" % (tags, name, value)
            if name not in [tag.split("=")[0] for tag in tags.split(";")]:
                newData[id] = newTag
        for id, tag in newData.iteritems():
            query = self.mySqlConnection.executeQuery("UPDATE '%s' SET tags = '%s' WHERE id = %i" % (, tag, id))

    def show(self):
        Drop the content of the current table
        query = self.mySqlConnection.executeQuery("SELECT * FROM '%s'" % (
        print query.getLines()