/*
 * Decompiled with CFR 0.152.
 */
package dr.math.distributions;

import dr.inference.model.GradientProvider;
import dr.inference.model.HessianProvider;
import dr.inference.model.Likelihood;
import dr.math.MathUtils;
import dr.math.distributions.GaussianProcessRandomGenerator;
import dr.math.distributions.MultivariateDistribution;
import dr.math.matrixAlgebra.CholeskyDecomposition;
import dr.math.matrixAlgebra.IllegalDimension;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.ReadableMatrix;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.SymmetricMatrix;
import dr.math.matrixAlgebra.Vector;
import dr.math.matrixAlgebra.WritableVector;
import java.util.Arrays;
import org.ejml.alg.dense.decomposition.TriangularSolver;
import org.ejml.alg.dense.decomposition.chol.CholeskyDecompositionInner_D64;
import org.ejml.data.DenseMatrix64F;

public class MultivariateNormalDistribution
implements MultivariateDistribution,
GaussianProcessRandomGenerator,
GradientProvider,
HessianProvider {
    public static final String TYPE = "MultivariateNormal";
    private final double[] mean;
    private final double[][] precision;
    private double[][] variance = null;
    private double[][] cholesky = null;
    private Double logDet = null;
    private final boolean hasSinglePrecision;
    private final double singlePrecision;
    private static final double logNormalize = -0.5 * Math.log(Math.PI * 2);

    public MultivariateNormalDistribution(double[] dArray, double[][] dArray2) {
        this.mean = dArray;
        this.precision = dArray2;
        this.hasSinglePrecision = false;
        this.singlePrecision = 1.0;
    }

    public MultivariateNormalDistribution(double[] dArray, double d) {
        this.mean = dArray;
        this.hasSinglePrecision = true;
        this.singlePrecision = d;
        int n = dArray.length;
        this.precision = new double[n][n];
        for (int i = 0; i < n; ++i) {
            this.precision[i][i] = d;
        }
    }

    @Override
    public String getType() {
        return TYPE;
    }

    public double[][] getVariance() {
        if (this.variance == null) {
            this.variance = new SymmetricMatrix(this.precision).inverse().toComponents();
        }
        return this.variance;
    }

    public double[][] getCholeskyDecomposition() {
        if (this.cholesky == null) {
            this.cholesky = MultivariateNormalDistribution.getCholeskyDecomposition(this.getVariance());
        }
        return this.cholesky;
    }

    public double getLogDet() {
        if (this.logDet == null) {
            this.logDet = MultivariateNormalDistribution.calculatePrecisionMatrixLogDeterminate(this.precision);
        }
        if (Double.isInfinite(this.logDet) && this.isDiagonal(this.precision)) {
            this.logDet = this.logDetForDiagonal(this.precision);
        }
        return this.logDet;
    }

    private boolean isDiagonal(double[][] dArray) {
        for (int i = 0; i < dArray.length; ++i) {
            for (int j = i + 1; j < dArray.length; ++j) {
                if (dArray[i][j] == 0.0) continue;
                return false;
            }
        }
        return true;
    }

    private double logDetForDiagonal(double[][] dArray) {
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += Math.log(dArray[i][i]);
        }
        return d;
    }

    @Override
    public double[][] getScaleMatrix() {
        return this.precision;
    }

    @Override
    public double[] getMean() {
        return this.mean;
    }

    public double[] nextMultivariateNormal() {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(this.mean, this.getCholeskyDecomposition(), 1.0);
    }

    public double[] nextMultivariateNormal(double[] dArray) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray, this.getCholeskyDecomposition(), 1.0);
    }

    public double[] nextScaledMultivariateNormal(double[] dArray, double d) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray, this.getCholeskyDecomposition(), Math.sqrt(d));
    }

    public void nextScaledMultivariateNormal(double[] dArray, double d, double[] dArray2) {
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray, this.getCholeskyDecomposition(), Math.sqrt(d), dArray2);
    }

    public static double calculatePrecisionMatrixLogDeterminate(double[][] dArray) {
        try {
            return new Matrix(dArray).logDeterminant();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException(illegalDimension.getMessage());
        }
    }

    @Override
    public double logPdf(double[] dArray) {
        if (this.hasSinglePrecision) {
            return MultivariateNormalDistribution.logPdf(dArray, this.mean, this.singlePrecision, 1.0);
        }
        return MultivariateNormalDistribution.logPdf(dArray, this.mean, this.precision, this.getLogDet(), 1.0);
    }

    public double[] gradLogPdf(double[] dArray) {
        if (this.hasSinglePrecision) {
            return MultivariateNormalDistribution.gradLogPdf(dArray, this.mean, this.singlePrecision);
        }
        return MultivariateNormalDistribution.gradLogPdf(dArray, this.mean, this.precision);
    }

    public static double[] gradLogPdf(double[] dArray, double[] dArray2, double d) {
        int n = dArray.length;
        double[] dArray3 = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray3[i] = d * (dArray2[i] - dArray[i]);
        }
        return dArray3;
    }

    public static double[] gradLogPdf(double[] dArray, double[] dArray2, double[][] dArray3) {
        int n;
        int n2 = dArray.length;
        double[] dArray4 = new double[n2];
        double[] dArray5 = new double[n2];
        for (n = 0; n < n2; ++n) {
            dArray5[n] = dArray2[n] - dArray[n];
        }
        for (n = 0; n < n2; ++n) {
            double d = 0.0;
            for (int i = 0; i < n2; ++i) {
                d += dArray3[n][i] * dArray5[i];
            }
            dArray4[n] = d;
        }
        return dArray4;
    }

    public static double[] gradLogPdf(double[] dArray, double[] dArray2, double[] dArray3) {
        int n;
        int n2 = dArray.length;
        double[] dArray4 = new double[n2];
        double[] dArray5 = new double[n2];
        for (n = 0; n < n2; ++n) {
            dArray5[n] = dArray2[n] - dArray[n];
        }
        for (n = 0; n < n2; ++n) {
            double d = 0.0;
            for (int i = 0; i < n2; ++i) {
                d += dArray3[n * n2 + i] * dArray5[i];
            }
            dArray4[n] = d;
        }
        return dArray4;
    }

    public double[][] hessianLogPdf(double[] dArray) {
        if (this.hasSinglePrecision) {
            return MultivariateNormalDistribution.hessianLogPdf(dArray, this.mean, this.singlePrecision);
        }
        return MultivariateNormalDistribution.hessianLogPdf(dArray, this.mean, this.precision);
    }

    public static double[][] hessianLogPdf(double[] dArray, double[] dArray2, double d) {
        int n = dArray.length;
        double[][] dArray3 = new double[n][n];
        for (int i = 0; i < n; ++i) {
            dArray3[i][i] = -d;
        }
        return dArray3;
    }

    public static double[][] hessianLogPdf(double[] dArray, double[] dArray2, double[][] dArray3) {
        int n = dArray.length;
        double[][] dArray4 = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                dArray4[i][j] = -dArray3[i][j];
            }
        }
        return dArray4;
    }

    public double[] diagonalHessianLogPdf(double[] dArray) {
        if (this.hasSinglePrecision) {
            return MultivariateNormalDistribution.diagonalHessianLogPdf(dArray, this.mean, this.singlePrecision);
        }
        return MultivariateNormalDistribution.diagonalHessianLogPdf(dArray, this.mean, this.precision);
    }

    public static double[] diagonalHessianLogPdf(double[] dArray, double[] dArray2, double d) {
        int n = dArray.length;
        double[] dArray3 = new double[n];
        Arrays.fill(dArray3, -d);
        return dArray3;
    }

    public static double[] diagonalHessianLogPdf(double[] dArray, double[] dArray2, double[][] dArray3) {
        int n = dArray.length;
        double[] dArray4 = new double[n];
        for (int i = 0; i < n; ++i) {
            dArray4[i] = -dArray3[i][i];
        }
        return dArray4;
    }

    public static double logPdf(double[] dArray, double[] dArray2, double[][] dArray3, double d, double d2) {
        int n;
        if (d == Double.NEGATIVE_INFINITY) {
            return d;
        }
        int n2 = dArray.length;
        double[] dArray4 = new double[n2];
        double[] dArray5 = new double[n2];
        for (n = 0; n < n2; ++n) {
            dArray4[n] = dArray[n] - dArray2[n];
        }
        for (n = 0; n < n2; ++n) {
            for (int i = 0; i < n2; ++i) {
                int n3 = n;
                dArray5[n3] = dArray5[n3] + dArray4[i] * dArray3[i][n];
            }
        }
        double d3 = 0.0;
        for (int i = 0; i < n2; ++i) {
            d3 += dArray5[i] * dArray4[i];
        }
        return (double)n2 * logNormalize + 0.5 * (d - (double)n2 * Math.log(d2) - d3 / d2);
    }

    public static double logPdf(double[] dArray, double[] dArray2, double d, double d2) {
        int n = dArray.length;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            double d4 = dArray[i] - dArray2[i];
            d3 += d4 * d4;
        }
        return (double)n * logNormalize + 0.5 * ((double)n * (Math.log(d) - Math.log(d2)) - d3 * d / d2);
    }

    private static double[][] getInverse(double[][] dArray) {
        return new SymmetricMatrix(dArray).inverse().toComponents();
    }

    private static double[][] getCholeskyDecomposition(double[][] dArray) {
        double[][] dArray2;
        try {
            dArray2 = new CholeskyDecomposition(dArray).getL();
        }
        catch (IllegalDimension illegalDimension) {
            throw new RuntimeException("Attempted Cholesky decomposition on non-square matrix");
        }
        return dArray2;
    }

    public static double[] nextMultivariateNormalViaBackSolvePrecision(double[] dArray, double[][] dArray2) {
        double[] dArray3 = new double[dArray.length * dArray.length];
        int n = 0;
        for (int i = 0; i < dArray.length; ++i) {
            System.arraycopy(dArray2[i], 0, dArray3, n, dArray.length);
            n += dArray.length;
        }
        return MultivariateNormalDistribution.nextMultivariateNormalViaBackSolvePrecision(dArray, dArray3);
    }

    public static double[] nextMultivariateNormalViaBackSolvePrecision(double[] dArray, double[] dArray2) {
        int n;
        int n2 = dArray.length;
        DenseMatrix64F denseMatrix64F = DenseMatrix64F.wrap(n2, n2, dArray2);
        CholeskyDecompositionInner_D64 choleskyDecompositionInner_D64 = new CholeskyDecompositionInner_D64();
        choleskyDecompositionInner_D64.decompose(denseMatrix64F);
        double[] dArray3 = new double[n2];
        for (n = 0; n < n2; ++n) {
            dArray3[n] = MathUtils.nextGaussian();
        }
        TriangularSolver.solveTranL(denseMatrix64F.getData(), dArray3, n2);
        for (n = 0; n < n2; ++n) {
            int n3 = n;
            dArray3[n3] = dArray3[n3] + dArray[n];
        }
        return dArray3;
    }

    public static double[] nextMultivariateNormalPrecision(double[] dArray, double[][] dArray2) {
        return MultivariateNormalDistribution.nextMultivariateNormalVariance(dArray, MultivariateNormalDistribution.getInverse(dArray2));
    }

    public static double[] nextMultivariateNormalVariance(double[] dArray, double[][] dArray2) {
        return MultivariateNormalDistribution.nextMultivariateNormalVariance(dArray, dArray2, 1.0);
    }

    public static double[] nextMultivariateNormalVariance(double[] dArray, double[][] dArray2, double d) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray, MultivariateNormalDistribution.getCholeskyDecomposition(dArray2), Math.sqrt(d));
    }

    public static double[] nextMultivariateNormalCholesky(double[] dArray, double[][] dArray2) {
        return MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray, dArray2, 1.0);
    }

    public static double[] nextMultivariateNormalCholesky(double[] dArray, double[][] dArray2, double d) {
        double[] dArray3 = new double[dArray.length];
        MultivariateNormalDistribution.nextMultivariateNormalCholesky(dArray, dArray2, d, dArray3);
        return dArray3;
    }

    public static void nextMultivariateNormalCholesky(double[] dArray, double[][] dArray2, double d, double[] dArray3) {
        int n;
        int n2 = dArray.length;
        System.arraycopy(dArray, 0, dArray3, 0, n2);
        double[] dArray4 = new double[n2];
        for (n = 0; n < n2; ++n) {
            dArray4[n] = MathUtils.nextGaussian() * d;
        }
        for (n = 0; n < n2; ++n) {
            for (int i = 0; i <= n; ++i) {
                int n3 = n;
                dArray3[n3] = dArray3[n3] + dArray2[n][i] * dArray4[i];
            }
        }
    }

    public static void nextMultivariateNormalCholesky(ReadableVector readableVector, ReadableMatrix readableMatrix, double d, WritableVector writableVector, double[] dArray) {
        int n;
        int n2 = readableVector.getDim();
        for (n = 0; n < n2; ++n) {
            dArray[n] = MathUtils.nextGaussian() * d;
        }
        for (n = 0; n < n2; ++n) {
            double d2 = readableVector.get(n);
            for (int i = 0; i <= n; ++i) {
                d2 += readableMatrix.get(n, i) * dArray[i];
            }
            writableVector.set(n, d2);
        }
    }

    public static void nextMultivariateNormalCholesky(double[] dArray, int n, double[][] dArray2, double d, double[] dArray3, int n2, double[] dArray4) {
        int n3;
        int n4 = dArray4.length;
        System.arraycopy(dArray, n, dArray3, n2, n4);
        for (n3 = 0; n3 < n4; ++n3) {
            dArray4[n3] = MathUtils.nextGaussian() * d;
        }
        for (n3 = 0; n3 < n4; ++n3) {
            for (int i = 0; i <= n3; ++i) {
                int n5 = n2 + n3;
                dArray3[n5] = dArray3[n5] + dArray2[n3][i] * dArray4[i];
            }
        }
    }

    public static void main(String[] stringArray) {
        MultivariateNormalDistribution.testPdf();
        MultivariateNormalDistribution.testRandomDraws();
    }

    public static void testPdf() {
        double[] dArray = new double[]{1.0, 2.0};
        double[] dArray2 = new double[]{0.0, 0.0};
        double[][] dArrayArray = new double[][]{{2.0, 0.5}, {0.5, 1.0}};
        double d = 0.2;
        System.err.println("logPDF = " + MultivariateNormalDistribution.logPdf(dArray, dArray2, dArrayArray, MultivariateNormalDistribution.calculatePrecisionMatrixLogDeterminate(dArrayArray), d));
        System.err.println("Should = -19.94863\n");
        System.err.println("logPDF = " + MultivariateNormalDistribution.logPdf(dArray, dArray2, 2.0, 0.2));
        System.err.println("Should = -24.53529\n");
    }

    public static void testRandomDraws() {
        int n;
        double[] dArray = new double[]{1.0, 2.0};
        double[][] dArrayArray = new double[][]{{2.0, 0.5}, {0.5, 1.0}};
        int n2 = 1000000;
        System.err.println("Random draws (via precision) ...");
        double[] dArray2 = new double[2];
        double[] dArray3 = new double[2];
        double[] dArray4 = new double[2];
        double d = 0.0;
        for (n = 0; n < n2; ++n) {
            double[] dArray5 = MultivariateNormalDistribution.nextMultivariateNormalViaBackSolvePrecision(dArray, dArrayArray);
            for (int i = 0; i < 2; ++i) {
                int n3 = i;
                dArray2[n3] = dArray2[n3] + dArray5[i];
                int n4 = i;
                dArray3[n4] = dArray3[n4] + dArray5[i] * dArray5[i];
            }
            d += dArray5[0] * dArray5[1];
        }
        for (n = 0; n < 2; ++n) {
            int n5 = n;
            dArray2[n5] = dArray2[n5] / (double)n2;
            int n6 = n;
            dArray3[n6] = dArray3[n6] / (double)n2;
            dArray4[n] = dArray3[n] - dArray2[n] * dArray2[n];
        }
        d /= (double)n2;
        System.err.println("Mean: " + new Vector(dArray2));
        System.err.println("TRUE: [ 1 2 ]\n");
        System.err.println("MVar: " + new Vector(dArray4));
        System.err.println("TRUE: [ 0.571 1.14 ]\n");
        System.err.println("Covv: " + (d -= dArray2[0] * dArray2[1]));
        System.err.println("TRUE: -0.286");
    }

    @Override
    public Object nextRandom() {
        return this.nextMultivariateNormal();
    }

    @Override
    public double logPdf(Object object) {
        double[] dArray = (double[])object;
        return this.logPdf(dArray);
    }

    @Override
    public Likelihood getLikelihood() {
        return null;
    }

    @Override
    public int getDimension() {
        return this.mean.length;
    }

    @Override
    public double[] getGradientLogDensity(Object object) {
        return this.gradLogPdf((double[])object);
    }

    @Override
    public double[][] getPrecisionMatrix() {
        return this.precision;
    }

    @Override
    public double[] getDiagonalHessianLogDensity(Object object) {
        return this.diagonalHessianLogPdf((double[])object);
    }

    @Override
    public double[][] getHessianLogDensity(Object object) {
        return this.hessianLogPdf((double[])object);
    }
}

