#! /usr/bin/env python
#
# Copyright INRA-URGI 2009-2012
# 
# This software is governed by the CeCILL license under French law and
# abiding by the rules of distribution of free software. You can use,
# modify and/ or redistribute the software under the terms of the CeCILL
# license as circulated by CEA, CNRS and INRIA at the following URL
# "http://www.cecill.info".
# 
# As a counterpart to the access to the source code and rights to copy,
# modify and redistribute granted by the license, users are provided only
# with a limited warranty and the software's author, the holder of the
# economic rights, and the successive licensors have only limited
# liability.
# 
# In this respect, the user's attention is drawn to the risks associated
# with loading, using, modifying and/or developing or reproducing the
# software by the user in light of its specific status of free software,
# that may mean that it is complicated to manipulate, and that also
# therefore means that it is reserved for developers and experienced
# professionals having in-depth computer knowledge. Users are therefore
# encouraged to load and test the software's suitability as regards their
# requirements in conditions enabling the security of their systems and/or
# data to be ensured and, more generally, to use and operate it in the
# same conditions as regards security.
# 
# The fact that you are presently reading this means that you have had
# knowledge of the CeCILL license and that you accept its terms.
#

import os, struct, time, shutil
from optparse import OptionParser
from pyRepetUnit.commons.parsing.ParserChooser import ParserChooser
from pyRepetUnit.commons.writer.Gff3Writer import Gff3Writer
from SMART.Java.Python.structure.Transcript import Transcript
from SMART.Java.Python.structure.Interval import Interval
from SMART.Java.Python.ncList.NCList import NCList
from SMART.Java.Python.ncList.ConvertToNCList import ConvertToNCList
from SMART.Java.Python.ncList.NCListParser import NCListParser
from SMART.Java.Python.ncList.NCListCursor import NCListCursor
from SMART.Java.Python.ncList.NCListFilePickle import NCListFilePickle, NCListFileUnpickle
from SMART.Java.Python.ncList.FileSorter import FileSorter
from SMART.Java.Python.ncList.NCListHandler import NCListHandler
from SMART.Java.Python.misc.Progress import Progress
from SMART.Java.Python.misc.UnlimitedProgress import UnlimitedProgress
try:
   import cPickle as pickle
except:
   import pickle

REFERENCE = 0
QUERY = 1
TYPES = (REFERENCE, QUERY)
TYPETOSTRING = {0: "reference", 1: "query"}

