/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous.cdi;

import dr.evomodel.treedatalikelihood.continuous.cdi.ContinuousDiffusionIntegrator;
import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.math.matrixAlgebra.WrappedVector;
import dr.math.matrixAlgebra.missingData.InversionResult;
import dr.math.matrixAlgebra.missingData.MissingOps;
import java.util.HashMap;
import java.util.Map;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class MultivariateIntegrator
extends ContinuousDiffusionIntegrator.Basic {
    private static boolean DEBUG = false;
    private static final boolean TIMING = false;
    private final Map<String, Long> times;
    DenseMatrix64F matrix0;
    DenseMatrix64F matrix1;
    DenseMatrix64F matrixPip;
    DenseMatrix64F matrixPjp;
    DenseMatrix64F matrixPk;
    private DenseMatrix64F matrix5;
    private DenseMatrix64F matrix6;
    double[] vector0;
    private final Map<String, Long> startTimes = new HashMap<String, Long>();
    double[] inverseDiffusions;

    public MultivariateIntegrator(PrecisionType precisionType, int n, int n2, int n3, int n4, int n5) {
        super(precisionType, n, n2, n3, n4, n5);
        assert (precisionType == PrecisionType.FULL);
        this.allocateStorage();
        this.times = null;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder();
        return stringBuilder.toString();
    }

    private void allocateStorage() {
        this.inverseDiffusions = new double[this.dimProcess * this.dimProcess * this.diffusionCount];
        this.vector0 = new double[this.dimTrait];
        this.matrix0 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        this.matrix1 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        this.matrixPip = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        this.matrixPjp = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        this.matrixPk = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        this.matrix5 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
        this.matrix6 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
    }

    @Override
    public void setPostOrderPartial(int n, double[] dArray) {
        super.setPostOrderPartial(n, dArray);
        int n2 = PrecisionType.FULL.getRemainderOffset(this.dimTrait);
        for (int i = 0; i < this.numTraits; ++i) {
            this.remainders[n * this.numTraits + i] = dArray[this.dimPartialForTrait * i + n2];
        }
    }

    @Override
    public void setDiffusionPrecision(int n, double[] dArray, double d) {
        super.setDiffusionPrecision(n, dArray, d);
        assert (this.inverseDiffusions != null);
        int n2 = this.dimProcess * this.dimProcess * n;
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.diffusions, n2, this.dimProcess, this.dimProcess);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.dimProcess, this.dimProcess);
        CommonOps.invert(denseMatrix64F, denseMatrix64F2);
        MissingOps.unwrap(denseMatrix64F2, this.inverseDiffusions, n2);
        if (DEBUG) {
            System.err.println("At precision index: " + n);
            System.err.println("precision: " + denseMatrix64F);
            System.err.println("variance : " + denseMatrix64F2);
        }
    }

    public double[] getVariance(int n) {
        assert (this.inverseDiffusions != null);
        return this.getMatrixProcess(n, this.inverseDiffusions);
    }

    double[] getMatrixProcess(int n, double[] dArray) {
        int n2 = this.dimTrait * this.dimTrait * n;
        double[] dArray2 = new double[this.dimTrait * this.dimTrait];
        System.arraycopy(dArray, n2, dArray2, 0, this.dimTrait * this.dimTrait);
        return dArray2;
    }

    @Override
    public void getBranchVariance(int n, int n2, double[] dArray) {
        if (n == -1) {
            throw new RuntimeException("Not yet implemented");
        }
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait * this.dimTrait);
        this.updatePrecisionOffsetAndDeterminant(n2);
        double d = this.getBranchLength(n);
        for (int i = 0; i < this.dimTrait * this.dimTrait; ++i) {
            dArray[i] = d * this.inverseDiffusions[this.precisionOffset + i];
        }
    }

    @Override
    public void updatePreOrderPartial(int n, int n2, int n3, int n4, int n5) {
        int n6 = this.dimPartial * n;
        int n7 = this.dimPartial * n2;
        int n8 = this.dimPartial * n4;
        int n9 = n3;
        int n10 = n5;
        double d = this.branchLengths[n9];
        double d2 = this.branchLengths[n10];
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.inverseDiffusions, this.precisionOffset, this.dimTrait, this.dimTrait);
        if (DEBUG) {
            System.err.println("updatePreOrderPartial for node " + n2);
            System.err.println("\tvi: " + d + " vj: " + d2);
        }
        for (int i = 0; i < this.numTraits; ++i) {
            int n11;
            double d3;
            int n12;
            DenseMatrix64F denseMatrix64F2;
            DenseMatrix64F denseMatrix64F3 = MissingOps.wrap(this.preOrderPartials, n6 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.partials, n8 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            if (MissingOps.allZeroDiagonals(denseMatrix64F4)) {
                denseMatrix64F2 = MissingOps.wrap(this.partials, n8 + this.dimTrait, this.dimTrait, this.dimTrait);
                assert (!MissingOps.allZeroDiagonals(denseMatrix64F2));
                MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F4, false);
            }
            denseMatrix64F2 = this.matrix1;
            CommonOps.add(denseMatrix64F4, d2, (D1Matrix64F)denseMatrix64F, (D1Matrix64F)denseMatrix64F2);
            DenseMatrix64F denseMatrix64F5 = this.matrixPjp;
            MissingOps.safeInvert2(denseMatrix64F2, denseMatrix64F5, false);
            DenseMatrix64F denseMatrix64F6 = this.matrixPip;
            CommonOps.add((D1Matrix64F)denseMatrix64F3, denseMatrix64F5, (D1Matrix64F)denseMatrix64F6);
            DenseMatrix64F denseMatrix64F7 = this.matrix0;
            MissingOps.safeInvert2(denseMatrix64F6, denseMatrix64F7, false);
            double[] dArray = this.vector0;
            for (n12 = 0; n12 < this.dimTrait; ++n12) {
                d3 = 0.0;
                for (n11 = 0; n11 < this.dimTrait; ++n11) {
                    d3 += denseMatrix64F3.unsafe_get(n12, n11) * this.preOrderPartials[n6 + n11];
                    d3 += denseMatrix64F5.unsafe_get(n12, n11) * this.partials[n8 + n11];
                }
                dArray[n12] = d3;
            }
            for (n12 = 0; n12 < this.dimTrait; ++n12) {
                d3 = 0.0;
                for (n11 = 0; n11 < this.dimTrait; ++n11) {
                    d3 += denseMatrix64F7.unsafe_get(n12, n11) * dArray[n11];
                }
                this.preOrderPartials[n7 + n12] = d3;
            }
            DenseMatrix64F denseMatrix64F8 = denseMatrix64F7;
            CommonOps.add(d, denseMatrix64F, (D1Matrix64F)denseMatrix64F7, (D1Matrix64F)denseMatrix64F8);
            DenseMatrix64F denseMatrix64F9 = this.matrixPk;
            MissingOps.safeInvert2(denseMatrix64F8, denseMatrix64F9, false);
            MissingOps.unwrap(denseMatrix64F9, this.preOrderPartials, n7 + this.dimTrait);
            MissingOps.unwrap(denseMatrix64F8, this.preOrderPartials, n7 + this.dimTrait + this.dimTrait * this.dimTrait);
            if (DEBUG) {
                System.err.println("trait: " + i);
                System.err.println("pM: " + new WrappedVector.Raw(this.preOrderPartials, n6, this.dimTrait));
                System.err.println("pP: " + denseMatrix64F3);
                System.err.println("sM: " + new WrappedVector.Raw(this.partials, n8, this.dimTrait));
                System.err.println("sV: " + denseMatrix64F4);
                System.err.println("sVp: " + denseMatrix64F2);
                System.err.println("sPp: " + denseMatrix64F5);
                System.err.println("Pip: " + denseMatrix64F6);
                System.err.println("cM: " + new WrappedVector.Raw(this.preOrderPartials, n7, this.dimTrait));
                System.err.println("cV: " + denseMatrix64F8);
            }
            n6 += this.dimPartialForTrait;
            n7 += this.dimPartialForTrait;
            n8 += this.dimPartialForTrait;
        }
    }

    @Override
    protected void updatePartial(int n, int n2, int n3, int n4, int n5, boolean bl, boolean bl2) {
        if (bl2) {
            throw new RuntimeException("Outer-products are not supported.");
        }
        int n6 = this.dimPartial * n;
        int n7 = this.dimPartial * n2;
        int n8 = this.dimPartial * n4;
        int n9 = n3;
        int n10 = n5;
        double d = this.branchLengths[n9];
        double d2 = this.branchLengths[n10];
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.inverseDiffusions, this.precisionOffset, this.dimTrait, this.dimTrait);
        if (DEBUG) {
            System.err.println("variance diffusion: " + denseMatrix64F);
            System.err.println("\tvi: " + d + " vj: " + d2);
            System.err.println("precisionOffset = " + this.precisionOffset);
        }
        for (int i = 0; i < this.numTraits; ++i) {
            double d3 = this.partials[n7 + this.dimTrait + 2 * this.dimTrait * this.dimTrait];
            double d4 = this.partials[n8 + this.dimTrait + 2 * this.dimTrait * this.dimTrait];
            DenseMatrix64F denseMatrix64F2 = MissingOps.wrap(this.partials, n7 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F3 = MissingOps.wrap(this.partials, n8 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.partials, n7 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F5 = MissingOps.wrap(this.partials, n8 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            double d5 = Double.isInfinite(d3) ? 1.0 / d : d3 / (1.0 + d3 * d);
            double d6 = Double.isInfinite(d4) ? 1.0 / d2 : d4 / (1.0 + d4 * d2);
            DenseMatrix64F denseMatrix64F6 = this.matrix0;
            DenseMatrix64F denseMatrix64F7 = this.matrix1;
            CommonOps.add(denseMatrix64F4, d, (D1Matrix64F)denseMatrix64F, (D1Matrix64F)denseMatrix64F6);
            CommonOps.add(denseMatrix64F5, d2, (D1Matrix64F)denseMatrix64F, (D1Matrix64F)denseMatrix64F7);
            DenseMatrix64F denseMatrix64F8 = this.matrixPip;
            DenseMatrix64F denseMatrix64F9 = this.matrixPjp;
            InversionResult inversionResult = MissingOps.safeInvert2(denseMatrix64F6, denseMatrix64F8, true);
            InversionResult inversionResult2 = MissingOps.safeInvert2(denseMatrix64F7, denseMatrix64F9, true);
            double d7 = d5 + d6;
            DenseMatrix64F denseMatrix64F10 = this.matrixPk;
            CommonOps.add((D1Matrix64F)denseMatrix64F8, denseMatrix64F9, (D1Matrix64F)denseMatrix64F10);
            DenseMatrix64F denseMatrix64F11 = this.matrix5;
            InversionResult inversionResult3 = MissingOps.safeInvertPrecision(denseMatrix64F10, denseMatrix64F11, true);
            MissingOps.weightedAverage(this.partials, n7, denseMatrix64F8, this.partials, n8, denseMatrix64F9, this.partials, n6, denseMatrix64F11, this.dimTrait, this.vector0);
            this.partials[n6 + this.dimTrait + 2 * this.dimTrait * this.dimTrait] = d7;
            MissingOps.unwrap(denseMatrix64F10, this.partials, n6 + this.dimTrait);
            MissingOps.unwrap(denseMatrix64F11, this.partials, n6 + this.dimTrait + this.dimTrait * this.dimTrait);
            if (DEBUG) {
                this.reportMeansAndPrecisions(i, n7, n8, n6, denseMatrix64F2, denseMatrix64F3, denseMatrix64F10);
            }
            double d8 = 0.0;
            if (DEBUG) {
                System.err.println("i status: " + inversionResult);
                System.err.println("j status: " + inversionResult2);
                System.err.println("k status: " + inversionResult3);
                System.err.println("Pip: " + denseMatrix64F8);
                System.err.println("Vip: " + denseMatrix64F6);
                System.err.println("Pjp: " + denseMatrix64F9);
                System.err.println("Vjp: " + denseMatrix64F7);
            }
            if (inversionResult.getReturnCode() != InversionResult.Code.NOT_OBSERVED && inversionResult2.getReturnCode() != InversionResult.Code.NOT_OBSERVED) {
                double d9 = MissingOps.weightedThreeInnerProduct(this.partials, n7, denseMatrix64F8, this.partials, n8, denseMatrix64F9, this.partials, n6, denseMatrix64F10, this.dimTrait);
                DenseMatrix64F denseMatrix64F12 = this.matrix6;
                CommonOps.add((D1Matrix64F)denseMatrix64F6, denseMatrix64F7, (D1Matrix64F)denseMatrix64F12);
                if (DEBUG) {
                    System.err.println("Vt: " + denseMatrix64F12);
                }
                int n11 = inversionResult.getEffectiveDimension() + inversionResult2.getEffectiveDimension() - inversionResult3.getEffectiveDimension();
                d8 += (double)(-n11) * LOG_SQRT_2_PI - 0.5 * (inversionResult.getLogDeterminant() + inversionResult2.getLogDeterminant() + inversionResult3.getLogDeterminant()) - 0.5 * d9;
                if (DEBUG) {
                    System.err.println("\t\t\tSS = " + d9);
                    System.err.println("\t\t\tdetI = " + inversionResult.getLogDeterminant());
                    System.err.println("\t\t\tdetJ = " + inversionResult2.getLogDeterminant());
                    System.err.println("\t\t\tdetK = " + inversionResult3.getLogDeterminant());
                    System.err.println("\t\tremainder: " + d8);
                }
            }
            this.remainders[n * this.numTraits + i] = d8 + this.remainders[n2 * this.numTraits + i] + this.remainders[n4 * this.numTraits + i];
            n6 += this.dimPartialForTrait;
            n7 += this.dimPartialForTrait;
            n8 += this.dimPartialForTrait;
        }
    }

    void reportMeansAndPrecisions(int n, int n2, int n3, int n4, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        int n5;
        System.err.println("\ttrait: " + n);
        System.err.println("Pi: " + denseMatrix64F);
        System.err.println("Pj: " + denseMatrix64F2);
        System.err.println("Pk: " + denseMatrix64F3);
        System.err.print("\t\tmean i:");
        for (n5 = 0; n5 < this.dimTrait; ++n5) {
            System.err.print(" " + this.partials[n2 + n5]);
        }
        System.err.print("\t\tmean j:");
        for (n5 = 0; n5 < this.dimTrait; ++n5) {
            System.err.print(" " + this.partials[n3 + n5]);
        }
        System.err.print("\t\tmean k:");
        for (n5 = 0; n5 < this.dimTrait; ++n5) {
            System.err.print(" " + this.partials[n4 + n5]);
        }
        System.err.println("");
    }

    void startTime(String string) {
        this.startTimes.put(string, System.nanoTime());
    }

    void endTime(String string) {
        long l = this.startTimes.get(string);
        Long l2 = this.times.get(string);
        if (l2 == null) {
            l2 = 0L;
        }
        long l3 = l2 + (System.nanoTime() - l);
        this.times.put(string, l3);
    }

    @Override
    public void calculatePreOrderRoot(int n, int n2, int n3) {
        super.calculatePreOrderRoot(n, n2, n3);
        this.updatePrecisionOffsetAndDeterminant(n3);
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.diffusions, this.precisionOffset, this.dimTrait, this.dimTrait);
        DenseMatrix64F denseMatrix64F2 = MissingOps.wrap(this.inverseDiffusions, this.precisionOffset, this.dimTrait, this.dimTrait);
        int n4 = this.dimPartial * n2;
        for (int i = 0; i < this.numTraits; ++i) {
            DenseMatrix64F denseMatrix64F3 = MissingOps.wrap(this.preOrderPartials, n4 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.preOrderPartials, n4 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F5 = this.matrix0;
            MissingOps.safeMult(denseMatrix64F, denseMatrix64F3, denseMatrix64F5);
            MissingOps.unwrap(denseMatrix64F5, this.preOrderPartials, n4 + this.dimTrait);
            CommonOps.mult(denseMatrix64F2, denseMatrix64F4, denseMatrix64F5);
            MissingOps.unwrap(denseMatrix64F5, this.preOrderPartials, n4 + this.dimTrait + this.dimTrait * this.dimTrait);
            n4 += this.dimPartialForTrait;
        }
    }

    @Override
    public void calculateRootLogLikelihood(int n, int n2, int n3, double[] dArray, boolean bl, boolean bl2) {
        assert (dArray.length == this.numTraits);
        assert (!bl);
        assert (!bl2);
        if (DEBUG) {
            System.err.println("Root calculation for " + n);
            System.err.println("Prior buffer index is " + n2);
        }
        int n4 = this.dimPartial * n;
        int n5 = this.dimPartial * n2;
        this.updatePrecisionOffsetAndDeterminant(n3);
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(this.inverseDiffusions, this.precisionOffset, this.dimTrait, this.dimTrait);
        for (int i = 0; i < this.numTraits; ++i) {
            DenseMatrix64F denseMatrix64F2 = MissingOps.wrap(this.partials, n4 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F3 = MissingOps.wrap(this.partials, n5 + this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F4 = MissingOps.wrap(this.partials, n4 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F5 = MissingOps.wrap(this.partials, n5 + this.dimTrait + this.dimTrait * this.dimTrait, this.dimTrait, this.dimTrait);
            DenseMatrix64F denseMatrix64F6 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            CommonOps.mult(denseMatrix64F, denseMatrix64F5, denseMatrix64F6);
            denseMatrix64F5.set(denseMatrix64F6);
            denseMatrix64F6 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            CommonOps.add((D1Matrix64F)denseMatrix64F4, denseMatrix64F5, (D1Matrix64F)denseMatrix64F6);
            DenseMatrix64F denseMatrix64F7 = new DenseMatrix64F(this.dimTrait, this.dimTrait);
            CommonOps.invert(denseMatrix64F6, denseMatrix64F7);
            double d = MissingOps.weightedInnerProductOfDifferences(this.partials, n4, this.partials, n5, denseMatrix64F7, this.dimTrait);
            double d2 = (double)(-this.dimTrait) * LOG_SQRT_2_PI - 0.5 * Math.log(CommonOps.det(denseMatrix64F6)) - 0.5 * d;
            double d3 = this.remainders[n * this.numTraits + i];
            dArray[i] = d2 + d3;
            if (DEBUG) {
                System.err.print("mean:");
                for (int j = 0; j < this.dimTrait; ++j) {
                    System.err.print(" " + this.partials[n4 + j]);
                }
                System.err.println("");
                System.err.println("P  root: " + denseMatrix64F2);
                System.err.println("V  root: " + denseMatrix64F4);
                System.err.println("P prior: " + denseMatrix64F3);
                System.err.println("V prior: " + denseMatrix64F5);
                System.err.println("P total: " + denseMatrix64F7);
                System.err.println("V total: " + denseMatrix64F6);
                System.err.println("\t" + d2 + " " + (d2 + d3));
            }
            n4 += this.dimPartialForTrait;
            n5 += this.dimPartialForTrait;
        }
        if (DEBUG) {
            System.err.println("End");
        }
    }
}

