/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractVariableLengthDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.TreeMap;

public class CyclicMarkovModelDiffSM
extends AbstractVariableLengthDiffSM
implements SamplingDifferentiableStatisticalModel {
    private boolean freeParams;
    private boolean plugIn;
    private boolean optimize;
    private boolean optimizeFrame;
    private int order;
    private int period;
    private int starts;
    private int initFrame;
    private int[] powers;
    private double logGammaSum;
    private double[] frameHyper;
    private double[][][] params;
    private double[][][] probs;
    private double[][][] logNorm;
    private double[] frameLogScore;
    private double[] frameParams;
    private double[] frameProbs;
    private double[] partDer;
    private double[][][] hyper;
    private double logFrameNorm;
    private int[][][] counter;
    private int[][][] distCounter;
    private int[][] offset;

    public static double[][][] getHyperParams(int alphabetSize, int length, double ess, double[] frameProb, double[][][] prob) {
        if (alphabetSize <= 0) {
            throw new IllegalArgumentException();
        }
        if (length <= 0) {
            throw new IllegalArgumentException();
        }
        if (ess <= 0.0) {
            throw new IllegalArgumentException();
        }
        if (frameProb.length != prob.length) {
            throw new IllegalArgumentException();
        }
        int order = prob[0].length - 1;
        double[][][] hyper = new double[prob.length][prob[0].length][];
        double[] currentHyper = new double[(int)Math.pow(alphabetSize, order)];
        double[] nextHyper = new double[currentHyper.length];
        int startFrame = 0;
        while (startFrame < frameProb.length) {
            currentHyper[0] = ess * frameProb[startFrame];
            int l = 0;
            int f = startFrame;
            while (l < length) {
                int ord = Math.min(order, l);
                int y = (int)Math.pow(alphabetSize, ord);
                if (hyper[f][ord] == null) {
                    hyper[f][ord] = new double[y * alphabetSize];
                }
                int x = 0;
                while (x < y) {
                    int a = 0;
                    while (a < alphabetSize) {
                        int v = x * alphabetSize + a;
                        double[] dArray = hyper[f][ord];
                        int n = v;
                        dArray[n] = dArray[n] + currentHyper[x] * prob[f][ord][v];
                        int n2 = v % nextHyper.length;
                        nextHyper[n2] = nextHyper[n2] + currentHyper[x] * prob[f][ord][v];
                        ++a;
                    }
                    ++x;
                }
                f = (f + 1) % frameProb.length;
                System.arraycopy(nextHyper, 0, currentHyper, 0, currentHyper.length);
                Arrays.fill(nextHyper, 0.0);
                ++l;
            }
            ++startFrame;
        }
        return hyper;
    }

    private static double[][][] getHyper(int period, int alphabetSize, double[] sumOfHyper) {
        double[][][] hyper = new double[period][sumOfHyper.length][];
        int o = 0;
        while (o < sumOfHyper.length) {
            int a = (int)Math.pow(alphabetSize, o + 1);
            double h = sumOfHyper[o] / (double)period / (double)a;
            int p = 0;
            while (p < period) {
                hyper[p][o] = new double[a];
                Arrays.fill(hyper[p][o], h);
                ++p;
            }
            ++o;
        }
        return hyper;
    }

    private static double[] getHyper(int period, double ess) {
        double[] hyper = new double[period];
        Arrays.fill(hyper, ess / (double)period);
        return hyper;
    }

    public CyclicMarkovModelDiffSM(AlphabetContainer alphabets, int order, int period, double classEss, double[] sumOfHyperParams, boolean plugIn, boolean optimize, int starts, int initFrame) {
        this(alphabets, CyclicMarkovModelDiffSM.getHyper(period, classEss), CyclicMarkovModelDiffSM.getHyper(period, (int)alphabets.getAlphabetLengthAt(0), sumOfHyperParams), plugIn, optimize, starts, initFrame);
    }

    public CyclicMarkovModelDiffSM(AlphabetContainer alphabets, double[] frameHyper, double[][][] hyper, boolean plugIn, boolean optimize, int starts, int initFrame) {
        super(alphabets);
        this.order = hyper[0].length - 1;
        this.period = frameHyper.length;
        this.createArrays();
        this.frameHyper = new double[this.period];
        this.hyper = new double[this.period][this.order + 1][];
        int p = 0;
        while (p < this.period) {
            if (frameHyper[p] < 0.0) {
                throw new IllegalArgumentException("The ess for the class has to be non-negative.");
            }
            this.frameHyper[p] = frameHyper[p];
            int o = 0;
            while (o <= this.order) {
                if ((double)hyper[p][o].length != Math.pow(this.powers[1], o + 1)) {
                    throw new IllegalArgumentException();
                }
                this.hyper[p][o] = new double[hyper[p][o].length];
                int i = 0;
                while (i < this.hyper[p][o].length) {
                    if (hyper[p][o][i] < 0.0) {
                        throw new IllegalArgumentException("The ess for the class has to be non-negative.");
                    }
                    this.hyper[p][o][i] = hyper[p][o][i];
                    ++i;
                }
                ++o;
            }
            ++p;
        }
        this.frameParams = new double[this.period];
        Arrays.fill(this.frameParams, -Math.log(this.period));
        Arrays.fill(this.frameProbs, 1.0 / (double)this.period);
        this.params = new double[this.period][this.order + 1][];
        double uniform = 1.0 / (double)this.powers[1];
        double logUniform = Math.log(uniform);
        int p2 = 0;
        while (p2 < this.period) {
            int i = 0;
            while (i <= this.order) {
                this.params[p2][i] = new double[this.powers[i + 1]];
                this.probs[p2][i] = new double[this.powers[i + 1]];
                this.logNorm[p2][i] = new double[this.powers[i]];
                Arrays.fill(this.params[p2][i], logUniform);
                Arrays.fill(this.probs[p2][i], uniform);
                ++i;
            }
            ++p2;
        }
        this.plugIn = plugIn;
        this.optimize = optimize;
        this.optimizeFrame = true;
        if (starts <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = starts;
        this.setFreeParams(false);
        this.computeConstantsOfLogPrior();
        if (initFrame >= this.period) {
            throw new IllegalArgumentException("Check initFrame.");
        }
        this.initFrame = initFrame;
    }

    public CyclicMarkovModelDiffSM(StringBuffer source) throws NonParsableException {
        super(source);
    }

    private void createArrays() {
        this.powers = new int[this.order + 2];
        this.powers[0] = 1;
        this.powers[1] = (int)this.alphabets.getAlphabetLengthAt(0);
        int i = 2;
        while (i < this.powers.length) {
            this.powers[i] = this.powers[i - 1] * this.powers[1];
            ++i;
        }
        this.probs = new double[this.period][this.order + 1][];
        this.logNorm = new double[this.period][this.order + 1][];
        this.frameProbs = new double[this.period];
        this.counter = new int[this.period][this.period][this.powers[this.order + 1]];
        this.distCounter = new int[this.period][this.period][this.powers[this.order]];
        this.offset = new int[this.period][this.order + 2];
        this.frameLogScore = new double[this.period];
        this.partDer = new double[this.powers[1]];
    }

    @Override
    public CyclicMarkovModelDiffSM clone() throws CloneNotSupportedException {
        CyclicMarkovModelDiffSM clone = (CyclicMarkovModelDiffSM)super.clone();
        clone.frameHyper = (double[])this.frameHyper.clone();
        clone.hyper = (double[][][])ArrayHandler.clone((Cloneable[])this.hyper);
        clone.params = new double[this.period][this.order + 1][];
        clone.probs = new double[this.period][this.order + 1][];
        clone.logNorm = new double[this.period][this.order + 1][];
        clone.offset = new int[this.period][];
        clone.counter = new int[this.period][this.period][];
        clone.distCounter = new int[this.period][this.period][];
        int p = 0;
        while (p < this.period) {
            int o = 0;
            while (o <= this.order) {
                clone.params[p][o] = (double[])this.params[p][o].clone();
                clone.probs[p][o] = (double[])this.probs[p][o].clone();
                clone.logNorm[p][o] = (double[])this.logNorm[p][o].clone();
                ++o;
            }
            clone.offset[p] = (int[])this.offset[p].clone();
            o = 0;
            while (o < this.period) {
                clone.counter[p][o] = (int[])this.counter[p][o].clone();
                clone.distCounter[p][o] = (int[])this.distCounter[p][o].clone();
                ++o;
            }
            ++p;
        }
        clone.frameParams = (double[])this.frameParams.clone();
        clone.frameProbs = (double[])this.frameProbs.clone();
        clone.frameLogScore = (double[])this.frameLogScore.clone();
        clone.partDer = (double[])this.partDer.clone();
        return clone;
    }

    @Override
    public String getInstanceName() {
        return "cMM(" + this.order + ", " + this.period + ")";
    }

    private void fillFrameLogScores(Sequence seq, int start, int length) {
        int indexOld;
        int l = 0;
        int indexNew = 0;
        int o = Math.min(this.order, length);
        int p = 0;
        while (p < this.period) {
            this.frameLogScore[p] = this.frameParams[p];
            ++p;
        }
        while (l < o) {
            indexOld = indexNew;
            indexNew = indexOld * this.powers[1] + seq.discreteVal(start++);
            p = 0;
            while (p < this.period) {
                int n = p;
                this.frameLogScore[n] = this.frameLogScore[n] + (this.params[(p + l) % this.period][l][indexNew] - this.logNorm[(p + l) % this.period][l][indexOld]);
                ++p;
            }
            ++l;
        }
        while (l < length) {
            indexOld = indexNew % this.powers[this.order];
            indexNew = indexOld * this.powers[1] + seq.discreteVal(start++);
            p = 0;
            while (p < this.period) {
                int n = p;
                this.frameLogScore[n] = this.frameLogScore[n] + (this.params[(p + l) % this.period][this.order][indexNew] - this.logNorm[(p + l) % this.period][this.order][indexOld]);
                ++p;
            }
            ++l;
        }
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        this.fillFrameLogScores(seq, start, end - start + 1);
        return Normalisation.getLogSum(this.frameLogScore) - this.logFrameNorm;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int end, IntList indices, DoubleList dList) {
        int length = end - start + 1;
        if (this.optimize) {
            int index;
            int h;
            int indexOld;
            int z;
            int l = 0;
            int indexNew = 0;
            int o = Math.min(this.order, length);
            int p = 0;
            while (p < this.period) {
                z = 0;
                while (z < this.period) {
                    Arrays.fill(this.counter[p][z], 0);
                    Arrays.fill(this.distCounter[p][z], 0);
                    ++z;
                }
                this.frameLogScore[p] = this.frameParams[p];
                ++p;
            }
            int stop = this.powers[1] - (this.freeParams ? 1 : 0);
            while (l < o) {
                indexOld = indexNew;
                z = indexOld * this.powers[1];
                indexNew = z + seq.discreteVal(start++);
                h = z - (this.freeParams ? indexOld : 0);
                p = 0;
                while (p < this.period) {
                    int n = p;
                    this.frameLogScore[n] = this.frameLogScore[n] + (this.params[(p + l) % this.period][l][indexNew] - this.logNorm[(p + l) % this.period][l][indexOld]);
                    index = 0;
                    while (index < stop) {
                        indices.add(this.offset[p][l] + h + index);
                        if (z + index == indexNew) {
                            dList.add(1.0 - this.probs[p][l][z + index]);
                        } else {
                            dList.add(-this.probs[p][l][z + index]);
                        }
                        ++index;
                    }
                    ++p;
                }
                ++l;
            }
            while (l < length) {
                indexOld = indexNew % this.powers[this.order];
                indexNew = indexOld * this.powers[1] + seq.discreteVal(start++);
                p = 0;
                while (p < this.period) {
                    int n = p;
                    this.frameLogScore[n] = this.frameLogScore[n] + (this.params[(p + l) % this.period][this.order][indexNew] - this.logNorm[(p + l) % this.period][this.order][indexOld]);
                    int[] nArray = this.distCounter[p][(p + l) % this.period];
                    int n2 = indexOld;
                    nArray[n2] = nArray[n2] + 1;
                    int[] nArray2 = this.counter[p][(p + l) % this.period];
                    int n3 = indexNew;
                    nArray2[n3] = nArray2[n3] + 1;
                    ++p;
                }
                ++l;
            }
            double erg = Normalisation.logSumNormalisation(this.frameLogScore, 0, this.period, this.frameLogScore, 0) - this.logFrameNorm;
            indexOld = 0;
            l = 0;
            while (l < o) {
                p = 0;
                while (p < this.period) {
                    indexNew = indexOld + stop;
                    dList.multiply(indexOld, indexNew, this.frameLogScore[p]);
                    indexOld = indexNew;
                    ++p;
                }
                ++l;
            }
            l = 0;
            while (l < this.distCounter[0][0].length) {
                h = l * (this.powers[1] - (this.freeParams ? 1 : 0));
                o = l * this.powers[1];
                z = 0;
                while (z < this.period) {
                    Arrays.fill(this.partDer, 0.0);
                    boolean used = false;
                    p = 0;
                    while (p < this.period) {
                        if (this.distCounter[p][z][l] > 0) {
                            used = true;
                            index = 0;
                            while (index < stop) {
                                int n = index;
                                this.partDer[n] = this.partDer[n] + this.frameLogScore[p] * ((double)this.counter[p][z][o + index] - (double)this.distCounter[p][z][l] * this.probs[z][this.order][o + index]);
                                ++index;
                            }
                        }
                        ++p;
                    }
                    if (used) {
                        index = 0;
                        while (index < stop) {
                            indices.add(this.offset[z][this.order] + h + index);
                            dList.add(this.partDer[index]);
                            ++index;
                        }
                    }
                    ++z;
                }
                ++l;
            }
            if (this.optimizeFrame) {
                p = 0;
                while (p < this.period - (this.freeParams ? 1 : 0)) {
                    indices.add(p);
                    dList.add(this.frameLogScore[p] - this.frameProbs[p]);
                    ++p;
                }
            }
            return erg;
        }
        return this.getLogScoreFor(seq, start, end);
    }

    @Override
    public int getNumberOfParameters() {
        return this.offset[this.period - 1][this.order + 1];
    }

    @Override
    public void setParameters(double[] params, int start) {
        if (this.optimize) {
            int p;
            int stop;
            if (this.optimizeFrame) {
                stop = this.period - (this.freeParams ? 1 : 0);
                this.logFrameNorm = 0.0;
                p = 0;
                while (p < stop) {
                    this.frameParams[p] = params[start++];
                    this.frameProbs[p] = Math.exp(this.frameParams[p]);
                    this.logFrameNorm += this.frameProbs[p];
                    ++p;
                }
                if (stop < this.period) {
                    this.frameProbs[p] = Math.exp(this.frameParams[p]);
                    this.logFrameNorm += this.frameProbs[p];
                }
                p = 0;
                while (p < this.period) {
                    int n = p++;
                    this.frameProbs[n] = this.frameProbs[n] / this.logFrameNorm;
                }
                this.logFrameNorm = Math.log(this.logFrameNorm);
            }
            stop = this.powers[1] - (this.freeParams ? 1 : 0);
            p = 0;
            while (p < this.period) {
                int o = 0;
                while (o <= this.order) {
                    int n = 0;
                    int index = 0;
                    while (n < this.logNorm[p][o].length) {
                        this.logNorm[p][o][n] = 0.0;
                        int j = 0;
                        while (j < stop) {
                            this.params[p][o][index + j] = params[start];
                            this.probs[p][o][index + j] = Math.exp(this.params[p][o][index + j]);
                            double[] dArray = this.logNorm[p][o];
                            int n2 = n;
                            dArray[n2] = dArray[n2] + this.probs[p][o][index + j];
                            ++j;
                            ++start;
                        }
                        if (j < this.powers[1]) {
                            this.probs[p][o][index + j] = Math.exp(this.params[p][o][index + j]);
                            double[] dArray = this.logNorm[p][o];
                            int n3 = n;
                            dArray[n3] = dArray[n3] + this.probs[p][o][index + j];
                        }
                        j = 0;
                        while (j < this.powers[1]) {
                            double[] dArray = this.probs[p][o];
                            int n4 = index++;
                            dArray[n4] = dArray[n4] / this.logNorm[p][o][n];
                            ++j;
                        }
                        this.logNorm[p][o][n] = Math.log(this.logNorm[p][o][n]);
                        ++n;
                    }
                    ++o;
                }
                ++p;
            }
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer b = new StringBuffer(10000);
        XMLParser.appendObjectWithTags(b, this.length, "length");
        XMLParser.appendObjectWithTags(b, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(b, this.order, "order");
        XMLParser.appendObjectWithTags(b, this.period, "period");
        XMLParser.appendObjectWithTags(b, this.frameHyper, "frameEss");
        XMLParser.appendObjectWithTags(b, this.hyper, "hyper");
        XMLParser.appendObjectWithTags(b, this.frameParams, "frameParams");
        int p = 0;
        while (p < this.period) {
            XMLParser.appendObjectWithTagsAndAttributes(b, this.params[p], "params", "frame=\"" + p + "\"");
            ++p;
        }
        XMLParser.appendObjectWithTags(b, this.plugIn, "plugIn");
        XMLParser.appendObjectWithTags(b, this.optimize, "optimize");
        XMLParser.appendObjectWithTags(b, this.optimizeFrame, "optimizeFrame");
        XMLParser.appendObjectWithTags(b, this.starts, "starts");
        XMLParser.appendObjectWithTags(b, this.freeParams, "freeParams");
        XMLParser.appendObjectWithTags(b, this.initFrame, "initFrame");
        XMLParser.addTags(b, this.getClass().getSimpleName());
        return b;
    }

    @Override
    public double[] getCurrentParameterValues() {
        int l = this.optimize ? this.offset[this.period - 1][this.order + 1] : 0;
        double[] erg = new double[l];
        if (this.optimize) {
            int p;
            int stop;
            int i = 0;
            if (this.optimizeFrame) {
                stop = this.period - (this.freeParams ? 1 : 0);
                p = 0;
                while (p < stop) {
                    erg[i] = this.frameParams[p];
                    ++p;
                    ++i;
                }
            }
            stop = this.powers[1] - (this.freeParams ? 1 : 0);
            p = 0;
            while (p < this.period) {
                int o = 0;
                while (o <= this.order) {
                    int index = 0;
                    while (index < this.params[p][o].length) {
                        int j = 0;
                        while (j < stop) {
                            erg[i] = this.params[p][o][index + j];
                            ++j;
                            ++i;
                        }
                        index += this.powers[1];
                    }
                    ++o;
                }
                ++p;
            }
        }
        return erg;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) {
        if (this.optimize && this.plugIn && data != null && data[index] != null) {
            int anz = data[index].getNumberOfElements();
            double w = 1.0;
            boolean externalWeights = weights != null && weights[index] != null;
            int rMax = this.initFrame >= 0 ? 1 : 3;
            double[][] frameP = new double[anz][this.period];
            if (this.initFrame < 0) {
                this.initializeFunctionRandomly(freeParams);
            }
            int r = 0;
            while (r < rMax) {
                int indexOld;
                int indexNew;
                int len;
                int o;
                int i = 0;
                while (i < anz) {
                    if (this.initFrame < 0) {
                        this.fillFrameLogScores(data[index].getElementAt(i), 0, this.length);
                        Normalisation.logSumNormalisation(this.frameLogScore, 0, this.period, frameP[i], 0);
                    } else {
                        Arrays.fill(frameP[i], 0.0);
                        frameP[i][this.initFrame] = 1.0;
                    }
                    ++i;
                }
                int p = 0;
                while (p < this.period) {
                    o = 0;
                    while (o <= this.order) {
                        System.arraycopy(this.hyper[p][o], 0, this.params[p][o], 0, this.hyper[p][o].length);
                        int idx = 0;
                        int n = 0;
                        while (idx < this.params[p][o].length) {
                            this.logNorm[p][o][n] = 0.0;
                            int j = 0;
                            while (j < this.powers[1]) {
                                double[] dArray = this.logNorm[p][o];
                                int n2 = n;
                                dArray[n2] = dArray[n2] + this.hyper[p][o][idx + j];
                                ++j;
                            }
                            ++n;
                            idx += this.powers[1];
                        }
                        ++o;
                    }
                    ++p;
                }
                if (this.optimizeFrame) {
                    System.arraycopy(this.frameHyper, 0, this.frameProbs, 0, this.period);
                    this.logFrameNorm = this.getESS();
                }
                int i2 = 0;
                while (i2 < anz) {
                    Sequence seq = data[index].getElementAt(i2);
                    len = seq.getLength();
                    o = Math.min(len, this.order);
                    indexNew = 0;
                    if (externalWeights) {
                        w = weights[index][i2];
                    }
                    if (this.optimizeFrame) {
                        p = 0;
                        while (p < this.period) {
                            double[] dArray = frameP[i2];
                            int n = p;
                            dArray[n] = dArray[n] * w;
                            int n3 = p;
                            this.frameProbs[n3] = this.frameProbs[n3] + frameP[i2][p];
                            ++p;
                        }
                        this.logFrameNorm += w;
                    }
                    int l = 0;
                    while (l < o) {
                        indexOld = indexNew;
                        indexNew = indexOld * this.powers[1] + seq.discreteVal(l);
                        p = 0;
                        while (p < this.period) {
                            double[] dArray = this.probs[(p + l) % this.period][l];
                            int n = indexNew;
                            dArray[n] = dArray[n] + frameP[i2][p];
                            double[] dArray2 = this.logNorm[(p + l) % this.period][l];
                            int n4 = indexOld;
                            dArray2[n4] = dArray2[n4] + frameP[i2][p];
                            ++p;
                        }
                        ++l;
                    }
                    while (l < len) {
                        indexOld = indexNew % this.powers[this.order];
                        indexNew = indexOld * this.powers[1] + seq.discreteVal(l);
                        p = 0;
                        while (p < this.period) {
                            double[] dArray = this.probs[(p + l) % this.period][this.order];
                            int n = indexNew;
                            dArray[n] = dArray[n] + frameP[i2][p];
                            double[] dArray3 = this.logNorm[(p + l) % this.period][this.order];
                            int n5 = indexOld;
                            dArray3[n5] = dArray3[n5] + frameP[i2][p];
                            ++p;
                        }
                        ++l;
                    }
                    ++i2;
                }
                if (this.optimizeFrame) {
                    p = 0;
                    while (p < this.period) {
                        int n = p;
                        this.frameProbs[n] = this.frameProbs[n] / this.logFrameNorm;
                        this.frameParams[p] = Math.log(this.frameProbs[p]);
                        ++p;
                    }
                    this.logFrameNorm = 0.0;
                }
                p = 0;
                while (p < this.period) {
                    o = 0;
                    while (o <= this.order) {
                        indexNew = 0;
                        indexOld = 0;
                        while (indexOld < this.logNorm[p][o].length) {
                            len = 0;
                            while (len < this.powers[1]) {
                                double[] dArray = this.probs[p][o];
                                int n = indexNew;
                                dArray[n] = dArray[n] / this.logNorm[p][o][indexOld];
                                this.params[p][o][indexNew] = Math.log(this.probs[p][o][indexNew]);
                                ++len;
                                ++indexNew;
                            }
                            this.logNorm[p][o][indexOld] = 0.0;
                            ++indexOld;
                        }
                        ++o;
                    }
                    ++p;
                }
                ++r;
            }
        } else {
            this.initializeFunctionRandomly(freeParams);
        }
        this.setFreeParams(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) {
        if (this.optimize) {
            int p;
            double[] freq;
            DirichletMRGParams hyper;
            double[] currentHyper = new double[this.powers[1]];
            if (this.optimizeFrame) {
                hyper = new DirichletMRGParams(this.frameHyper);
                freq = DirichletMRG.DEFAULT_INSTANCE.generate(this.period, hyper);
                p = 0;
                while (p < this.period) {
                    this.frameProbs[p] = freq[p];
                    this.frameParams[p] = Math.log(freq[p]);
                    ++p;
                }
                this.logFrameNorm = 0.0;
            }
            freq = new double[this.powers[1]];
            p = 0;
            while (p < this.period) {
                int o = 0;
                while (o <= this.order) {
                    int paramCounter = 0;
                    int normCounter = 0;
                    while (normCounter < this.logNorm[p][o].length) {
                        this.logNorm[p][o][normCounter] = 0.0;
                        int len = 0;
                        while (len < this.powers[1]) {
                            currentHyper[len] = this.hyper[p][o][paramCounter + len];
                            ++len;
                        }
                        hyper = new DirichletMRGParams(currentHyper);
                        DirichletMRG.DEFAULT_INSTANCE.generate(freq, 0, this.powers[1], hyper);
                        len = 0;
                        while (len < this.powers[1]) {
                            this.probs[p][o][paramCounter] = freq[len];
                            this.params[p][o][paramCounter] = Math.log(freq[len]);
                            ++len;
                            ++paramCounter;
                        }
                        ++normCounter;
                    }
                    ++o;
                }
                ++p;
            }
            this.setFreeParams(freeParams);
        }
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer b = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.length = XMLParser.extractObjectForTags(b, "length", Integer.TYPE);
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(b, "alphabets");
        this.order = XMLParser.extractObjectForTags(b, "order", Integer.TYPE);
        this.period = XMLParser.extractObjectForTags(b, "period", Integer.TYPE);
        this.createArrays();
        StringBuffer help = XMLParser.extractForTag(b, "frameEss");
        if (help == null) {
            this.frameHyper = CyclicMarkovModelDiffSM.getHyper(this.period, XMLParser.extractObjectForTags(b, "classEss", Double.TYPE) / (double)this.period);
            this.hyper = CyclicMarkovModelDiffSM.getHyper(this.period, this.powers[1], XMLParser.extractObjectForTags(b, "sumOfHyperParams", double[].class));
        } else {
            XMLParser.addTags(help, "frameEss");
            this.frameHyper = (double[])XMLParser.extractObjectForTags(help, "frameEss");
            this.hyper = (double[][][])XMLParser.extractObjectForTags(b, "hyper");
        }
        this.frameParams = XMLParser.extractObjectForTags(b, "frameParams", double[].class);
        this.logFrameNorm = 0.0;
        int p = 0;
        this.params = new double[this.period][][];
        TreeMap<String, String> map = new TreeMap<String, String>();
        while (p < this.period) {
            this.frameProbs[p] = Math.exp(this.frameParams[p]);
            this.logFrameNorm += this.frameProbs[p];
            map.clear();
            map.put("frame", "" + p);
            this.params[p] = XMLParser.extractObjectAndAttributesForTags(b, "params", null, map, double[][].class);
            int o = 0;
            while (o <= this.order) {
                this.probs[p][o] = new double[this.params[p][o].length];
                this.logNorm[p][o] = new double[this.powers[o]];
                int index = 0;
                int n = 0;
                while (n < this.logNorm[p][o].length) {
                    this.logNorm[p][o][n] = 0.0;
                    int j = 0;
                    while (j < this.powers[1]) {
                        this.probs[p][o][index + j] = Math.exp(this.params[p][o][index + j]);
                        double[] dArray = this.logNorm[p][o];
                        int n2 = n;
                        dArray[n2] = dArray[n2] + this.probs[p][o][index + j];
                        ++j;
                    }
                    j = 0;
                    while (j < this.powers[1]) {
                        double[] dArray = this.probs[p][o];
                        int n3 = index++;
                        dArray[n3] = dArray[n3] / this.logNorm[p][o][n];
                        ++j;
                    }
                    this.logNorm[p][o][n] = Math.log(this.logNorm[p][o][n]);
                    ++n;
                }
                ++o;
            }
            ++p;
        }
        p = 0;
        while (p < this.period) {
            int n = p++;
            this.frameProbs[n] = this.frameProbs[n] / this.logFrameNorm;
        }
        this.logFrameNorm = Math.log(this.logFrameNorm);
        this.plugIn = XMLParser.extractObjectForTags(b, "plugIn", Boolean.TYPE);
        this.optimize = XMLParser.extractObjectForTags(b, "optimize", Boolean.TYPE);
        this.optimizeFrame = XMLParser.extractObjectForTags(b, "optimizeFrame", Boolean.TYPE);
        this.starts = XMLParser.extractObjectForTags(b, "starts", Integer.TYPE);
        this.setFreeParams(XMLParser.extractObjectForTags(b, "freeParams", Boolean.TYPE));
        this.initFrame = XMLParser.extractObjectForTags(b, "initFrame", Integer.TYPE);
        this.computeConstantsOfLogPrior();
    }

    private void setFreeParams(boolean freeParams) {
        this.freeParams = freeParams;
        if (this.optimize) {
            int p = 0;
            while (p < this.period) {
                if (p == 0) {
                    this.offset[0][0] = this.optimizeFrame ? this.period - (freeParams ? 1 : 0) : 0;
                } else {
                    this.offset[p][0] = this.offset[p - 1][this.order + 1];
                }
                int o = 0;
                while (o <= this.order) {
                    this.offset[p][o + 1] = this.offset[p][o] + this.params[p][o].length - (freeParams ? this.powers[o] : 0);
                    ++o;
                }
                ++p;
            }
        } else {
            this.offset[this.period - 1][this.order + 1] = 0;
        }
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        if (index < this.offset[0][0]) {
            return this.period;
        }
        int p = 0;
        while (index >= this.offset[p][this.order + 1]) {
            ++p;
        }
        int o = 1;
        while (index >= this.offset[p][o]) {
            ++o;
        }
        return this.powers[o - 1];
    }

    @Override
    public double getLogNormalizationConstant(int length) {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        if (parameterIndex < this.offset[this.period - 1][this.order + 1]) {
            return Double.NEGATIVE_INFINITY;
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double getESS() {
        double ess = 0.0;
        int f = 0;
        while (f < this.frameHyper.length) {
            ess += this.frameHyper[f];
            ++f;
        }
        return ess;
    }

    @Override
    public String toString(NumberFormat nf) {
        DiscreteAlphabet abc = (DiscreteAlphabet)this.alphabets.getAlphabetAt(0);
        int i = 0;
        int l = (int)abc.length();
        StringBuffer info = new StringBuffer((int)Math.pow(l, this.order) * l * this.period * 15);
        String[] sym = new String[l];
        --l;
        while (i <= l) {
            sym[i] = abc.getSymbolAt(i);
            ++i;
        }
        int[] context = new int[this.order + 1];
        int p = 0;
        while (p < this.period) {
            info.append("frame " + p + ": p(" + p + ") = " + nf.format(this.frameProbs[p]) + "\n");
            i = 0;
            while (i <= l) {
                info.append("\t" + sym[i]);
                ++i;
            }
            info.append("\n");
            int o = 0;
            while (o <= this.order) {
                info.append("P(X_" + o);
                i = 0;
                while (i < o) {
                    if (i == 0) {
                        info.append("|");
                    } else {
                        info.append(" ");
                    }
                    info.append("X_" + i);
                    ++i;
                }
                info.append(")\n");
                Arrays.fill(context, 0);
                int index = 0;
                while (index < this.probs[p][o].length) {
                    i = 0;
                    while (i < o) {
                        info.append(sym[context[i]]);
                        ++i;
                    }
                    i = 0;
                    while (i <= l) {
                        info.append("\t" + nf.format(this.probs[p][o][index]) + "\t(" + nf.format(this.hyper[p][o][index]) + ")");
                        ++i;
                        ++index;
                    }
                    info.append("\n");
                    i = o - 1;
                    while (i >= 0 && context[i] == l) {
                        context[i] = 0;
                        --i;
                    }
                    if (i < 0) continue;
                    int n = i;
                    context[n] = context[n] + 1;
                }
                info.append("\n");
                ++o;
            }
            ++p;
        }
        return info.toString();
    }

    @Override
    public double getLogPriorTerm() {
        if (this.optimize) {
            double classESS = this.getESS();
            double val = -classESS * this.logFrameNorm;
            int A = (int)this.alphabets.getAlphabetLengthAt(0);
            int f = 0;
            while (f < this.params.length) {
                val += this.frameParams[f] * this.frameHyper[f];
                int o = 0;
                while (o < this.params[f].length) {
                    double sum = 0.0;
                    int a = 0;
                    int i = 0;
                    while (i < this.params[f].length) {
                        val += this.params[f][o][i] * this.hyper[f][o][i];
                        sum += this.hyper[f][o][i];
                        if (++a == A) {
                            val -= this.logNorm[f][o][i / A] * sum;
                            sum = 0.0;
                            a = 0;
                        }
                        ++i;
                    }
                    ++o;
                }
                ++f;
            }
            return val + this.logGammaSum;
        }
        return 0.0;
    }

    private double getLogPriorTerm(int offset) {
        if (this.optimize) {
            double classESS = this.getESS();
            double val = -classESS * this.logFrameNorm;
            int A = (int)this.alphabets.getAlphabetLengthAt(0);
            int f = 0;
            while (f < this.params.length) {
                val += this.frameParams[f] * this.frameHyper[(f + offset) % this.period];
                int o = 0;
                while (o < this.params[f].length) {
                    double sum = 0.0;
                    int a = 0;
                    int i = 0;
                    while (i < this.params[f].length) {
                        val += this.params[f][o][i] * this.hyper[(f + offset) % this.period][o][i];
                        sum += this.hyper[(f + offset) % this.period][o][i];
                        if (++a == A) {
                            val -= this.logNorm[f][o][i / A] * sum;
                            sum = 0.0;
                            a = 0;
                        }
                        ++i;
                    }
                    ++o;
                }
                ++f;
            }
            return val + this.logGammaSum;
        }
        return 0.0;
    }

    private void computeConstantsOfLogPrior() {
        double classESS = this.getESS();
        double sum = 0.0;
        int A = (int)this.alphabets.getAlphabetLengthAt(0);
        this.logGammaSum = Gamma.logOfGamma(classESS);
        int f = 0;
        while (f < this.params.length) {
            this.logGammaSum -= Gamma.logOfGamma(this.frameHyper[f]);
            int o = 0;
            while (o < this.params[f].length) {
                int a = 0;
                int i = 0;
                while (i < this.params[f][o].length) {
                    this.logGammaSum -= Gamma.logOfGamma(this.hyper[f][o][i]);
                    sum += this.hyper[f][o][i];
                    if (++a == A) {
                        this.logGammaSum = Gamma.logOfGamma(sum);
                        sum = 0.0;
                        a = 0;
                    }
                    ++i;
                }
                ++o;
            }
            ++f;
        }
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) {
        if (this.optimize) {
            int p;
            int j;
            if (this.optimizeFrame) {
                double classESS = this.getESS();
                j = this.period - (this.freeParams ? 1 : 0);
                p = 0;
                while (p < j) {
                    int n = start++;
                    grad[n] = grad[n] + (this.frameHyper[p] - classESS * this.frameProbs[p]);
                    ++p;
                }
            }
            double sum = 0.0;
            int stop = this.powers[1] - (this.freeParams ? 1 : 0);
            p = 0;
            while (p < this.params.length) {
                int o = 0;
                while (o < this.params[p].length) {
                    int index = 0;
                    while (index < this.params[p][o].length) {
                        sum = 0.0;
                        j = 0;
                        while (j < this.powers[1]) {
                            sum += this.hyper[p][o][index + j];
                            ++j;
                        }
                        j = 0;
                        while (j < stop) {
                            int n = start++;
                            grad[n] = grad[n] + (this.hyper[p][o][index + j] - sum * this.probs[p][o][index + j]);
                            ++j;
                        }
                        index += this.powers[1];
                    }
                    ++o;
                }
                ++p;
            }
        }
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    public void setParameterOptimization(boolean optimize) {
        this.optimize = optimize;
        this.setFreeParams(this.freeParams);
    }

    public void setFrameParameterOptimization(boolean optimize) {
        this.optimizeFrame = optimize;
        this.setFreeParams(this.freeParams);
    }

    @Override
    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
        throw new RuntimeException();
    }

    @Override
    public int[][] getSamplingGroups(int parameterOffset) {
        if (this.optimize) {
            int j;
            int[][] res = new int[this.optimizeFrame ? 1 : 0 + this.offset[this.period - 1][this.order + 1] % this.powers[1]][];
            int i = 0;
            if (this.optimizeFrame) {
                res[i] = new int[this.period];
                j = 0;
                while (j < res[i].length) {
                    res[i][j] = parameterOffset++;
                    ++j;
                }
                ++i;
            }
            while (i < res.length) {
                res[i] = new int[this.powers[1]];
                j = 0;
                while (j < res[i].length) {
                    res[i][j] = parameterOffset++;
                    ++j;
                }
                ++i;
            }
            return res;
        }
        return new int[0][];
    }
}