class FindOverlapsOptim(object):
	
	def __init__(self, verbosity = 1):
		self._parsers				  = {}
		self._sortedFileNames		  = {}
		self._outputFileName		  = "outputOverlaps.gff3"
		self._iWriter				  = None
		self._inputFileNames		  = {REFERENCE: None,  QUERY: None}
		self._convertedFileNames      = {REFERENCE: False, QUERY: False}
		self._inputFileFormats		  = {REFERENCE: None,  QUERY: None}
		self._converted			      = {REFERENCE: False, QUERY: False}
		self._ncListHandlers          = {REFERENCE: None,  QUERY: None}
		self._splittedFileNames	      = {REFERENCE: {},	QUERY: {}}
		self._nbOverlappingQueries	  = 0
		self._nbOverlaps			  = 0
		self._nbLines				  = {REFERENCE: 0, QUERY: 0}
		self._verbosity			      = verbosity
		self._ncLists				  = {}
		self._cursors				  = {}
		self._nbElementsPerChromosome = {}
		self._tmpDirectories		  = {REFERENCE: False, QUERY: False}
		
	def close(self):
		self._iWriter.close()
		for fileName in (self._sortedFileNames.values()):
			if os.path.exists(fileName):
				os.remove(fileName)
		for fileName in self._convertedFileNames.values():
			if fileName:
				os.remove(fileName)
		
	def setRefFileName(self, fileName, format):
		self.setFileName(fileName, format, REFERENCE)
		
	def setQueryFileName(self, fileName, format):
		self.setFileName(fileName, format, QUERY)

	def setFileName(self, fileName, format, type):
		self._inputFileNames[type] = fileName
		self._inputFileFormats[type] = format
		if format.lower() != "nclist":
			self._converted[type] = True
		
	def setOutputFileName(self, outputFileName):
		self._outputFileName = outputFileName
		self._iWriter = Gff3Writer(self._outputFileName)

	def createNCLists(self):
		startTime = time.time()
		if self._verbosity > 1:
			print "Building database"
		self._ncLists = dict([type, {}] for type in TYPES)
		self._indices = dict([type, {}] for type in TYPES)
		self._cursors = dict([type, {}] for type in TYPES)
		for type in TYPES:
			self._ncListHandlers[type] = NCListHandler(self._verbosity-3)
			if self._converted[type]:
				self._convertedFileNames[type] = "%s_%d.ncl" % (os.path.splitext(self._inputFileNames[type])[0], type)
				ncLists = ConvertToNCList(self._verbosity-3)
				ncLists.setInputFileName(self._inputFileNames[type], self._inputFileFormats[type])
				ncLists.setOutputFileName(self._convertedFileNames[type])
				if type == REFERENCE:
					ncLists.setIndex(True)
				ncLists.run()
				self._ncListHandlers[type].setFileName(self._convertedFileNames[type])
			else:
				self._ncListHandlers[type].setFileName(self._inputFileNames[type])
			self._ncListHandlers[type].loadData()
			self._nbLines[type]				    = self._ncListHandlers[type].getNbElements()
			self._nbElementsPerChromosome[type] = self._ncListHandlers[type].getNbElementsPerChromosome()
			self._ncLists[type]				    = self._ncListHandlers[type].getNCLists()
			for chromosome, ncList in self._ncLists[type].iteritems():
				self._cursors[type][chromosome] = NCListCursor(None, ncList, 0, self._verbosity)
				if type == REFERENCE:
					self._indices[REFERENCE][chromosome] = ncList.getIndex()
		endTime = time.time()
		if self._verbosity > 1:
			print "done (%.2gs)" % (endTime - startTime)

	def compare(self):
		nbSkips, nbMoves   = 0, 0
		previousChromosome = None
		done			   = False
		startTime		   = time.time()
		progress		   = Progress(len(self._ncLists[QUERY].keys()), "Checking overlap", self._verbosity)
		#print "query:", self._ncLists[QUERY].keys()
		#print "reference:", self._ncLists[REFERENCE].keys()
		for chromosome, queryNCList in self._ncLists[QUERY].iteritems():
			queryParser = self._ncListHandlers[QUERY].getParser(chromosome)
			queryCursor = self._cursors[QUERY][chromosome]
			if chromosome != previousChromosome:
				skipChromosome	  = False
				previousChromosome  = chromosome
				if chromosome not in self._ncLists[REFERENCE]:
					#print "out ", chromosome
					continue
				refNCList = self._ncLists[REFERENCE][chromosome]
				refCursor = self._cursors[REFERENCE][chromosome]
			#print "starting", chromosome
			while True:
				queryTranscript = queryCursor.getTranscript()
				#print queryTranscript
				newRefLaddr = self.checkIndex(queryTranscript, refCursor)
				#print "query is", queryTranscript
				if newRefLaddr != None:
					nbMoves += 1
					refCursor.setLIndex(newRefLaddr)
					#print "skipping to", refCursor
					done = False
				refCursor, done, unmatched = self.findOverlapIter(queryTranscript, refCursor, done)
				#print "completed with", refCursor, done, unmatched
				if refCursor.isOut():
					#print "exiting 1", chromosome
					break
				if unmatched or not queryCursor.hasChildren():
					queryCursor.moveNext()
					#print "moving next to", queryCursor
					nbSkips += 1
				else:
					queryCursor.moveDown()
					#print "moving down to", queryCursor
				if queryCursor.isOut():
					#print "exiting 2", chromosome
					break
			progress.inc()
		progress.done()
		endTime = time.time()
		self._timeSpent = endTime - startTime
		if self._verbosity >= 10:
			print "# skips:   %d" % (nbSkips)
			print "# moves:   %d" % (nbMoves)

	def findOverlapIter(self, queryTranscript, cursor, done):
		chromosome = queryTranscript.getChromosome()
		if chromosome not in self._ncLists[REFERENCE]:
			return False, None
		ncList = self._ncLists[REFERENCE][chromosome]
		overlappingNames = {}
		nextDone = False
		firstOverlapLAddr = NCListCursor(cursor)
		firstOverlapLAddr.setLIndex(-1)
		if cursor.isOut():
			return firstOverlapLAddr, False
		parentCursor = NCListCursor(cursor)
		parentCursor.moveUp()
		firstParentAfter = False
		#print "query transcript 1", queryTranscript
		#print "cursor 1", cursor
		#print "parent 1", parentCursor
		while not parentCursor.isOut(): 
			if self.isOverlapping(queryTranscript, parentCursor) == 0:
				#print "overlap parent choice 0"
				overlappingNames.update(self._extractID(parentCursor.getTranscript()))
				if firstOverlapLAddr.isOut():
					#print "overlap parent 2"
					firstOverlapLAddr.copy(parentCursor)
					nextDone = True # new
					#   lastCursor = NCListCursor(parentCursor)
					#   lastCursor.moveDown()
					#   lastCursor.moveLastSibling()
					#   if self.isOverlapping(queryTranscript, lastCursor) == -1:
					#	   #print "next done 1"
					#	   nextDone = True
			elif self.isOverlapping(queryTranscript, parentCursor) == 1:
				#print "overlap parent choice 1"
				firstParentAfter = NCListCursor(parentCursor)
			parentCursor.moveUp()
			#print "parent 2", parentCursor
		if firstParentAfter:
			#print "exit parent", firstParentAfter, overlappingNames
			self._writeIntervalInNewGFF3(queryTranscript, overlappingNames)
			return firstParentAfter, False, not overlappingNames
		#This loop finds the overlaps with currentRefLAddr.#
		while True:
			#print "ref cursor now is", cursor
			parentCursor = NCListCursor(cursor)
			parentCursor.moveUp()
			#In case: Query is on the right of the RefInterval and does not overlap.
			overlap = self.isOverlapping(queryTranscript, cursor)
			if overlap == -1:
				#print "choice 1"
				#   if cursor.isLast() and not firstOverlapLAddr.isOut():
				#	   if firstOverlapLAddr.compare(parentCursor):
				#		   #print "next done 2"
				#		   nextDone = True
				cursor.moveNext()
			#In case: Query overlaps with RefInterval.	
			elif overlap == 0:
				#print "choice 2"
				overlappingNames.update(self._extractID(cursor.getTranscript()))
				if firstOverlapLAddr.compare(parentCursor):
					firstOverlapLAddr.copy(cursor)
					nextDone = True # new
				if done:
					cursor.moveNext()
				else:
					if not cursor.hasChildren():
						cursor.moveNext()
						if cursor.isOut():
							#print "break 1"
							break
					else:
						cursor.moveDown()
			#In case: Query is on the left of the RefInterval and does not overlap.		
			else:
				#print "choice 3"
				if firstOverlapLAddr.isOut() or firstOverlapLAddr.compare(parentCursor):
					#print "changing nfo 2"
					firstOverlapLAddr.copy(cursor)
					nextDone = False # new
				#print "break 2"
				break
			
			done = False
			if cursor.isOut():
				#print "break 3"
				break
		self._writeIntervalInNewGFF3(queryTranscript, overlappingNames)
		return firstOverlapLAddr, nextDone, not overlappingNames
	
	def isOverlapping(self, queryTranscript, refTranscript):
		if (queryTranscript.getStart() <= refTranscript.getEnd() and queryTranscript.getEnd() >= refTranscript.getStart()):
			return 0   
		if queryTranscript.getEnd() < refTranscript.getStart():
			return 1
		return -1

	def checkIndex(self, transcript, cursor):
		chromosome = transcript.getChromosome()
		nextLIndex = self._indices[REFERENCE][chromosome].getIndex(transcript)
		if nextLIndex == None:
			return None
		ncList		 = self._ncLists[REFERENCE][chromosome]
		nextGffAddress = ncList.getRefGffAddr(nextLIndex)
		thisGffAddress = cursor.getGffAddress()
		if nextGffAddress > thisGffAddress:
			return nextLIndex
		return None
		
	def _writeIntervalInNewGFF3(self, transcript, names):
		nbOverlaps = 0
		for cpt in names.values():
			nbOverlaps += cpt
		if not names:
			return
		transcript.setTagValue("overlapsWith", "--".join(sorted(names.keys())))
		transcript.setTagValue("nbOverlaps", nbOverlaps)
		self._iWriter.addTranscript(transcript)
		self._iWriter.write()
		self._nbOverlappingQueries += 1
		self._nbOverlaps		   += nbOverlaps
		
	def _extractID(self, transcript):
		nbElements = float(transcript.getTagValue("nbElements")) if "nbElements" in transcript.getTagNames() else 1
		id		   = transcript.getTagValue("ID")				 if "ID"		 in transcript.getTagNames() else transcript.getUniqueName()
		return {id: nbElements}
		
	def run(self):
		self.createNCLists()
		self.compare()
		self.close()
		if self._verbosity > 0:
			print "# queries: %d" % (self._nbLines[QUERY])
			print "# refs:    %d" % (self._nbLines[REFERENCE])
			print "# written: %d (%d overlaps)" % (self._nbOverlappingQueries, self._nbOverlaps)
			print "time:      %.2gs" % (self._timeSpent)


