/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.inference.distribution.NormalStatisticsProvider;
import dr.inference.distribution.shrinkage.BayesianBridgeLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.CompoundLikelihood;
import dr.inference.model.Likelihood;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.inference.operators.repeatedMeasures.MultiplicativeGammaGibbsHelper;
import dr.xml.Reportable;
import java.util.Arrays;

public class MatrixShrinkageLikelihood
extends AbstractModelLikelihood
implements GradientWrtParameterProvider,
NormalStatisticsProvider,
Reportable,
MultiplicativeGammaGibbsHelper {
    private final MatrixParameterInterface loadings;
    private final BayesianBridgeLikelihood[] rowPriors;
    private final double[] gradientLogDensity;
    private final CompoundLikelihood likelihood;

    public MatrixShrinkageLikelihood(String string, MatrixParameterInterface matrixParameterInterface, BayesianBridgeLikelihood[] bayesianBridgeLikelihoodArray) {
        super(string);
        this.loadings = matrixParameterInterface;
        this.rowPriors = bayesianBridgeLikelihoodArray;
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : bayesianBridgeLikelihoodArray) {
            this.addModel(bayesianBridgeLikelihood);
        }
        this.gradientLogDensity = new double[matrixParameterInterface.getDimension()];
        this.addVariable(matrixParameterInterface);
        this.likelihood = new CompoundLikelihood(Arrays.asList(bayesianBridgeLikelihoodArray));
    }

    @Override
    public Likelihood getLikelihood() {
        return this;
    }

    public BayesianBridgeLikelihood getLikelihood(int n) {
        return this.rowPriors[n];
    }

    @Override
    public Parameter getParameter() {
        return this.loadings;
    }

    @Override
    public int getDimension() {
        return this.loadings.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        for (int i = 0; i < this.rowPriors.length; ++i) {
            double[] dArray = this.rowPriors[i].getGradientLogDensity();
            int n = i * this.loadings.getRowDimension();
            for (int j = 0; j < this.loadings.getRowDimension(); ++j) {
                this.gradientLogDensity[j + n] = dArray[j];
            }
        }
        return this.gradientLogDensity;
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : this.rowPriors) {
            bayesianBridgeLikelihood.handleModelChangedEvent(model, object, n);
        }
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : this.rowPriors) {
            bayesianBridgeLikelihood.handleVariableChangedEvent(variable, n, changeType);
        }
    }

    @Override
    protected void storeState() {
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : this.rowPriors) {
            bayesianBridgeLikelihood.storeState();
        }
    }

    @Override
    protected void restoreState() {
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : this.rowPriors) {
            bayesianBridgeLikelihood.restoreState();
        }
    }

    @Override
    protected void acceptState() {
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : this.rowPriors) {
            bayesianBridgeLikelihood.acceptState();
        }
    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        return this.likelihood.getLogLikelihood();
    }

    @Override
    public void makeDirty() {
        this.likelihood.makeDirty();
    }

    @Override
    public double getNormalMean(int n) {
        return 0.0;
    }

    @Override
    public double getNormalSD(int n) {
        int n2 = n / this.loadings.getRowDimension();
        int n3 = n - n2 * this.loadings.getRowDimension();
        double d = this.rowPriors[n2].getGlobalScale().getParameterValue(0);
        double d2 = this.rowPriors[n2].getLocalScale().getParameterValue(n3);
        return d * d2;
    }

    @Override
    public String getReport() {
        StringBuilder stringBuilder = new StringBuilder("MatrixShrinkageLikelihood\n");
        int n = 0;
        for (BayesianBridgeLikelihood bayesianBridgeLikelihood : this.rowPriors) {
            stringBuilder.append("\tLikelihood " + ++n + ": ");
            stringBuilder.append(bayesianBridgeLikelihood.getLogLikelihood());
            stringBuilder.append("\n");
        }
        stringBuilder.append("Likelihood: ");
        stringBuilder.append(this.getLogLikelihood());
        stringBuilder.append("\n");
        return stringBuilder.toString();
    }

    @Override
    public double computeSumSquaredErrors(int n) {
        int n2 = this.loadings.getRowDimension();
        double d = 0.0;
        for (int i = 0; i < n2; ++i) {
            double d2 = this.getLikelihood(n).getLocalScale().getParameterValue(i);
            double d3 = this.loadings.getParameterValue(i, n);
            double d4 = d3 / d2;
            d += d4 * d4;
        }
        return d;
    }

    @Override
    public int getRowDimension() {
        return this.loadings.getRowDimension();
    }

    @Override
    public int getColumnDimension() {
        return this.loadings.getColumnDimension();
    }
}

