/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.lazy;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.LinearNN;
import weka.core.NearestNeighbourSearch;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class IBk
extends Classifier
implements OptionHandler,
UpdateableClassifier,
WeightedInstancesHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -3080186098777067172L;
    protected Instances m_Train;
    protected int m_NumClasses;
    protected int m_ClassType;
    protected int m_kNN;
    protected int m_kNNUpper;
    protected boolean m_kNNValid;
    protected int m_WindowSize;
    protected int m_DistanceWeighting;
    protected boolean m_CrossValidate;
    protected boolean m_MeanSquared;
    public static final int WEIGHT_NONE = 1;
    public static final int WEIGHT_INVERSE = 2;
    public static final int WEIGHT_SIMILARITY = 4;
    public static final Tag[] TAGS_WEIGHTING = new Tag[]{new Tag(1, "No distance weighting"), new Tag(2, "Weight by 1/distance"), new Tag(4, "Weight by 1-distance")};
    protected NearestNeighbourSearch m_NNSearch = new LinearNN();
    protected double m_NumAttributesUsed;

    public IBk(int n) {
        this.init();
        this.setKNN(n);
    }

    public IBk() {
        this.init();
    }

    public String globalInfo() {
        return "K-nearest neighbours classifier. Can select appropriate value of K based on cross-validation. Can also do distance weighting.\n\nFor more information, see\n\n" + this.getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "D. Aha and D. Kibler");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "1991");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Instance-based learning algorithms");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Machine Learning");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "6");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "37-66");
        return technicalInformation;
    }

    public String KNNTipText() {
        return "The number of neighbours to use.";
    }

    public void setKNN(int n) {
        this.m_kNN = n;
        this.m_kNNUpper = n;
        this.m_kNNValid = false;
    }

    public int getKNN() {
        return this.m_kNN;
    }

    public String windowSizeTipText() {
        return "Gets the maximum number of instances allowed in the training pool. The addition of new instances above this value will result in old instances being removed. A value of 0 signifies no limit to the number of training instances.";
    }

    public int getWindowSize() {
        return this.m_WindowSize;
    }

    public void setWindowSize(int n) {
        this.m_WindowSize = n;
    }

    public String distanceWeightingTipText() {
        return "Gets the distance weighting method used.";
    }

    public SelectedTag getDistanceWeighting() {
        return new SelectedTag(this.m_DistanceWeighting, TAGS_WEIGHTING);
    }

    public void setDistanceWeighting(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_WEIGHTING) {
            this.m_DistanceWeighting = selectedTag.getSelectedTag().getID();
        }
    }

    public String meanSquaredTipText() {
        return "Whether the mean squared error is used rather than mean absolute error when doing cross-validation for regression problems.";
    }

    public boolean getMeanSquared() {
        return this.m_MeanSquared;
    }

    public void setMeanSquared(boolean bl) {
        this.m_MeanSquared = bl;
    }

    public String crossValidateTipText() {
        return "Whether hold-one-out cross-validation will be used to select the best k value.";
    }

    public boolean getCrossValidate() {
        return this.m_CrossValidate;
    }

    public void setCrossValidate(boolean bl) {
        this.m_CrossValidate = bl;
    }

    public String nearestNeighbourSearchAlgorithmTipText() {
        return "The nearest neighbour search algorithm to use (Default: LinearNN).";
    }

    public NearestNeighbourSearch getNearestNeighbourSearchAlgorithm() {
        return this.m_NNSearch;
    }

    public void setNearestNeighbourSearchAlgorithm(NearestNeighbourSearch nearestNeighbourSearch) {
        this.m_NNSearch = nearestNeighbourSearch;
    }

    public int getNumTraining() {
        return this.m_Train.numInstances();
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.DATE_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        instances = new Instances(instances);
        instances.deleteWithMissingClass();
        this.m_NumClasses = instances.numClasses();
        this.m_ClassType = instances.classAttribute().type();
        this.m_Train = new Instances(instances, 0, instances.numInstances());
        if (this.m_WindowSize > 0 && instances.numInstances() > this.m_WindowSize) {
            this.m_Train = new Instances(this.m_Train, this.m_Train.numInstances() - this.m_WindowSize, this.m_WindowSize);
        }
        this.m_NumAttributesUsed = 0.0;
        for (int i = 0; i < this.m_Train.numAttributes(); ++i) {
            if (i == this.m_Train.classIndex() || !this.m_Train.attribute(i).isNominal() && !this.m_Train.attribute(i).isNumeric()) continue;
            this.m_NumAttributesUsed += 1.0;
        }
        this.m_NNSearch.setInstances(this.m_Train);
        this.m_kNNValid = false;
    }

    public void updateClassifier(Instance instance) throws Exception {
        if (!this.m_Train.equalHeaders(instance.dataset())) {
            throw new Exception("Incompatible instance types");
        }
        if (instance.classIsMissing()) {
            return;
        }
        this.m_Train.add(instance);
        this.m_NNSearch.update(instance);
        this.m_kNNValid = false;
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            boolean bl = false;
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
                bl = true;
            }
            if (bl) {
                this.m_NNSearch.setInstances(this.m_Train);
            }
        }
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_Train.numInstances() == 0) {
            throw new Exception("No training instances!");
        }
        if (this.m_WindowSize > 0 && this.m_Train.numInstances() > this.m_WindowSize) {
            this.m_kNNValid = false;
            boolean bl = false;
            while (this.m_Train.numInstances() > this.m_WindowSize) {
                this.m_Train.delete(0);
            }
            if (bl) {
                this.m_NNSearch.setInstances(this.m_Train);
            }
        }
        if (!this.m_kNNValid && this.m_CrossValidate && this.m_kNNUpper >= 1) {
            this.crossValidate();
        }
        this.m_NNSearch.addInstanceInfo(instance);
        Instances instances = this.m_NNSearch.kNearestNeighbours(instance, this.m_kNN);
        double[] dArray = this.m_NNSearch.getDistances();
        double[] dArray2 = this.makeDistribution(instances, dArray);
        return dArray2;
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>(8);
        vector.addElement(new Option("\tWeight neighbours by the inverse of their distance\n\t(use when k > 1)", "I", 0, "-I"));
        vector.addElement(new Option("\tWeight neighbours by 1 - their distance\n\t(use when k > 1)", "F", 0, "-F"));
        vector.addElement(new Option("\tNumber of nearest neighbours (k) used in classification.\n\t(Default = 1)", "K", 1, "-K <number of neighbors>"));
        vector.addElement(new Option("\tMinimise mean squared error rather than mean absolute\n\terror when using -X option with numeric prediction.", "E", 0, "-E"));
        vector.addElement(new Option("\tMaximum number of training instances maintained.\n\tTraining instances are dropped FIFO. (Default = no window)", "W", 1, "-W <window size>"));
        vector.addElement(new Option("\tSelect the number of nearest neighbours between 1\n\tand the k value specified using hold-one-out evaluation\n\ton the training data (use when k > 1)", "X", 0, "-X"));
        vector.addElement(new Option("\tThe nearest neighbour search algorithm to use (default: LinearNN).\n", "A", 0, "-A"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string = Utils.getOption('K', stringArray);
        if (string.length() != 0) {
            this.setKNN(Integer.parseInt(string));
        } else {
            this.setKNN(1);
        }
        String string2 = Utils.getOption('W', stringArray);
        if (string2.length() != 0) {
            this.setWindowSize(Integer.parseInt(string2));
        } else {
            this.setWindowSize(0);
        }
        if (Utils.getFlag('I', stringArray)) {
            this.setDistanceWeighting(new SelectedTag(2, TAGS_WEIGHTING));
        } else if (Utils.getFlag('F', stringArray)) {
            this.setDistanceWeighting(new SelectedTag(4, TAGS_WEIGHTING));
        } else {
            this.setDistanceWeighting(new SelectedTag(1, TAGS_WEIGHTING));
        }
        this.setCrossValidate(Utils.getFlag('X', stringArray));
        this.setMeanSquared(Utils.getFlag('E', stringArray));
        String string3 = Utils.getOption('A', stringArray);
        if (string3.length() != 0) {
            String[] stringArray2 = Utils.splitOptions(string3);
            if (stringArray2.length == 0) {
                throw new Exception("Invalid NearestNeighbourSearch algorithm specification string.");
            }
            String string4 = stringArray2[0];
            stringArray2[0] = "";
            this.setNearestNeighbourSearchAlgorithm((NearestNeighbourSearch)Utils.forName(NearestNeighbourSearch.class, string4, stringArray2));
        } else {
            this.setNearestNeighbourSearchAlgorithm(new LinearNN());
        }
        Utils.checkForRemainingOptions(stringArray);
    }

    public String[] getOptions() {
        String[] stringArray = new String[11];
        int n = 0;
        stringArray[n++] = "-K";
        stringArray[n++] = "" + this.getKNN();
        stringArray[n++] = "-W";
        stringArray[n++] = "" + this.m_WindowSize;
        if (this.getCrossValidate()) {
            stringArray[n++] = "-X";
        }
        if (this.getMeanSquared()) {
            stringArray[n++] = "-E";
        }
        if (this.m_DistanceWeighting == 2) {
            stringArray[n++] = "-I";
        } else if (this.m_DistanceWeighting == 4) {
            stringArray[n++] = "-F";
        }
        stringArray[n++] = "-A";
        stringArray[n++] = this.m_NNSearch.getClass().getName() + " " + Utils.joinOptions(this.m_NNSearch.getOptions());
        while (n < stringArray.length) {
            stringArray[n++] = "";
        }
        return stringArray;
    }

    public String toString() {
        if (this.m_Train == null) {
            return "IBk: No model built yet.";
        }
        if (!this.m_kNNValid && this.m_CrossValidate) {
            this.crossValidate();
        }
        String string = "IB1 instance-based classifier\nusing " + this.m_kNN;
        switch (this.m_DistanceWeighting) {
            case 2: {
                string = string + " inverse-distance-weighted";
                break;
            }
            case 4: {
                string = string + " similarity-weighted";
            }
        }
        string = string + " nearest neighbour(s) for classification\n";
        if (this.m_WindowSize != 0) {
            string = string + "using a maximum of " + this.m_WindowSize + " (windowed) training instances\n";
        }
        return string;
    }

    protected void init() {
        this.setKNN(1);
        this.m_WindowSize = 0;
        this.m_DistanceWeighting = 1;
        this.m_CrossValidate = false;
        this.m_MeanSquared = false;
    }

    protected double[] makeDistribution(Instances instances, double[] dArray) throws Exception {
        int n;
        double d = 0.0;
        double[] dArray2 = new double[this.m_NumClasses];
        if (this.m_ClassType == 1) {
            for (n = 0; n < this.m_NumClasses; ++n) {
                dArray2[n] = 1.0 / (double)Math.max(1, this.m_Train.numInstances());
            }
            d = (double)this.m_NumClasses / (double)Math.max(1, this.m_Train.numInstances());
        }
        for (n = 0; n < instances.numInstances(); ++n) {
            double d2;
            Instance instance = instances.instance(n);
            dArray[n] = dArray[n] * dArray[n];
            dArray[n] = Math.sqrt(dArray[n] / this.m_NumAttributesUsed);
            switch (this.m_DistanceWeighting) {
                case 2: {
                    d2 = 1.0 / (dArray[n] + 0.001);
                    break;
                }
                case 4: {
                    d2 = 1.0 - dArray[n];
                    break;
                }
                default: {
                    d2 = 1.0;
                }
            }
            d2 *= instance.weight();
            try {
                switch (this.m_ClassType) {
                    case 1: {
                        int n2 = (int)instance.classValue();
                        dArray2[n2] = dArray2[n2] + d2;
                        break;
                    }
                    case 0: {
                        dArray2[0] = dArray2[0] + instance.classValue() * d2;
                    }
                }
            }
            catch (Exception exception) {
                throw new Error("Data has no class attribute!");
            }
            d += d2;
        }
        if (d > 0.0) {
            Utils.normalize(dArray2, d);
        }
        return dArray2;
    }

    protected void crossValidate() {
        try {
            int n;
            double[] dArray = new double[this.m_kNNUpper];
            double[] dArray2 = new double[this.m_kNNUpper];
            for (int i = 0; i < this.m_kNNUpper; ++i) {
                dArray[i] = 0.0;
                dArray2[i] = 0.0;
            }
            this.m_kNN = this.m_kNNUpper;
            for (n = 0; n < this.m_Train.numInstances(); ++n) {
                if (this.m_Debug && n % 50 == 0) {
                    System.err.print("Cross validating " + n + "/" + this.m_Train.numInstances() + "\r");
                }
                Instance instance = this.m_Train.instance(n);
                Instances instances = this.m_NNSearch.kNearestNeighbours(instance, this.m_kNN);
                double[] dArray3 = this.m_NNSearch.getDistances();
                for (int i = this.m_kNNUpper - 1; i >= 0; --i) {
                    double[] dArray4 = new double[dArray3.length];
                    System.arraycopy(dArray3, 0, dArray4, 0, dArray3.length);
                    double[] dArray5 = this.makeDistribution(instances, dArray4);
                    double d = Utils.maxIndex(dArray5);
                    if (this.m_Train.classAttribute().isNumeric()) {
                        d = dArray5[0];
                        double d2 = d - instance.classValue();
                        int n2 = i;
                        dArray2[n2] = dArray2[n2] + d2 * d2;
                        int n3 = i;
                        dArray[n3] = dArray[n3] + Math.abs(d2);
                    } else if (d != instance.classValue()) {
                        int n4 = i;
                        dArray[n4] = dArray[n4] + 1.0;
                    }
                    if (i < 1) continue;
                    instances = this.pruneToK(instances, dArray4, i);
                }
            }
            for (n = 0; n < this.m_kNNUpper; ++n) {
                if (this.m_Debug) {
                    System.err.print("Hold-one-out performance of " + (n + 1) + " neighbors ");
                }
                if (this.m_Train.classAttribute().isNumeric()) {
                    if (!this.m_Debug) continue;
                    if (this.m_MeanSquared) {
                        System.err.println("(RMSE) = " + Math.sqrt(dArray2[n] / (double)this.m_Train.numInstances()));
                        continue;
                    }
                    System.err.println("(MAE) = " + dArray[n] / (double)this.m_Train.numInstances());
                    continue;
                }
                if (!this.m_Debug) continue;
                System.err.println("(%ERR) = " + 100.0 * dArray[n] / (double)this.m_Train.numInstances());
            }
            double[] dArray6 = dArray;
            if (this.m_Train.classAttribute().isNumeric() && this.m_MeanSquared) {
                dArray6 = dArray2;
            }
            double d = Double.NaN;
            int n5 = 1;
            for (int i = 0; i < this.m_kNNUpper; ++i) {
                if (!Double.isNaN(d) && !(d > dArray6[i])) continue;
                d = dArray6[i];
                n5 = i + 1;
            }
            this.m_kNN = n5;
            if (this.m_Debug) {
                System.err.println("Selected k = " + n5);
            }
            this.m_kNNValid = true;
        }
        catch (Exception exception) {
            throw new Error("Couldn't optimize by cross-validation: " + exception.getMessage());
        }
    }

    public Instances pruneToK(Instances instances, double[] dArray, int n) {
        if (instances == null || dArray == null || instances.numInstances() == 0) {
            return null;
        }
        if (n < 1) {
            n = 1;
        }
        int n2 = 0;
        for (int i = 0; i < instances.numInstances(); ++i) {
            double d = dArray[i];
            if (++n2 <= n || d == dArray[i - 1]) continue;
            instances = new Instances(instances, 0, --n2);
            break;
        }
        return instances;
    }

    public static void main(String[] stringArray) {
        IBk.runClassifier(new IBk(), stringArray);
    }
}