if __name__ == "__main__":
	description = "Find Overlaps Optim v1.0.0: Finds overlaps with several query intervals. [Category: Data Comparison]"

	parser = OptionParser(description = description)
	parser.add_option("-i", "--query",	     dest="inputQueryFileName", action="store",			   type="string",  help="Query input file [compulsory] [format: file in transcript or other format given by -f]")
	parser.add_option("-f", "--queryFormat", dest="queryFormat",		action="store",			   type="string",  help="format of previous file (possibly in NCL format) [compulsory] [format: transcript or other file format]")
	parser.add_option("-j", "--ref",		 dest="inputRefFileName",   action="store",			   type="string",  help="Reference input file [compulsory] [format: file in transcript or other format given by -g]")
	parser.add_option("-g", "--refFormat",   dest="refFormat",		    action="store",			   type="string",  help="format of previous file (possibly in NCL format) [compulsory] [format: transcript or other file format]")
	parser.add_option("-o", "--output",	     dest="outputFileName",	    action="store",			   type="string",  help="Output file [compulsory] [format: output file in GFF3 format]")
	parser.add_option("-v", "--verbosity",   dest="verbosity",		    action="store", default=1, type="int",	   help="Trace level [format: int] [default: 1]")
	(options, args) = parser.parse_args()
	
	iFOO = FindOverlapsOptim(options.verbosity)
	iFOO.setRefFileName(options.inputRefFileName, options.refFormat)
	iFOO.setQueryFileName(options.inputQueryFileName, options.queryFormat)
	iFOO.setOutputFileName(options.outputFileName)
	iFOO.run()
