/*
 * Decompiled with CFR 0.152.
 */
package umontreal.iro.lecuyer.probdistmulti;

import umontreal.iro.lecuyer.functions.MathFunction;
import umontreal.iro.lecuyer.probdistmulti.DiscreteDistributionIntMulti;
import umontreal.iro.lecuyer.util.Num;
import umontreal.iro.lecuyer.util.RootFinder;

public class NegativeMultinomialDist
extends DiscreteDistributionIntMulti {
    protected double gamma;
    protected double[] p;

    public NegativeMultinomialDist(double gamma, double[] p) {
        this.setParams(gamma, p);
    }

    public double prob(int[] x) {
        return NegativeMultinomialDist.prob_(this.gamma, this.p, x);
    }

    public double[] getMean() {
        return NegativeMultinomialDist.getMean_(this.gamma, this.p);
    }

    public double[][] getCovariance() {
        return NegativeMultinomialDist.getCovariance_(this.gamma, this.p);
    }

    public double[][] getCorrelation() {
        return NegativeMultinomialDist.getCorrelation_(this.gamma, this.p);
    }

    private static void verifParam(double gamma, double[] p) {
        double sumPi = 0.0;
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        for (int i = 0; i < p.length; ++i) {
            if (p[i] < 0.0 || p[i] >= 1.0) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            sumPi += p[i];
        }
        if (sumPi >= 1.0) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }

    private static double prob_(double gamma, double[] p, int[] x) {
        double p0 = 0.0;
        double sumPi = 0.0;
        double sumXi = 0.0;
        double sumLnXiFact = 0.0;
        double sumXiLnPi = 0.0;
        if (x.length != p.length) {
            throw new IllegalArgumentException("x and p must have the same size");
        }
        for (int i = 0; i < p.length; ++i) {
            sumPi += p[i];
            sumXi += (double)x[i];
            sumLnXiFact += Num.lnFactorial(x[i]);
            sumXiLnPi += (double)x[i] * Math.log(p[i]);
        }
        p0 = 1.0 - sumPi;
        return Math.exp(Num.lnGamma(gamma + sumXi) - (Num.lnGamma(gamma) + sumLnXiFact) + gamma * Math.log(p0) + sumXiLnPi);
    }

    public static double prob(double gamma, double[] p, int[] x) {
        NegativeMultinomialDist.verifParam(gamma, p);
        return NegativeMultinomialDist.prob_(gamma, p, x);
    }

    private static double cdf_(double gamma, double[] p, int[] x) {
        throw new UnsupportedOperationException("cdf not implemented");
    }

    public static double cdf(double gamma, double[] p, int[] x) {
        NegativeMultinomialDist.verifParam(gamma, p);
        return NegativeMultinomialDist.cdf_(gamma, p, x);
    }

    private static double[] getMean_(double gamma, double[] p) {
        int i;
        double p0 = 0.0;
        double sumPi = 0.0;
        double[] mean = new double[p.length];
        for (i = 0; i < p.length; ++i) {
            sumPi += p[i];
        }
        p0 = 1.0 - sumPi;
        for (i = 0; i < p.length; ++i) {
            mean[i] = gamma * p[i] / p0;
        }
        return mean;
    }

    public static double[] getMean(double gamma, double[] p) {
        NegativeMultinomialDist.verifParam(gamma, p);
        return NegativeMultinomialDist.getMean_(gamma, p);
    }

    private static double[][] getCovariance_(double gamma, double[] p) {
        int i;
        double p0 = 0.0;
        double sumPi = 0.0;
        double[][] cov = new double[p.length][p.length];
        for (i = 0; i < p.length; ++i) {
            sumPi += p[i];
        }
        p0 = 1.0 - sumPi;
        for (i = 0; i < p.length; ++i) {
            for (int j = 0; j < p.length; ++j) {
                cov[i][j] = gamma * p[i] * p[j] / (p0 * p0);
            }
            cov[i][i] = gamma * p[i] * (p[i] + p0) / (p0 * p0);
        }
        return cov;
    }

    public static double[][] getCovariance(double gamma, double[] p) {
        NegativeMultinomialDist.verifParam(gamma, p);
        return NegativeMultinomialDist.getCovariance_(gamma, p);
    }

    private static double[][] getCorrelation_(double gamma, double[] p) {
        int i;
        double[][] corr = new double[p.length][p.length];
        double sumPi = 0.0;
        for (i = 0; i < p.length; ++i) {
            sumPi += p[i];
        }
        double p0 = 1.0 - sumPi;
        for (i = 0; i < p.length; ++i) {
            for (int j = 0; j < p.length; ++j) {
                corr[i][j] = Math.sqrt(p[i] * p[j] / ((p0 + p[i]) * (p0 + p[j])));
            }
            corr[i][i] = 1.0;
        }
        return corr;
    }

    public static double[][] getCorrelation(double gamma, double[] p) {
        NegativeMultinomialDist.verifParam(gamma, p);
        return NegativeMultinomialDist.getCorrelation_(gamma, p);
    }

    @Deprecated
    public static double[] getMaximumLikelihoodEstimate(int[][] x, int n, int d) {
        return NegativeMultinomialDist.getMLE(x, n, d);
    }

    public static double[] getMLE(int[][] x, int n, int d) {
        int j;
        int i;
        double[] parameters = new double[d + 1];
        int[] ups = new int[n];
        double[] mean = new double[d];
        for (i = 0; i < d; ++i) {
            mean[i] = 0.0;
        }
        for (j = 0; j < n; ++j) {
            ups[j] = 0;
            for (i = 0; i < d; ++i) {
                int n2 = j;
                ups[n2] = ups[n2] + x[j][i];
                int n3 = i;
                mean[n3] = mean[n3] + (double)x[j][i];
            }
        }
        i = 0;
        while (i < d) {
            int n4 = i++;
            mean[n4] = mean[n4] / (double)n;
        }
        int M = ups[0];
        for (j = 1; j < n; ++j) {
            if (ups[j] <= M) continue;
            M = ups[j];
        }
        if (M >= Integer.MAX_VALUE) {
            throw new IllegalArgumentException("gamma/p_i too large");
        }
        double[] Fl = new double[M];
        for (int l = 0; l < M; ++l) {
            int prop = 0;
            for (j = 0; j < n; ++j) {
                if (ups[j] <= l) continue;
                ++prop;
            }
            Fl[l] = (double)prop / (double)n;
        }
        Function f = new Function(n, M, ups, Fl);
        parameters[0] = RootFinder.brentDekker(1.0E-9, 1.0E9, f, 1.0E-5);
        double[] lambda = new double[d];
        double sumLambda = 0.0;
        for (i = 0; i < d; ++i) {
            lambda[i] = mean[i] / parameters[0];
            sumLambda += lambda[i];
        }
        for (i = 0; i < d; ++i) {
            parameters[i + 1] = lambda[i] / (1.0 + sumLambda);
            if (!(parameters[i + 1] > 1.0)) continue;
            throw new IllegalArgumentException("p_i > 1");
        }
        return parameters;
    }

    public double getGamma() {
        return this.gamma;
    }

    public double[] getP() {
        return this.p;
    }

    public void setParams(double gamma, double[] p) {
        double sumPi = 0.0;
        if (gamma <= 0.0) {
            throw new IllegalArgumentException("gamma <= 0");
        }
        this.gamma = gamma;
        this.dimension = p.length;
        this.p = new double[this.dimension];
        for (int i = 0; i < this.dimension; ++i) {
            if (p[i] < 0.0 || p[i] >= 1.0) {
                throw new IllegalArgumentException("p is not a probability vector");
            }
            sumPi += p[i];
            this.p[i] = p[i];
        }
        if (sumPi >= 1.0) {
            throw new IllegalArgumentException("p is not a probability vector");
        }
    }

    private static class Function
    implements MathFunction {
        protected double[] Fl;
        protected int[] ups;
        protected int n;
        protected int M;
        protected int sumUps;

        public Function(int n, int m, int[] ups, double[] Fl) {
            this.n = n;
            this.M = m;
            this.Fl = new double[Fl.length];
            System.arraycopy(Fl, 0, this.Fl, 0, Fl.length);
            this.ups = new int[ups.length];
            System.arraycopy(ups, 0, this.ups, 0, ups.length);
            this.sumUps = 0;
            for (int i = 0; i < ups.length; ++i) {
                this.sumUps += ups[i];
            }
        }

        public double evaluate(double gamma) {
            double sum = 0.0;
            for (int l = 0; l < this.M; ++l) {
                sum += this.Fl[l] / (gamma + (double)l);
            }
            return sum - Math.log1p((double)this.sumUps / ((double)this.n * gamma));
        }
    }
}

