/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.model.ScaledMatrixParameter;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class ScaledMatrixChainGradient
implements GradientWrtParameterProvider {
    private final GradientWrtParameterProvider originalGradient;
    private final ScaledMatrixParameter parameter;
    private final ComponentProvider componentProvider;
    private static final String SCALED_GRADIENT = "scaledMatrixGradient";
    private static final String COMPONENT = "component";
    private static final String SCALE = "scale";
    private static final String MATRIX = "matrix";
    public static AbstractXMLObjectParser PARSER = new AbstractXMLObjectParser(){

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.getStringAttribute(ScaledMatrixChainGradient.COMPONENT);
            ComponentProvider componentProvider = null;
            for (ComponentProvider componentProvider2 : ComponentProvider.values()) {
                if (!string.equalsIgnoreCase(componentProvider2.name)) continue;
                componentProvider = componentProvider2;
            }
            if (componentProvider == null) {
                throw new XMLParseException("Unrecognized 'component'. Must be 'scale' or 'matrix'.");
            }
            GradientWrtParameterProvider gradientWrtParameterProvider = (GradientWrtParameterProvider)xMLObject.getChild(GradientWrtParameterProvider.class);
            return new ScaledMatrixChainGradient(gradientWrtParameterProvider, componentProvider);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return new XMLSyntaxRule[]{new ElementRule(GradientWrtParameterProvider.class), AttributeRule.newStringRule(ScaledMatrixChainGradient.COMPONENT)};
        }

        @Override
        public String getParserDescription() {
            return null;
        }

        @Override
        public Class getReturnType() {
            return ScaledMatrixChainGradient.class;
        }

        @Override
        public String getParserName() {
            return ScaledMatrixChainGradient.SCALED_GRADIENT;
        }
    };

    ScaledMatrixChainGradient(GradientWrtParameterProvider gradientWrtParameterProvider, ComponentProvider componentProvider) {
        this.originalGradient = gradientWrtParameterProvider;
        this.parameter = (ScaledMatrixParameter)gradientWrtParameterProvider.getParameter();
        this.componentProvider = componentProvider;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.originalGradient.getLikelihood();
    }

    @Override
    public Parameter getParameter() {
        return this.componentProvider.getParameter(this.parameter);
    }

    @Override
    public int getDimension() {
        return this.componentProvider.getDimension(this.parameter);
    }

    @Override
    public double[] getGradientLogDensity() {
        return this.componentProvider.chainGradient(this.originalGradient.getGradientLogDensity(), this.parameter);
    }

    public static enum ComponentProvider {
        MATRIX("matrix"){

            @Override
            Parameter getParameter(ScaledMatrixParameter scaledMatrixParameter) {
                return scaledMatrixParameter.getMatrixParameter();
            }

            @Override
            double[] chainGradient(double[] dArray, ScaledMatrixParameter scaledMatrixParameter) {
                int n = 0;
                int n2 = scaledMatrixParameter.getRowDimension();
                int n3 = scaledMatrixParameter.getColumnDimension();
                for (int i = 0; i < n3; ++i) {
                    double d = scaledMatrixParameter.getScaleParameter().getParameterValue(i);
                    for (int j = 0; j < n2; ++j) {
                        int n4 = n + j;
                        dArray[n4] = dArray[n4] * d;
                    }
                    n += n2;
                }
                return dArray;
            }
        }
        ,
        SCALE("scale"){

            @Override
            Parameter getParameter(ScaledMatrixParameter scaledMatrixParameter) {
                return scaledMatrixParameter.getScaleParameter();
            }

            @Override
            double[] chainGradient(double[] dArray, ScaledMatrixParameter scaledMatrixParameter) {
                int n = scaledMatrixParameter.getRowDimension();
                int n2 = scaledMatrixParameter.getColumnDimension();
                double[] dArray2 = new double[n2];
                int n3 = 0;
                for (int i = 0; i < n2; ++i) {
                    for (int j = 0; j < n; ++j) {
                        int n4 = i;
                        dArray2[n4] = dArray2[n4] + dArray[n3 + j] * scaledMatrixParameter.getMatrixParameter().getParameterValue(j, i);
                    }
                    n3 += n;
                }
                return dArray2;
            }
        };

        public final String name;

        private ComponentProvider(String string2) {
            this.name = string2;
        }

        abstract Parameter getParameter(ScaledMatrixParameter var1);

        abstract double[] chainGradient(double[] var1, ScaledMatrixParameter var2);

        public int getDimension(ScaledMatrixParameter scaledMatrixParameter) {
            return this.getParameter(scaledMatrixParameter).getDimension();
        }
    }
}

