/*
 * Decompiled with CFR 0.152.
 */
package projects.crispr;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractVariableLengthDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousMMDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;

public class ConditionalHomMM
extends AbstractVariableLengthDiffSM {
    private HomogeneousMMDiffSM[][] condMMs;
    int[][] offsets;
    private boolean conditionalOnMatch;
    private double ess;

    public ConditionalHomMM(AlphabetContainer alphabets, int hMMOrder, boolean conditionalOnMatch, int length, double ess) throws IllegalArgumentException {
        super(alphabets, 0);
        this.ess = ess;
        this.conditionalOnMatch = conditionalOnMatch;
        this.condMMs = new HomogeneousMMDiffSM[conditionalOnMatch ? 2 : 1][(int)alphabets.getAlphabetLengthAt(0)];
        this.offsets = new int[conditionalOnMatch ? 2 : 1][(int)alphabets.getAlphabetLengthAt(0)];
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                this.condMMs[i][j] = new HomogeneousMMDiffSM(alphabets, hMMOrder, ess / (double)this.condMMs.length / (double)this.condMMs[i].length, length);
                ++j;
            }
            ++i;
        }
    }

    public ConditionalHomMM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public ConditionalHomMM clone() throws CloneNotSupportedException {
        ConditionalHomMM clone = (ConditionalHomMM)super.clone();
        clone.offsets = (int[][])ArrayHandler.clone((Cloneable[])this.offsets);
        clone.condMMs = (HomogeneousMMDiffSM[][])ArrayHandler.clone((Cloneable[])this.condMMs);
        return clone;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getLogPriorTerm() {
        double lp = 0.0;
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                lp += this.condMMs[i][j].getLogPriorTerm();
                ++j;
            }
            ++i;
        }
        return lp;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                this.condMMs[i][j].addGradientOfLogPriorTerm(grad, start);
                start += this.condMMs[i][j].getNumberOfParameters();
                ++j;
            }
            ++i;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int off = 0;
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                this.condMMs[i][j].initializeFunctionRandomly(freeParams);
                this.offsets[i][j] = off;
                off += this.condMMs[i][j].getNumberOfParameters();
                ++j;
            }
            ++i;
        }
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, start, seq.getLength() - 1, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, int end, IntList indices, DoubleList partialDer) {
        IntList tempIndices = new IntList();
        double ls = 0.0;
        int i = start;
        while (i <= end) {
            int order;
            tempIndices.clear();
            int grna = this.getGRNASymbol(seq, i);
            int cond = 0;
            if (this.conditionalOnMatch && i > 0 && this.getGRNASymbol(seq, i - 1) != seq.discreteVal(i - 1)) {
                cond = 1;
            }
            if ((order = this.condMMs[cond][grna].getMaximalMarkovOrder()) > i) {
                order = i;
            }
            if (order > 0 && i > 0) {
                ls += this.condMMs[cond][grna].getLogScoreAndPartialDerivation(seq, i - order, i, tempIndices, partialDer);
                int num = partialDer.length();
                ls -= this.condMMs[cond][grna].getLogScoreAndPartialDerivation(seq, i - order, i - 1, tempIndices, partialDer);
                partialDer.multiply(num, partialDer.length(), -1.0);
            } else {
                ls += this.condMMs[cond][grna].getLogScoreAndPartialDerivation(seq, i, i, tempIndices, partialDer);
            }
            int j = 0;
            while (j < tempIndices.length()) {
                indices.add(tempIndices.get(j) + this.offsets[cond][grna]);
                ++j;
            }
            ++i;
        }
        return ls;
    }

    @Override
    public int getNumberOfParameters() {
        int num = 0;
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                num += this.condMMs[i][j].getNumberOfParameters();
                ++j;
            }
            ++i;
        }
        return num;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] pars = new double[this.getNumberOfParameters()];
        int start = 0;
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                double[] temp = this.condMMs[i][j].getCurrentParameterValues();
                System.arraycopy(temp, 0, pars, start, temp.length);
                start += this.condMMs[i][j].getNumberOfParameters();
                ++j;
            }
            ++i;
        }
        return pars;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                this.condMMs[i][j].setParameters(params, start);
                start += this.condMMs[i][j].getNumberOfParameters();
                ++j;
            }
            ++i;
        }
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    private int getGRNASymbol(Sequence seq, int position) {
        Sequence grna = ((ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0)).getReferenceSequence();
        return grna.discreteVal(position);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.getLogScoreFor(seq, start, seq.getLength() - 1);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        double ls = 0.0;
        int i = start;
        while (i <= end) {
            int order;
            int grna = this.getGRNASymbol(seq, i);
            int cond = 0;
            if (this.conditionalOnMatch && i > 0 && this.getGRNASymbol(seq, i - 1) != seq.discreteVal(i - 1)) {
                cond = 1;
            }
            if ((order = this.condMMs[cond][grna].getMaximalMarkovOrder()) > i) {
                order = i;
            }
            ls = order > 0 && i > 0 ? (ls += this.condMMs[cond][grna].getLogScoreFor(seq, i - order, i) - this.condMMs[cond][grna].getLogScoreFor(seq, i - order, i - 1)) : (ls += this.condMMs[cond][grna].getLogScoreFor(seq, i, i));
            ++i;
        }
        return ls;
    }

    @Override
    public boolean isInitialized() {
        return this.condMMs[0][0].isInitialized();
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer();
        int i = 0;
        while (i < this.condMMs.length) {
            int j = 0;
            while (j < this.condMMs[i].length) {
                sb.append(String.valueOf(i) + ", " + j + "\n");
                sb.append(this.condMMs[i][j].toString(nf));
                ++j;
            }
            ++i;
        }
        return sb.toString();
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
    }

    @Override
    public double getLogNormalizationConstant(int length) {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex, int length) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public void setStatisticForHyperparameters(int[] length, double[] weight) throws Exception {
    }
}

