/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.util.Taxon;
import dr.evomodel.branchratemodel.AbstractBranchRateModel;
import dr.evomodel.branchratemodel.ArbitraryBranchRates;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.DifferentiableBranchRates;
import dr.evomodel.branchratemodel.NodeRateMap;
import dr.evomodel.tree.TreeModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;
import java.util.function.DoubleBinaryOperator;

public class FixedReferenceRates
extends AbstractBranchRateModel
implements DifferentiableBranchRates {
    private static final String FIXED_REFERENCE_RATES = "fixedReferenceRates";
    private static final String FIXED_LENGTH = "fixedLength";
    private final TreeModel treeModel;
    private final DifferentiableBranchRates differentiableBranchRateModel;
    private final Taxon referenceTaxon;
    private NodeRef oneNode;
    private boolean nodeKnown = false;
    private boolean storedNodeKnown;
    private NodeRef storedOneNode;
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{new ElementRule(TreeModel.class), new ElementRule(BranchRateModel.class), new ElementRule(Taxon.class), AttributeRule.newStringRule("fixedLength", true)};

        @Override
        public String getParserName() {
            return FixedReferenceRates.FIXED_REFERENCE_RATES;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeModel treeModel = (TreeModel)xMLObject.getChild(TreeModel.class);
            BranchRateModel branchRateModel = (BranchRateModel)xMLObject.getChild(BranchRateModel.class);
            Taxon taxon = (Taxon)xMLObject.getChild(Taxon.class);
            int n = xMLObject.getAttribute(FixedReferenceRates.FIXED_LENGTH, 0);
            return new FixedReferenceRates(FixedReferenceRates.FIXED_REFERENCE_RATES, treeModel, branchRateModel, taxon, n);
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        @Override
        public String getParserDescription() {
            return "Fixes ancestral off-root branch (and optional addnl branches) to 1 with reference to a user-specified taxon.";
        }

        @Override
        public Class getReturnType() {
            return FixedReferenceRates.class;
        }
    };

    public FixedReferenceRates(String string, TreeModel treeModel, BranchRateModel branchRateModel, Taxon taxon, int n) {
        super(string);
        this.treeModel = treeModel;
        this.referenceTaxon = taxon;
        this.differentiableBranchRateModel = branchRateModel instanceof DifferentiableBranchRates ? (DifferentiableBranchRates)((Object)branchRateModel) : null;
        this.checkDifferentiability();
        this.updateNodeList(treeModel, this.referenceTaxon);
        this.addModel(treeModel);
        this.addModel(branchRateModel);
    }

    @Override
    public double getUntransformedBranchRate(Tree tree, NodeRef nodeRef) {
        if (!this.nodeKnown) {
            this.updateNodeList(tree, this.referenceTaxon);
        }
        if (nodeRef.getNumber() == this.oneNode.getNumber()) {
            return 1.0;
        }
        return this.differentiableBranchRateModel.getUntransformedBranchRate(tree, nodeRef);
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        if (!this.nodeKnown) {
            this.updateNodeList(tree, this.referenceTaxon);
        }
        if (nodeRef.getNumber() == this.oneNode.getNumber()) {
            return 1.0;
        }
        return this.differentiableBranchRateModel.getBranchRate(tree, nodeRef);
    }

    private void updateNodeList(Tree tree, Taxon taxon) {
        this.nodeKnown = true;
        int n = tree.getTaxonIndex(taxon.getId());
        NodeRef nodeRef = tree.getNode(n);
        NodeRef nodeRef2 = tree.getRoot();
        while (tree.getParent(nodeRef) != nodeRef2) {
            nodeRef = tree.getParent(nodeRef);
        }
        this.oneNode = nodeRef;
    }

    @Override
    public Tree getTree() {
        return this.treeModel;
    }

    @Override
    public double getBranchRateDifferential(Tree tree, NodeRef nodeRef) {
        if (nodeRef.getNumber() == this.oneNode.getNumber()) {
            return 0.0;
        }
        return this.differentiableBranchRateModel.getBranchRateDifferential(tree, nodeRef);
    }

    @Override
    public double getBranchRateSecondDifferential(Tree tree, NodeRef nodeRef) {
        return this.differentiableBranchRateModel.getBranchRateSecondDifferential(tree, nodeRef);
    }

    @Override
    public Parameter getRateParameter() {
        return this.differentiableBranchRateModel.getRateParameter();
    }

    @Override
    public int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.differentiableBranchRateModel.getParameterIndexFromNode(nodeRef);
    }

    private void checkDifferentiability() {
        if (this.differentiableBranchRateModel == null) {
            throw new RuntimeException("Non-differentiable base BranchRateModel");
        }
    }

    @Override
    public ArbitraryBranchRates.BranchRateTransform getTransform() {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double[] updateGradientLogDensity(double[] dArray, double[] dArray2, int n, int n2) {
        return this.differentiableBranchRateModel.updateGradientLogDensity(dArray, dArray2, n, n2);
    }

    @Override
    public double[] updateDiagonalHessianLogDensity(double[] dArray, double[] dArray2, double[] dArray3, int n, int n2) {
        throw new RuntimeException("Not yet implemented");
    }

    @Override
    public double mapReduceOverRates(NodeRateMap nodeRateMap, DoubleBinaryOperator doubleBinaryOperator, double d) {
        return this.differentiableBranchRateModel.mapReduceOverRates(nodeRateMap, doubleBinaryOperator, d);
    }

    @Override
    public void forEachOverRates(NodeRateMap nodeRateMap) {
        this.differentiableBranchRateModel.forEachOverRates(nodeRateMap);
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.differentiableBranchRateModel) {
            this.fireModelChanged();
        } else if (model == this.treeModel) {
            this.nodeKnown = false;
        } else {
            throw new RuntimeException("Should only watch branchRates or treeModel");
        }
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        throw new RuntimeException("Should not be variable changed event");
    }

    @Override
    protected void storeState() {
        this.storedNodeKnown = this.nodeKnown;
        this.storedOneNode = this.oneNode;
    }

    @Override
    protected void restoreState() {
        this.nodeKnown = this.storedNodeKnown;
        this.oneNode = this.storedOneNode;
    }

    @Override
    protected void acceptState() {
    }
}

