/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.antigenic;

import dr.evomodel.antigenic.AntigenicGradientWrtParameter;
import dr.evomodel.antigenic.NewAntigenicLikelihood;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.CompoundParameter;
import dr.inference.model.Likelihood;
import dr.inference.model.Model;
import dr.inference.model.ModelListener;
import dr.inference.model.Parameter;
import dr.inference.multidimensionalscaling.MultiDimensionalScalingCore;
import dr.xml.Reportable;
import java.util.List;

public class AntigenicLikelihoodGradient
implements ModelListener,
GradientWrtParameterProvider,
Reportable {
    private final NewAntigenicLikelihood likelihood;
    private final MultiDimensionalScalingCore mdsCore;
    private final List<AntigenicGradientWrtParameter> wrtList;
    private final int numViruses;
    private final int numSera;
    private final int mdsDim;
    private final Parameter parameter;
    private boolean locationGradientKnown;
    private boolean observationGradientKnown;
    private double[] locationGradient;
    private double[] observationGradient;
    private static final double tolerance = 0.001;

    public AntigenicLikelihoodGradient(NewAntigenicLikelihood newAntigenicLikelihood, List<AntigenicGradientWrtParameter> list) {
        this.likelihood = newAntigenicLikelihood;
        this.mdsCore = newAntigenicLikelihood.getCore();
        this.wrtList = list;
        this.numViruses = newAntigenicLikelihood.getNumberOfViruses();
        this.numSera = newAntigenicLikelihood.getNumberOfSera();
        this.mdsDim = newAntigenicLikelihood.getMdsDimension();
        newAntigenicLikelihood.addModelListener(this);
        newAntigenicLikelihood.addModelRestoreListener(this);
        if (list.size() == 1) {
            this.parameter = list.get(0).getParameter();
        } else {
            CompoundParameter compoundParameter = new CompoundParameter("AntigenicLikelihoodGradient");
            for (AntigenicGradientWrtParameter antigenicGradientWrtParameter : list) {
                compoundParameter.addParameter(antigenicGradientWrtParameter.getParameter());
            }
            this.parameter = compoundParameter;
        }
        this.locationGradientKnown = false;
        this.observationGradientKnown = false;
    }

    @Override
    public void modelChangedEvent(Model model, Object object, int n) {
        if (model != this.likelihood) {
            throw new IllegalArgumentException("Unknown model");
        }
        this.locationGradientKnown = false;
        this.observationGradientKnown = false;
    }

    @Override
    public void modelRestored(Model model) {
        this.locationGradientKnown = false;
        this.observationGradientKnown = false;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.parameter.getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        this.likelihood.updateParametersOnDevice();
        if (!this.locationGradientKnown && this.requiresLocationGradient()) {
            this.getLocationGradients();
            this.locationGradientKnown = true;
        }
        if (!this.observationGradientKnown && this.requiresObservationGradient()) {
            this.getObservationGradients();
            this.observationGradientKnown = true;
        }
        double[] dArray = new double[this.getGradientSize()];
        int n = 0;
        for (AntigenicGradientWrtParameter antigenicGradientWrtParameter : this.wrtList) {
            antigenicGradientWrtParameter.getGradient(dArray, n, this.locationGradient, this.observationGradient);
            n += antigenicGradientWrtParameter.getSize();
        }
        return dArray;
    }

    @Override
    public String getReport() {
        return GradientWrtParameterProvider.getReportAndCheckForError(this, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0.001);
    }

    private boolean requiresLocationGradient() {
        for (AntigenicGradientWrtParameter antigenicGradientWrtParameter : this.wrtList) {
            if (!antigenicGradientWrtParameter.requiresLocationGradient()) continue;
            return true;
        }
        return false;
    }

    private boolean requiresObservationGradient() {
        for (AntigenicGradientWrtParameter antigenicGradientWrtParameter : this.wrtList) {
            if (!antigenicGradientWrtParameter.requiresObservationGradient()) continue;
            return true;
        }
        return false;
    }

    private int getGradientSize() {
        int n = 0;
        for (AntigenicGradientWrtParameter antigenicGradientWrtParameter : this.wrtList) {
            n += antigenicGradientWrtParameter.getSize();
        }
        return n;
    }

    private void getLocationGradients() {
        if (this.locationGradient == null) {
            this.locationGradient = new double[(this.numViruses + this.numSera) * this.mdsDim];
        }
        this.mdsCore.getLocationGradient(this.locationGradient);
    }

    private void getObservationGradients() {
        if (this.observationGradient == null) {
            this.observationGradient = new double[this.numViruses * this.numSera];
        }
        this.mdsCore.getObservationGradient(this.observationGradient);
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }
}

