/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.io.Serializable;
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.RandomizableSingleClassifierEnhancer;
import weka.classifiers.functions.Logistic;
import weka.classifiers.rules.ZeroR;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.Range;
import weka.core.RevisionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.MakeIndicator;
import weka.filters.unsupervised.instance.RemoveWithValues;

public class MultiClassClassifier
extends RandomizableSingleClassifierEnhancer
implements OptionHandler {
    static final long serialVersionUID = -3879602011542849141L;
    private Classifier[] m_Classifiers;
    private boolean m_pairwiseCoupling = false;
    private double[] m_SumOfWeights;
    private Filter[] m_ClassFilters;
    private ZeroR m_ZeroR;
    private Attribute m_ClassAttribute;
    private Instances m_TwoClassDataset;
    private double m_RandomWidthFactor = 2.0;
    private int m_Method = 0;
    public static final int METHOD_1_AGAINST_ALL = 0;
    public static final int METHOD_ERROR_RANDOM = 1;
    public static final int METHOD_ERROR_EXHAUSTIVE = 2;
    public static final int METHOD_1_AGAINST_1 = 3;
    public static final Tag[] TAGS_METHOD = new Tag[]{new Tag(0, "1-against-all"), new Tag(1, "Random correction code"), new Tag(2, "Exhaustive correction code"), new Tag(3, "1-against-1")};

    public MultiClassClassifier() {
        this.m_Classifier = new Logistic();
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.functions.Logistic";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.NOMINAL_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances insts) throws Exception {
        this.getCapabilities().testWithFail(insts);
        insts = new Instances(insts);
        insts.deleteWithMissingClass();
        if (this.m_Classifier == null) {
            throw new Exception("No base classifier has been set!");
        }
        this.m_ZeroR = new ZeroR();
        this.m_ZeroR.buildClassifier(insts);
        this.m_TwoClassDataset = null;
        int numClassifiers = insts.numClasses();
        if (numClassifiers <= 2) {
            this.m_Classifiers = Classifier.makeCopies(this.m_Classifier, 1);
            this.m_Classifiers[0].buildClassifier(insts);
            this.m_ClassFilters = null;
        } else if (this.m_Method == 3) {
            FastVector pairs = new FastVector();
            int i = 0;
            while (i < insts.numClasses()) {
                int j = 0;
                while (j < insts.numClasses()) {
                    if (j > i) {
                        int[] pair = new int[]{i, j};
                        pairs.addElement(pair);
                    }
                    ++j;
                }
                ++i;
            }
            numClassifiers = pairs.size();
            this.m_Classifiers = Classifier.makeCopies(this.m_Classifier, numClassifiers);
            this.m_ClassFilters = new Filter[numClassifiers];
            this.m_SumOfWeights = new double[numClassifiers];
            i = 0;
            while (i < numClassifiers) {
                RemoveWithValues classFilter = new RemoveWithValues();
                classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
                classFilter.setModifyHeader(true);
                classFilter.setInvertSelection(true);
                classFilter.setNominalIndicesArr((int[])pairs.elementAt(i));
                Instances tempInstances = new Instances(insts, 0);
                tempInstances.setClassIndex(-1);
                classFilter.setInputFormat(tempInstances);
                Instances newInsts = Filter.useFilter(insts, classFilter);
                if (newInsts.numInstances() > 0) {
                    newInsts.setClassIndex(insts.classIndex());
                    this.m_Classifiers[i].buildClassifier(newInsts);
                    this.m_ClassFilters[i] = classFilter;
                    this.m_SumOfWeights[i] = newInsts.sumOfWeights();
                } else {
                    this.m_Classifiers[i] = null;
                    this.m_ClassFilters[i] = null;
                }
                ++i;
            }
            this.m_TwoClassDataset = new Instances(insts, 0);
            int classIndex = this.m_TwoClassDataset.classIndex();
            this.m_TwoClassDataset.setClassIndex(-1);
            this.m_TwoClassDataset.deleteAttributeAt(classIndex);
            FastVector classLabels = new FastVector();
            classLabels.addElement("class0");
            classLabels.addElement("class1");
            this.m_TwoClassDataset.insertAttributeAt(new Attribute("class", classLabels), classIndex);
            this.m_TwoClassDataset.setClassIndex(classIndex);
        } else {
            Code code = null;
            switch (this.m_Method) {
                case 2: {
                    code = new ExhaustiveCode(numClassifiers);
                    break;
                }
                case 1: {
                    code = new RandomCode(numClassifiers, (int)((double)numClassifiers * this.m_RandomWidthFactor), insts);
                    break;
                }
                case 0: {
                    code = new StandardCode(numClassifiers);
                    break;
                }
                default: {
                    throw new Exception("Unrecognized correction code type");
                }
            }
            numClassifiers = code.size();
            this.m_Classifiers = Classifier.makeCopies(this.m_Classifier, numClassifiers);
            this.m_ClassFilters = new MakeIndicator[numClassifiers];
            int i = 0;
            while (i < this.m_Classifiers.length) {
                this.m_ClassFilters[i] = new MakeIndicator();
                MakeIndicator classFilter = (MakeIndicator)this.m_ClassFilters[i];
                classFilter.setAttributeIndex("" + (insts.classIndex() + 1));
                classFilter.setValueIndices(code.getIndices(i));
                classFilter.setNumeric(false);
                classFilter.setInputFormat(insts);
                Instances newInsts = Filter.useFilter(insts, this.m_ClassFilters[i]);
                this.m_Classifiers[i].buildClassifier(newInsts);
                ++i;
            }
        }
        this.m_ClassAttribute = insts.classAttribute();
    }

    public double[] individualPredictions(Instance inst) throws Exception {
        double[] result = null;
        if (this.m_Classifiers.length == 1) {
            result = new double[]{this.m_Classifiers[0].distributionForInstance(inst)[1]};
        } else {
            result = new double[this.m_ClassFilters.length];
            int i = 0;
            while (i < this.m_ClassFilters.length) {
                if (this.m_Classifiers[i] != null) {
                    if (this.m_Method == 3) {
                        Instance tempInst = (Instance)inst.copy();
                        tempInst.setDataset(this.m_TwoClassDataset);
                        result[i] = this.m_Classifiers[i].distributionForInstance(tempInst)[1];
                    } else {
                        this.m_ClassFilters[i].input(inst);
                        this.m_ClassFilters[i].batchFinished();
                        result[i] = this.m_Classifiers[i].distributionForInstance(this.m_ClassFilters[i].output())[1];
                    }
                }
                ++i;
            }
        }
        return result;
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        if (this.m_Classifiers.length == 1) {
            return this.m_Classifiers[0].distributionForInstance(inst);
        }
        double[] probs = new double[inst.numClasses()];
        if (this.m_Method == 3) {
            double[][] r = new double[inst.numClasses()][inst.numClasses()];
            double[][] n = new double[inst.numClasses()][inst.numClasses()];
            int i = 0;
            while (i < this.m_ClassFilters.length) {
                if (this.m_Classifiers[i] != null) {
                    Instance tempInst = (Instance)inst.copy();
                    tempInst.setDataset(this.m_TwoClassDataset);
                    double[] current = this.m_Classifiers[i].distributionForInstance(tempInst);
                    Range range = new Range(((RemoveWithValues)this.m_ClassFilters[i]).getNominalIndices());
                    range.setUpper(this.m_ClassAttribute.numValues());
                    int[] pair = range.getSelection();
                    if (this.m_pairwiseCoupling && inst.numClasses() > 2) {
                        r[pair[0]][pair[1]] = current[0];
                        n[pair[0]][pair[1]] = this.m_SumOfWeights[i];
                    } else if (current[0] > current[1]) {
                        int n2 = pair[0];
                        probs[n2] = probs[n2] + 1.0;
                    } else {
                        int n3 = pair[1];
                        probs[n3] = probs[n3] + 1.0;
                    }
                }
                ++i;
            }
            if (this.m_pairwiseCoupling && inst.numClasses() > 2) {
                return MultiClassClassifier.pairwiseCoupling(n, r);
            }
        } else {
            int i = 0;
            while (i < this.m_ClassFilters.length) {
                this.m_ClassFilters[i].input(inst);
                this.m_ClassFilters[i].batchFinished();
                double[] current = this.m_Classifiers[i].distributionForInstance(this.m_ClassFilters[i].output());
                int j = 0;
                while (j < this.m_ClassAttribute.numValues()) {
                    if (((MakeIndicator)this.m_ClassFilters[i]).getValueRange().isInRange(j)) {
                        int n = j;
                        probs[n] = probs[n] + current[1];
                    } else {
                        int n = j;
                        probs[n] = probs[n] + current[0];
                    }
                    ++j;
                }
                ++i;
            }
        }
        if (Utils.gr(Utils.sum(probs), 0.0)) {
            Utils.normalize(probs);
            return probs;
        }
        return this.m_ZeroR.distributionForInstance(inst);
    }

    public String toString() {
        if (this.m_Classifiers == null) {
            return "MultiClassClassifier: No model built yet.";
        }
        StringBuffer text = new StringBuffer();
        text.append("MultiClassClassifier\n\n");
        int i = 0;
        while (i < this.m_Classifiers.length) {
            text.append("Classifier ").append(i + 1);
            if (this.m_Classifiers[i] != null) {
                if (this.m_ClassFilters != null && this.m_ClassFilters[i] != null) {
                    if (this.m_ClassFilters[i] instanceof RemoveWithValues) {
                        Range range = new Range(((RemoveWithValues)this.m_ClassFilters[i]).getNominalIndices());
                        range.setUpper(this.m_ClassAttribute.numValues());
                        int[] pair = range.getSelection();
                        text.append(", " + (pair[0] + 1) + " vs " + (pair[1] + 1));
                    } else if (this.m_ClassFilters[i] instanceof MakeIndicator) {
                        text.append(", using indicator values: ");
                        text.append(((MakeIndicator)this.m_ClassFilters[i]).getValueRange());
                    }
                }
                text.append('\n');
                text.append(String.valueOf(this.m_Classifiers[i].toString()) + "\n\n");
            } else {
                text.append(" Skipped (no training examples)\n");
            }
            ++i;
        }
        return text.toString();
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> vec = new Vector<Option>(4);
        vec.addElement(new Option("\tSets the method to use. Valid values are 0 (1-against-all),\n\t1 (random codes), 2 (exhaustive code), and 3 (1-against-1). (default 0)\n", "M", 1, "-M <num>"));
        vec.addElement(new Option("\tSets the multiplier when using random codes. (default 2.0)", "R", 1, "-R <num>"));
        vec.addElement(new Option("\tUse pairwise coupling (only has an effect for 1-against1)", "P", 0, "-P"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            vec.addElement((Option)enu.nextElement());
        }
        return vec.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String errorString = Utils.getOption('M', options);
        if (errorString.length() != 0) {
            this.setMethod(new SelectedTag(Integer.parseInt(errorString), TAGS_METHOD));
        } else {
            this.setMethod(new SelectedTag(0, TAGS_METHOD));
        }
        String rfactorString = Utils.getOption('R', options);
        if (rfactorString.length() != 0) {
            this.setRandomWidthFactor(new Double(rfactorString));
        } else {
            this.setRandomWidthFactor(2.0);
        }
        this.setUsePairwiseCoupling(Utils.getFlag('P', options));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[superOptions.length + 5];
        int current = 0;
        options[current++] = "-M";
        options[current++] = "" + this.m_Method;
        if (this.getUsePairwiseCoupling()) {
            options[current++] = "-P";
        }
        options[current++] = "-R";
        options[current++] = "" + this.m_RandomWidthFactor;
        System.arraycopy(superOptions, 0, options, current, superOptions.length);
        current += superOptions.length;
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String globalInfo() {
        return "A metaclassifier for handling multi-class datasets with 2-class classifiers. This classifier is also capable of applying error correcting output codes for increased accuracy.";
    }

    public String randomWidthFactorTipText() {
        return "Sets the width multiplier when using random codes. The number of codes generated will be thus number multiplied by the number of classes.";
    }

    public double getRandomWidthFactor() {
        return this.m_RandomWidthFactor;
    }

    public void setRandomWidthFactor(double newRandomWidthFactor) {
        this.m_RandomWidthFactor = newRandomWidthFactor;
    }

    public String methodTipText() {
        return "Sets the method to use for transforming the multi-class problem into several 2-class ones.";
    }

    public SelectedTag getMethod() {
        return new SelectedTag(this.m_Method, TAGS_METHOD);
    }

    public void setMethod(SelectedTag newMethod) {
        if (newMethod.getTags() == TAGS_METHOD) {
            this.m_Method = newMethod.getSelectedTag().getID();
        }
    }

    public void setUsePairwiseCoupling(boolean p) {
        this.m_pairwiseCoupling = p;
    }

    public boolean getUsePairwiseCoupling() {
        return this.m_pairwiseCoupling;
    }

    public String usePairwiseCouplingTipText() {
        return "Use pairwise coupling (only has an effect for 1-against-1).";
    }

    public static double[] pairwiseCoupling(double[][] n, double[][] r) {
        boolean changed;
        double[] p = new double[r.length];
        int i = 0;
        while (i < p.length) {
            p[i] = 1.0 / (double)p.length;
            ++i;
        }
        double[][] u = new double[r.length][r.length];
        int i2 = 0;
        while (i2 < r.length) {
            int j = i2 + 1;
            while (j < r.length) {
                u[i2][j] = 0.5;
                ++j;
            }
            ++i2;
        }
        double[] firstSum = new double[p.length];
        int i3 = 0;
        while (i3 < p.length) {
            int j = i3 + 1;
            while (j < p.length) {
                int n2 = i3;
                firstSum[n2] = firstSum[n2] + n[i3][j] * r[i3][j];
                int n3 = j;
                firstSum[n3] = firstSum[n3] + n[i3][j] * (1.0 - r[i3][j]);
                ++j;
            }
            ++i3;
        }
        do {
            changed = false;
            double[] secondSum = new double[p.length];
            int i4 = 0;
            while (i4 < p.length) {
                int j = i4 + 1;
                while (j < p.length) {
                    int n4 = i4;
                    secondSum[n4] = secondSum[n4] + n[i4][j] * u[i4][j];
                    int n5 = j;
                    secondSum[n5] = secondSum[n5] + n[i4][j] * (1.0 - u[i4][j]);
                    ++j;
                }
                ++i4;
            }
            i4 = 0;
            while (i4 < p.length) {
                if (firstSum[i4] == 0.0 || secondSum[i4] == 0.0) {
                    if (p[i4] > 0.0) {
                        changed = true;
                    }
                    p[i4] = 0.0;
                } else {
                    double factor = firstSum[i4] / secondSum[i4];
                    double pOld = p[i4];
                    int n6 = i4;
                    p[n6] = p[n6] * factor;
                    if (Math.abs(pOld - p[i4]) > 0.001) {
                        changed = true;
                    }
                }
                ++i4;
            }
            Utils.normalize(p);
            i4 = 0;
            while (i4 < r.length) {
                int j = i4 + 1;
                while (j < r.length) {
                    u[i4][j] = p[i4] / (p[i4] + p[j]);
                    ++j;
                }
                ++i4;
            }
        } while (changed);
        return p;
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.48 $");
    }

    public static void main(String[] argv) {
        MultiClassClassifier.runClassifier(new MultiClassClassifier(), argv);
    }

    private abstract class Code
    implements Serializable,
    RevisionHandler {
        static final long serialVersionUID = 418095077487120846L;
        protected boolean[][] m_Codebits;

        private Code() {
        }

        public int size() {
            return this.m_Codebits.length;
        }

        public String getIndices(int which) {
            StringBuffer sb = new StringBuffer();
            int i = 0;
            while (i < this.m_Codebits[which].length) {
                if (this.m_Codebits[which][i]) {
                    if (sb.length() != 0) {
                        sb.append(',');
                    }
                    sb.append(i + 1);
                }
                ++i;
            }
            return sb.toString();
        }

        public String toString() {
            StringBuffer sb = new StringBuffer();
            int i = 0;
            while (i < this.m_Codebits[0].length) {
                int j = 0;
                while (j < this.m_Codebits.length) {
                    sb.append(this.m_Codebits[j][i] ? " 1" : " 0");
                    ++j;
                }
                sb.append('\n');
                ++i;
            }
            return sb.toString();
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1.48 $");
        }
    }

    private class ExhaustiveCode
    extends Code {
        static final long serialVersionUID = 8090991039670804047L;

        public ExhaustiveCode(int numClasses) {
            int width = (int)Math.pow(2.0, numClasses - 1) - 1;
            this.m_Codebits = new boolean[width][numClasses];
            int j = 0;
            while (j < width) {
                this.m_Codebits[j][0] = true;
                ++j;
            }
            int i = 1;
            while (i < numClasses) {
                int skip = (int)Math.pow(2.0, numClasses - (i + 1));
                int j2 = 0;
                while (j2 < width) {
                    this.m_Codebits[j2][i] = j2 / skip % 2 != 0;
                    ++j2;
                }
                ++i;
            }
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1.48 $");
        }
    }

    private class RandomCode
    extends Code {
        static final long serialVersionUID = 4413410540703926563L;
        Random r = null;

        public RandomCode(int numClasses, int numCodes, Instances data) {
            this.r = data.getRandomNumberGenerator(MultiClassClassifier.this.m_Seed);
            numCodes = Math.max(2, numCodes);
            this.m_Codebits = new boolean[numCodes][numClasses];
            int i = 0;
            do {
                this.randomize();
            } while (!this.good() && i++ < 100);
        }

        private boolean good() {
            boolean[] ninClass = new boolean[this.m_Codebits[0].length];
            boolean[] ainClass = new boolean[this.m_Codebits[0].length];
            int i = 0;
            while (i < ainClass.length) {
                ainClass[i] = true;
                ++i;
            }
            i = 0;
            while (i < this.m_Codebits.length) {
                boolean ninCode = false;
                boolean ainCode = true;
                int j = 0;
                while (j < this.m_Codebits[i].length) {
                    boolean current = this.m_Codebits[i][j];
                    ninCode = ninCode || current;
                    ainCode = ainCode && current;
                    ninClass[j] = ninClass[j] || current;
                    ainClass[j] = ainClass[j] && current;
                    ++j;
                }
                if (!ninCode || ainCode) {
                    return false;
                }
                ++i;
            }
            int j = 0;
            while (j < ninClass.length) {
                if (!ninClass[j] || ainClass[j]) {
                    return false;
                }
                ++j;
            }
            return true;
        }

        private void randomize() {
            int i = 0;
            while (i < this.m_Codebits.length) {
                int j = 0;
                while (j < this.m_Codebits[i].length) {
                    double temp = this.r.nextDouble();
                    this.m_Codebits[i][j] = !(temp < 0.5);
                    ++j;
                }
                ++i;
            }
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1.48 $");
        }
    }

    private class StandardCode
    extends Code {
        static final long serialVersionUID = 3707829689461467358L;

        public StandardCode(int numClasses) {
            this.m_Codebits = new boolean[numClasses][numClasses];
            int i = 0;
            while (i < numClasses) {
                this.m_Codebits[i][i] = true;
                ++i;
            }
        }

        @Override
        public String getRevision() {
            return RevisionUtils.extract("$Revision: 1.48 $");
        }
    }
}

