/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.hmc;

import dr.evomodel.treedatalikelihood.hmc.CorrelationPrecisionGradient;
import dr.evomodel.treedatalikelihood.hmc.GradientWrtPrecisionProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.MaskedParameter;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.model.TransformedMultivariateParameter;
import dr.util.MatrixInnerProductTransform;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class FullCorrelationPrecisionGradient
extends CorrelationPrecisionGradient {
    private final MatrixParameterInterface decomposedMatrix;
    private static final RuntimeException PARAMETER_EXCEPTION = new RuntimeException("off-diagonal parameter must be a mask of a inner product transform.");

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public FullCorrelationPrecisionGradient(GradientWrtPrecisionProvider gradientWrtPrecisionProvider, Likelihood likelihood, MatrixParameterInterface matrixParameterInterface) {
        super(gradientWrtPrecisionProvider, likelihood, matrixParameterInterface);
        Parameter parameter = this.compoundSymmetricMatrix.getOffDiagonalParameter();
        if (!(parameter instanceof MaskedParameter)) throw PARAMETER_EXCEPTION;
        MaskedParameter maskedParameter = (MaskedParameter)parameter;
        if (!(maskedParameter.getUnmaskedParameter() instanceof TransformedMultivariateParameter)) throw PARAMETER_EXCEPTION;
        TransformedMultivariateParameter transformedMultivariateParameter = (TransformedMultivariateParameter)maskedParameter.getUnmaskedParameter();
        if (!(transformedMultivariateParameter.getTransform() instanceof MatrixInnerProductTransform)) {
            throw PARAMETER_EXCEPTION;
        }
        this.decomposedMatrix = (MatrixParameterInterface)transformedMultivariateParameter.getUntransformedParameter();
    }

    @Override
    public int getDimension() {
        return this.getDimensionFull();
    }

    @Override
    public Parameter getParameter() {
        return this.decomposedMatrix;
    }

    @Override
    double[] getGradientParameter(double[] dArray) {
        int n = this.decomposedMatrix.getRowDimension();
        double[] dArray2 = this.compoundSymmetricMatrix.updateGradientFullOffDiagonal(dArray);
        double[] dArray3 = new double[dArray.length];
        DenseMatrix64F denseMatrix64F = DenseMatrix64F.wrap(n, n, dArray2);
        DenseMatrix64F denseMatrix64F2 = DenseMatrix64F.wrap(n, n, this.decomposedMatrix.getParameterValues());
        DenseMatrix64F denseMatrix64F3 = DenseMatrix64F.wrap(n, n, dArray3);
        CommonOps.multTransA(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
        CommonOps.scale(2.0, denseMatrix64F3);
        return dArray3;
    }
}

