import sys
from commons.core.seq.BioseqDB import BioseqDB
from commons.core.seq.Bioseq import Bioseq
from commons.core.coord.Align import Align
from commons.core.coord.Range import Range
from commons.core.stat.Stat import Stat
from math import log

## Multiple Sequence Alignment Representation   
class AlignedBioseqDB( BioseqDB ):
    def __init__( self, name="" ):
        BioseqDB.__init__( self, name )
        seqLength = self.getLength()
        if self.getSize() > 1:
            for bs in self.db[1:]:
                if bs.getLength() != seqLength:
                    print "ERROR: aligned sequences have different length"
    ## Get length of the alignment
    # @return length
    # @warning name before migration was 'length'
    def getLength( self ):
        length = 0
        if self.db != []:
            length = self.db[0].getLength()
        return length
    ## Get the true length of a given sequence (without gaps)
    # @param header string header of the sequence to analyze
    # @return length integer
    # @warning  name before migration was 'true_length'
    def getSeqLengthWithoutGaps( self, header ):
        bs = self.fetch( header )
        count = 0
        for pos in xrange(0,len(bs.sequence)):
            if bs.sequence[pos] != "-":
                count += 1
        return count
    def cleanMSA( self ):
        #TODO: Refactoring
        """clean the MSA"""
        i2del = []

        # for each sequence in the MSA
        for seqi in xrange(0,self.getSize()):
            if seqi in i2del:
            #define it as the reference
            ref = self.db[seqi].sequence
            refHeader = self.db[seqi].header
            # for each following sequence
            for seq_next in xrange(seqi+1,self.getSize()):
                if seq_next in i2del:
                keep = 0
                # for each position along the MSA
                for posx in xrange(0,self.getLength()):
                    seq = self.db[seq_next].sequence
                    if seq[posx] != '-' and ref[posx] != '-':
                        keep = 1
                seqHeader = self.db[seq_next].header
                # if there is at least one gap between the ref seq and the other seq
                # keep track of the shortest by recording it in "i2del"
                if keep == 0:
                    if self.getSeqLengthWithoutGaps(refHeader) < self.getSeqLengthWithoutGaps(seqHeader):
                        if seqi not in i2del:
                            i2del.append( seqi )
                        if seq_next not in i2del:
                            i2del.append( seq_next )

        # delete from the MSA each seq present in the list "i2del"
        for i in reversed(sorted(set(i2del))):
            del self.db[i]

        self.idx = {}
        count = 0
        for i in self.db:
            self.idx[i.header] = count
            count += 1

    ## Record the occurrences of symbols (A, T, G, C, N, -, ...) at each site
    # @return: list of dico whose keys are symbols and values are their occurrences
    def getListOccPerSite( self ):
        lOccPerSite = []   # list of dictionaries, one per position on the sequence
        n = 0    # nb of sequences parsed from the input file
        firstSeq = True

        # for each sequence in the bank
        for bs in self.db:
            if bs.sequence == None:
            n += 1

            # if it is the first to be parsed, create a dico at each site
            if firstSeq:
                for i in xrange(0,len(bs.sequence)):
                    lOccPerSite.append( {} )
                firstSeq = False

            # for each site, add its nucleotide
            for i in xrange(0,len(bs.sequence)):
                nuc = bs.sequence[i].upper()
                if lOccPerSite[i].has_key( nuc ):
                    lOccPerSite[i][nuc] += 1
                    lOccPerSite[i][nuc] = 1

        return lOccPerSite
    #TODO: review minNbNt !!! It should be at least 2 nucleotides to build a consensus...
    ## Make a consensus from the MSA
    # @param minNbNt: minimum nb of nucleotides to edit a consensus
    # @param minPropNt: minimum proportion for the major nucleotide to be used, otherwise add 'N' (default=0.0)
    # @param verbose: level of information sent to stdout (default=0/1)
    # @return: consensus
    def getConsensus( self, minNbNt, minPropNt=0.0, verbose=0 , isHeaderSAtannot=False):

        maxPropN = 0.40  # discard consensus if more than 40% of N's

        nbInSeq = self.getSize()
        if verbose > 0:
            print "nb of aligned sequences: %i" % ( nbInSeq ); sys.stdout.flush()
        if nbInSeq < 2:
            print "ERROR: can't make a consensus with less than 2 sequences"
        if minNbNt >= nbInSeq:
            minNbNt = nbInSeq - 1
            print "minNbNt=%i" % ( minNbNt )
        if minPropNt >= 1.0:
            print "ERROR: minPropNt=%.2f should be a proportion (below 1.0)" % ( minPropNt )

        lOccPerSite = self.getListOccPerSite()
        nbSites = len(lOccPerSite)
        if verbose > 0:
            print "nb of sites: %i" % ( nbSites ); sys.stdout.flush()

        seqConsensus = ""

        # for each site (i.e. each column of the MSA)
        nbRmvColumns = 0
        countSites = 0
        for dNt2Occ in lOccPerSite:
            countSites += 1
            if verbose > 1:
                print "site %s / %i" % ( str(countSites).zfill( len(str(nbSites)) ),
                                         nbSites )
            occMaxNt = 0   # occurrences of the predominant nucleotide at this site
            lBestNt = []
            nbNt = 0   # total nb of A, T, G and C (no gap)

            # for each distinct symbol at this site (A, T, G, C, N, -,...)
            for j in dNt2Occ.keys():
                if j != "-":
                    nbNt += dNt2Occ[j]
                    if verbose > 1:
                        print "%s: %i" % ( j, dNt2Occ[j] )
                    if dNt2Occ[j] > occMaxNt:
                        occMaxNt = dNt2Occ[j]
                        lBestNt = [ j ]
                    elif dNt2Occ[j] == occMaxNt:
                        lBestNt.append( j )
            if nbNt == 0:   # some MSA programs can remove some sequences (e.g. Muscle after Recon) or when using Refalign (non-alignable TE fragments put together via a refseq)
                nbRmvColumns += 1

            if len( lBestNt ) >= 1:
                bestNt = lBestNt[0]
            # if the predominant nucleotide occurs in less than x% of the sequences, put a "N"
            if minPropNt > 0.0 and nbNt != 0 and float(occMaxNt)/float(nbNt) < minPropNt:
                bestNt = "N"

            if int(nbNt) >= int(minNbNt):
                seqConsensus += bestNt
                if verbose > 1:
                    print "-> %s" % ( bestNt )

        if nbRmvColumns:
            if nbRmvColumns == 1:
                print "WARNING: 1 site was removed (%.2f%%)" % (nbRmvColumns / float(nbSites) * 100)
                print "WARNING: %i sites were removed (%.2f%%)" % ( nbRmvColumns, nbRmvColumns / float(nbSites) * 100 )
            if seqConsensus == "":
                print "WARNING: no consensus can be built (no sequence left)"

        propN = seqConsensus.count("N") / float(len(seqConsensus))
        if propN >= maxPropN:
            print "WARNING: no consensus can be built (%i%% of N's >= %i%%)" % ( propN * 100, maxPropN * 100 )
        elif propN >= maxPropN * 0.5:
            print "WARNING: %i%% of N's" % ( propN * 100 )

        consensus = Bioseq()
        consensus.sequence = seqConsensus
        if isHeaderSAtannot:
            header = self.db[0].header
            pyramid = header.split("Gr")[1].split("Cl")[0]
            pile = header.split("Cl")[1].split(" ")[0]
            consensus.header = "consensus=%s length=%i nbAlign=%i pile=%s pyramid=%s" % (, len(seqConsensus), self.getSize(), pile, pyramid)
            consensus.header = "consensus=%s length=%i nbAlign=%i" % (, len(seqConsensus), self.getSize() )

        if verbose > 0:
            statEntropy = self.getEntropy( verbose - 1 )
            print "entropy: %s" % ( statEntropy.stringQuantiles() )

        return consensus
    ## Get the entropy of the whole multiple alignment (only for A, T, G and C)
    # @param verbose level of verbosity
    # @return statistics about the entropy of the MSA
    def getEntropy( self, verbose=0 ):

        stats = Stat()

        # get the occurrences of symbols at each site
        lOccPerSite = self.getListOccPerSite()

        countSite = 0

        # for each site
        for dSymbol2Occ in lOccPerSite:
            countSite += 1

            # count the number of nucleotides (A, T, G and C, doesn't count gap '-')
            nbNt = 0
            dATGC2Occ = {}
            for base in ["A","T","G","C"]:
                dATGC2Occ[ base ] = 0.0
            for nt in dSymbol2Occ.keys():
                if nt != "-":
                    nbNt += dSymbol2Occ[ nt ]
                    checkedNt = self.getATGCNFromIUPAC( nt )
                    if checkedNt in ["A","T","G","C"] and dSymbol2Occ.has_key( checkedNt ):
                        dATGC2Occ[ checkedNt ] += 1 * dSymbol2Occ[ checkedNt ]
                    else:   # for 'N'
                        if dSymbol2Occ.has_key( checkedNt ):
                            dATGC2Occ[ "A" ] += 0.25 * dSymbol2Occ[ checkedNt ]
                            dATGC2Occ[ "T" ] += 0.25 * dSymbol2Occ[ checkedNt ]
                            dATGC2Occ[ "G" ] += 0.25 * dSymbol2Occ[ checkedNt ]
                            dATGC2Occ[ "C" ] += 0.25 * dSymbol2Occ[ checkedNt ]
            if verbose > 2:
                for base in dATGC2Occ.keys():
                    print "%s: %i" % ( base, dATGC2Occ[ base ] )

            # compute the entropy for the site
            entropySite = 0.0
            for nt in dATGC2Occ.keys():
                entropySite += self.computeEntropy( dATGC2Occ[ nt ], nbNt )
            if verbose > 1:
                print "site %i (%i nt): entropy = %.3f" % ( countSite, nbNt, entropySite )
            stats.add( entropySite )

        return stats
    ## Get A, T, G, C or N from an IUPAC letter
    #  IUPAC = ['A','T','G','C','U','R','Y','M','K','W','S','B','D','H','V','N']
    # @return A, T, G, C or N
    def getATGCNFromIUPAC( self, nt ):
        iBs = Bioseq()
        return iBs.getATGCNFromIUPAC( nt )
    ## Compute the entropy based on the occurrences of a certain nucleotide and the total number of nucleotides
    def computeEntropy( self, nbOcc, nbNt ):
        if nbOcc == 0.0:
            return 0.0
            freq = nbOcc / float(nbNt)
            return - freq * log(freq) / log(2) 
    ## Save the multiple alignment as a matrix with '0' if gap, '1' otherwise
    def saveAsBinaryMatrix( self, outFile ):
        outFileHandler = open( outFile, "w" )
        for bs in self.db:
            string = "%s" % ( bs.header )
            for nt in bs.sequence:
                if nt != "-":
                    string += "\t%i" % ( 1 )
                    string += "\t%i" % ( 0 )
            outFileHandler.write( "%s\n" % ( string ) )
    ## Return a list of Align instances corresponding to the aligned regions (without gaps)
    # @param query string header of the sequence considered as query
    # @param subject string header of the sequence considered as subject
    def getAlignList( self, query, subject ):
        lAligns = []
        alignQ = self.fetch( query ).sequence
        alignS = self.fetch( subject ).sequence
        createNewAlign = True
        indexAlign = 0
        indexQ = 0
        indexS = 0
        while indexAlign < len(alignQ):
            if alignQ[ indexAlign ] != "-" and alignS[ indexAlign ] != "-":
                indexQ += 1
                indexS += 1
                if createNewAlign:
                    iAlign = Align( Range( query, indexQ, indexQ ),
                                    Range( subject, indexS, indexS ),
                                    int( alignQ[ indexAlign ] == alignS[ indexAlign ] ),
                                    int( alignQ[ indexAlign ] == alignS[ indexAlign ] ) )
                    lAligns.append( iAlign )
                    createNewAlign = False
                    lAligns[-1].range_query.end += 1
                    lAligns[-1].range_subject.end += 1
                    lAligns[-1].score += int( alignQ[ indexAlign ] == alignS[ indexAlign ] )
                    lAligns[-1].identity += int( alignQ[ indexAlign ] == alignS[ indexAlign ] )
                if not createNewAlign:
                    lAligns[-1].identity = 100 * lAligns[-1].identity / lAligns[-1].getLengthOnQuery()
                    createNewAlign = True
                if alignQ[ indexAlign ] != "-":
                    indexQ += 1
                elif alignS[ indexAlign ] != "-":
                    indexS += 1
            indexAlign += 1
        if not createNewAlign:
            lAligns[-1].identity = 100 * lAligns[-1].identity / lAligns[-1].getLengthOnQuery()
        return lAligns
    def removeGaps(self):
        for iBs in self.db:
            iBs.removeSymbol( "-" )
    ## Compute mean per cent identity for MSA. 
    # First sequence in MSA is considered as reference sequence. 
    def computeMeanPcentIdentity(self):
        seqRef = self.db[0]
        sumPcentIdentity = 0

        for seq in self.db[1:]:
            pcentIdentity = self._computePcentIdentityBetweenSeqRefAndCurrentSeq(seqRef, seq) 
            sumPcentIdentity = sumPcentIdentity + pcentIdentity
        nbSeq = len(self.db[1:])
        meanPcentIdentity = round (sumPcentIdentity/nbSeq)
        return meanPcentIdentity

    def _computePcentIdentityBetweenSeqRefAndCurrentSeq(self, seqRef, seq):
            indexOnSeqRef = 0
            sumIdentity = 0
            for nuclSeq in seq.sequence:
                nuclRef = seqRef.sequence[indexOnSeqRef]
                if nuclRef != "-" and nuclRef == nuclSeq:
                    sumIdentity = sumIdentity + 1
                indexOnSeqRef = indexOnSeqRef + 1   
            return float(sumIdentity) / float(seqRef.getLength()) * 100