/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.model.MatrixParameterInterface;
import dr.inference.model.Parameter;
import dr.inference.operators.AdaptationMode;
import dr.inference.operators.hmc.HamiltonianMonteCarloOperator;
import dr.inference.operators.hmc.MassPreconditioner;
import dr.math.matrixAlgebra.EJMLUtils;
import dr.math.matrixAlgebra.Matrix;
import dr.math.matrixAlgebra.SkewSymmetricMatrixExponential;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import dr.xml.Reportable;
import java.util.ArrayList;
import java.util.Collections;
import org.ejml.alg.dense.decomposition.TriangularSolver;
import org.ejml.data.DenseMatrix64F;
import org.ejml.factory.DecompositionFactory;
import org.ejml.interfaces.decomposition.CholeskyDecomposition;
import org.ejml.ops.CommonOps;

public class GeodesicHamiltonianMonteCarloOperator
extends HamiltonianMonteCarloOperator
implements Reportable {
    public GeodesicHamiltonianMonteCarloOperator(AdaptationMode adaptationMode, double d, GradientWrtParameterProvider gradientWrtParameterProvider, Parameter parameter, Transform transform, Parameter parameter2, HamiltonianMonteCarloOperator.Options options, MassPreconditioner massPreconditioner) {
        super(adaptationMode, d, gradientWrtParameterProvider, parameter, transform, parameter2, options, massPreconditioner);
        this.leapFrogEngine = new GeodesicLeapFrogEngine(parameter, this.getDefaultInstabilityHandler(), this.preconditioning, this.mask);
    }

    @Override
    public String getOperatorName() {
        return "GeodesicHMC(" + this.parameter.getParameterName() + ")";
    }

    @Override
    public String getReport() {
        double d;
        MatrixParameterInterface matrixParameterInterface = (MatrixParameterInterface)this.parameter;
        int n = matrixParameterInterface.getColumnDimension();
        int n2 = matrixParameterInterface.getRowDimension();
        StringBuilder stringBuilder = new StringBuilder("operator: geodesicHamiltonianMonteCarloOperator");
        stringBuilder.append("\n");
        stringBuilder.append("\toriginal position:\n");
        Matrix matrix = new Matrix(matrixParameterInterface.getParameterAsMatrix());
        stringBuilder.append(matrix.toString(2));
        double[] dArray = new double[this.parameter.getDimension()];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = i;
        }
        Matrix matrix2 = new Matrix(n2, n);
        for (int i = 0; i < n2; ++i) {
            for (int j = 0; j < n; ++j) {
                matrix2.set(i, j, dArray[i + j * n2]);
            }
        }
        stringBuilder.append("\toriginal momentum (unprojected):\n");
        stringBuilder.append(matrix2.toString(2));
        WrappedVector.Raw raw = new WrappedVector.Raw(dArray);
        try {
            d = this.leapFrogGivenMomentum(raw);
        }
        catch (HamiltonianMonteCarloOperator.NumericInstabilityException numericInstabilityException) {
            numericInstabilityException.printStackTrace();
            throw new RuntimeException("HMC failed");
        }
        Matrix matrix3 = new Matrix(matrixParameterInterface.getParameterAsMatrix());
        stringBuilder.append("\n");
        stringBuilder.append("\tfinal position:\n");
        stringBuilder.append(matrix3.toString(2));
        stringBuilder.append("\n");
        stringBuilder.append("\thastings ratio: " + d + "\n\n");
        return stringBuilder.toString();
    }

    public void setOrthogonalityStructure(ArrayList<ArrayList<Integer>> arrayList) {
        ((GeodesicLeapFrogEngine)this.leapFrogEngine).setOrthogonalityStructure(arrayList);
    }

    public static class GeodesicLeapFrogEngine
    extends HamiltonianMonteCarloOperator.LeapFrogEngine.Default {
        private final MatrixParameterInterface matrixParameter;
        private final ArrayList<ArrayList<Integer>> orthogonalityStructure;
        private final ArrayList<ArrayList<Integer>> orthogonalityBlockRows;

        GeodesicLeapFrogEngine(Parameter parameter, HamiltonianMonteCarloOperator.InstabilityHandler instabilityHandler, MassPreconditioner massPreconditioner, double[] dArray) {
            super(parameter, instabilityHandler, massPreconditioner, dArray);
            this.matrixParameter = (MatrixParameterInterface)parameter;
            this.orthogonalityStructure = new ArrayList();
            this.orthogonalityBlockRows = new ArrayList();
            if (dArray == null) {
                ArrayList<Integer> arrayList = new ArrayList<Integer>();
                for (int i = 0; i < this.matrixParameter.getRowDimension(); ++i) {
                    arrayList.add(i);
                }
                ArrayList<Integer> arrayList2 = new ArrayList<Integer>();
                for (int i = 0; i < this.matrixParameter.getColumnDimension(); ++i) {
                    arrayList2.add(i);
                }
                this.orthogonalityStructure.add(arrayList2);
                this.orthogonalityBlockRows.add(arrayList);
            } else {
                this.parseStructureFromMask(dArray);
            }
        }

        private void parseStructureFromMask(double[] dArray) {
            int n = this.matrixParameter.getRowDimension();
            int n2 = this.matrixParameter.getColumnDimension();
            ArrayList<Integer> arrayList = new ArrayList<Integer>();
            for (int i = 0; i < n2; ++i) {
                int n3;
                arrayList.clear();
                int n4 = i * n;
                for (n3 = 0; n3 < n; ++n3) {
                    if (dArray[n4 + n3] != 1.0) continue;
                    arrayList.add(n3);
                }
                if (arrayList.isEmpty()) continue;
                n3 = this.findMatchingArray(this.orthogonalityBlockRows, arrayList);
                if (n3 == -1) {
                    ArrayList<Integer> arrayList2 = new ArrayList<Integer>();
                    arrayList2.add(i);
                    this.orthogonalityStructure.add(arrayList2);
                    this.orthogonalityBlockRows.add(new ArrayList<Integer>(arrayList));
                    continue;
                }
                this.orthogonalityStructure.get(n3).add(i);
            }
        }

        private int findMatchingArray(ArrayList<ArrayList<Integer>> arrayList, ArrayList<Integer> arrayList2) {
            int n = arrayList.size();
            for (int i = 0; i < n; ++i) {
                ArrayList<Integer> arrayList3 = arrayList.get(i);
                boolean bl = true;
                if (arrayList2.size() != arrayList3.size()) continue;
                for (int j = 0; j < arrayList2.size(); ++j) {
                    if (arrayList2.get(j) == arrayList3.get(j)) continue;
                    bl = false;
                    break;
                }
                if (!bl) continue;
                return i;
            }
            return -1;
        }

        private int findSubArray(ArrayList<ArrayList<Integer>> arrayList, ArrayList<Integer> arrayList2, ArrayList<Integer> arrayList3) {
            int n = arrayList.size();
            for (int i = 0; i < n; ++i) {
                ArrayList<Integer> arrayList4 = arrayList.get(i);
                arrayList3.clear();
                if (arrayList2.size() > arrayList4.size()) continue;
                int n2 = 0;
                for (int j = 0; j < arrayList4.size(); ++j) {
                    if (n2 < arrayList2.size() && arrayList4.get(j) == arrayList2.get(n2)) {
                        ++n2;
                        continue;
                    }
                    arrayList3.add(arrayList4.get(j));
                }
                if (n2 != arrayList2.size()) continue;
                return i;
            }
            return -1;
        }

        public void setOrthogonalityStructure(ArrayList<ArrayList<Integer>> arrayList) {
            for (int i = 0; i < arrayList.size(); ++i) {
                ArrayList<Integer> arrayList2 = new ArrayList<Integer>();
                ArrayList<Integer> arrayList3 = arrayList.get(i);
                Collections.sort(arrayList3);
                int n = this.findSubArray(this.orthogonalityStructure, arrayList3, arrayList2);
                if (n == -1) {
                    throw new RuntimeException("Orthogonality structure incompatible with itself or mask.");
                }
                ArrayList<Integer> arrayList4 = this.orthogonalityStructure.get(n);
                if (arrayList2.size() <= 0) continue;
                this.orthogonalityStructure.set(n, arrayList2);
                this.orthogonalityStructure.add(arrayList3);
                this.orthogonalityBlockRows.add(this.orthogonalityBlockRows.get(n));
            }
        }

        private DenseMatrix64F setOrthogonalSubMatrix(double[] dArray, int n, int n2) {
            int n3 = this.matrixParameter.getRowDimension();
            ArrayList<Integer> arrayList = this.orthogonalityStructure.get(n2);
            ArrayList<Integer> arrayList2 = this.orthogonalityBlockRows.get(n2);
            int n4 = arrayList.size();
            int n5 = arrayList2.size();
            DenseMatrix64F denseMatrix64F = new DenseMatrix64F(n4, n5);
            for (int i = 0; i < n5; ++i) {
                for (int j = 0; j < n4; ++j) {
                    int n6 = n3 * arrayList.get(j) + arrayList2.get(i) + n;
                    denseMatrix64F.set(j, i, dArray[n6]);
                }
            }
            return denseMatrix64F;
        }

        private DenseMatrix64F setOrthogonalSubMatrix(double[] dArray, int n) {
            return this.setOrthogonalSubMatrix(dArray, 0, n);
        }

        private void unwrapSubMatrix(DenseMatrix64F denseMatrix64F, int n, double[] dArray, int n2) {
            int n3 = this.matrixParameter.getRowDimension();
            ArrayList<Integer> arrayList = this.orthogonalityStructure.get(n);
            ArrayList<Integer> arrayList2 = this.orthogonalityBlockRows.get(n);
            for (int i = 0; i < arrayList2.size(); ++i) {
                for (int j = 0; j < arrayList.size(); ++j) {
                    int n4 = n3 * arrayList.get(j) + arrayList2.get(i) + n2;
                    dArray[n4] = denseMatrix64F.get(j, i);
                }
            }
        }

        private void unwrapSubMatrix(DenseMatrix64F denseMatrix64F, int n, double[] dArray) {
            this.unwrapSubMatrix(denseMatrix64F, n, dArray, 0);
        }

        @Override
        public void updateMomentum(double[] dArray, double[] dArray2, double[] dArray3, double d) throws HamiltonianMonteCarloOperator.NumericInstabilityException {
            super.updateMomentum(dArray, dArray2, dArray3, d);
            this.projectMomentum(dArray2, dArray);
        }

        @Override
        public void updatePosition(double[] dArray, WrappedVector wrappedVector, double d) throws HamiltonianMonteCarloOperator.NumericInstabilityException {
            for (int i = 0; i < this.orthogonalityStructure.size(); ++i) {
                int n;
                int n2 = this.orthogonalityStructure.get(i).size();
                int n3 = this.orthogonalityBlockRows.get(i).size();
                DenseMatrix64F denseMatrix64F = this.setOrthogonalSubMatrix(dArray, i);
                DenseMatrix64F denseMatrix64F2 = this.setOrthogonalSubMatrix(wrappedVector.getBuffer(), wrappedVector.getOffset(), i);
                DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(n2, n2);
                DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(n2, n2);
                CommonOps.multTransB(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
                CommonOps.multTransB(denseMatrix64F2, denseMatrix64F2, denseMatrix64F4);
                double[][] dArray2 = new double[2 * n2][2 * n2];
                for (int j = 0; j < n2; ++j) {
                    dArray2[j + n2][j] = 1.0;
                    for (int k = 0; k < n2; ++k) {
                        dArray2[j][k] = denseMatrix64F3.get(j, k);
                        dArray2[j + n2][k + n2] = denseMatrix64F3.get(j, k);
                        dArray2[j][k + n2] = -denseMatrix64F4.get(k, j);
                    }
                }
                double[] dArray3 = new double[n2 * n2];
                CommonOps.scale(-d, denseMatrix64F3);
                SkewSymmetricMatrixExponential skewSymmetricMatrixExponential = new SkewSymmetricMatrixExponential(n2);
                skewSymmetricMatrixExponential.exponentiate(denseMatrix64F3.data, dArray3);
                double[] dArray4 = new double[n2 * n2 * 4];
                SkewSymmetricMatrixExponential skewSymmetricMatrixExponential2 = new SkewSymmetricMatrixExponential(n2 * 2);
                DenseMatrix64F denseMatrix64F5 = new DenseMatrix64F(dArray2);
                CommonOps.scale(d, denseMatrix64F5);
                skewSymmetricMatrixExponential2.exponentiate(denseMatrix64F5.data, dArray4);
                DenseMatrix64F denseMatrix64F6 = new DenseMatrix64F(n2 * 2, n2 * 2);
                DenseMatrix64F denseMatrix64F7 = new DenseMatrix64F(n2 * 2, n2 * 2);
                for (int j = 0; j < n2; ++j) {
                    for (int k = 0; k < n2; ++k) {
                        denseMatrix64F6.set(j, k, dArray3[j * n2 + k]);
                        denseMatrix64F6.set(j + n2, k + n2, dArray3[j * n2 + k]);
                    }
                }
                denseMatrix64F7.setData(dArray4);
                DenseMatrix64F denseMatrix64F8 = new DenseMatrix64F(n2 * 2, n2 * 2);
                CommonOps.mult(denseMatrix64F7, denseMatrix64F6, denseMatrix64F8);
                DenseMatrix64F denseMatrix64F9 = new DenseMatrix64F(n2 * 2, n3);
                for (int j = 0; j < n3; ++j) {
                    for (n = 0; n < n2; ++n) {
                        denseMatrix64F9.set(n, j, denseMatrix64F.get(n, j));
                        denseMatrix64F9.set(n + n2, j, denseMatrix64F2.get(n, j));
                    }
                }
                DenseMatrix64F denseMatrix64F10 = new DenseMatrix64F(2 * n2, n3);
                CommonOps.transpose(denseMatrix64F8);
                CommonOps.mult(denseMatrix64F8, denseMatrix64F9, denseMatrix64F10);
                for (n = 0; n < n3; ++n) {
                    for (int j = 0; j < n2; ++j) {
                        denseMatrix64F.set(j, n, denseMatrix64F10.get(j, n));
                        denseMatrix64F2.set(j, n, denseMatrix64F10.get(j + n2, n));
                    }
                }
                CommonOps.multTransB(denseMatrix64F, denseMatrix64F, denseMatrix64F3);
                CholeskyDecomposition<DenseMatrix64F> choleskyDecomposition = DecompositionFactory.chol(n2, true);
                choleskyDecomposition.decompose(denseMatrix64F3);
                TriangularSolver.invertLower(denseMatrix64F3.data, n2);
                DenseMatrix64F denseMatrix64F11 = new DenseMatrix64F(n2, n3);
                CommonOps.mult(denseMatrix64F3, denseMatrix64F, denseMatrix64F11);
                double d2 = 0.0;
                for (int j = 0; j < denseMatrix64F.data.length; ++j) {
                    double d3 = denseMatrix64F11.data[j] - denseMatrix64F.data[j];
                    d2 += d3 * d3;
                }
                if (d2 / (double)dArray.length > 0.01) {
                    System.err.println("unstable");
                    throw new HamiltonianMonteCarloOperator.NumericInstabilityException();
                }
                System.arraycopy(denseMatrix64F11.data, 0, denseMatrix64F.data, 0, denseMatrix64F.data.length);
                this.unwrapSubMatrix(denseMatrix64F, i, dArray);
                this.unwrapSubMatrix(denseMatrix64F2, i, wrappedVector.getBuffer(), wrappedVector.getOffset());
            }
            this.matrixParameter.setAllParameterValuesQuietly(dArray, 0);
            this.matrixParameter.fireParameterChangedEvent();
        }

        @Override
        public void projectMomentum(double[] dArray, double[] dArray2) {
            for (int i = 0; i < this.orthogonalityStructure.size(); ++i) {
                DenseMatrix64F denseMatrix64F = this.setOrthogonalSubMatrix(dArray2, i);
                DenseMatrix64F denseMatrix64F2 = this.setOrthogonalSubMatrix(dArray, i);
                int n = this.orthogonalityStructure.get(i).size();
                int n2 = this.orthogonalityBlockRows.get(i).size();
                DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(n, n);
                CommonOps.multTransB(denseMatrix64F, denseMatrix64F2, denseMatrix64F3);
                EJMLUtils.addWithTransposed(denseMatrix64F3);
                DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(n, n2);
                CommonOps.mult(0.5, denseMatrix64F3, denseMatrix64F, denseMatrix64F4);
                CommonOps.subtractEquals(denseMatrix64F2, denseMatrix64F4);
                this.unwrapSubMatrix(denseMatrix64F2, i, dArray);
            }
        }
    }
}

