/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.evaluation.EvaluationMetricHelper;
import weka.classifiers.meta.LogitBoost;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class IterativeClassifierOptimizer
extends RandomizableClassifier
implements AdditionalMeasureProducer {
    private static final long serialVersionUID = -3665485256313525864L;
    protected IterativeClassifier m_IterativeClassifier = new LogitBoost();
    protected int m_NumFolds = 10;
    protected int m_NumRuns = 1;
    protected int m_StepSize = 1;
    protected boolean m_UseAverage = false;
    protected int m_lookAheadIterations = 50;
    public static Tag[] TAGS_EVAL;
    protected String m_evalMetric = "rmse";
    protected int m_classValueIndex = -1;
    protected double[] m_thresholds = null;
    protected double m_bestResult = Double.MAX_VALUE;
    protected int m_bestNumIts;
    protected int m_numThreads = 1;
    protected int m_poolSize = 1;

    public String globalInfo() {
        return "Optimizes the number of iterations of the given iterative classifier using cross-validation.";
    }

    protected String defaultIterativeClassifierString() {
        return "weka.classifiers.meta.LogitBoost";
    }

    public String useAverageTipText() {
        return "If true, average estimates are used instead of one estimate from pooled predictions.";
    }

    public boolean getUseAverage() {
        return this.m_UseAverage;
    }

    public void setUseAverage(boolean newUseAverage) {
        this.m_UseAverage = newUseAverage;
    }

    public String numThreadsTipText() {
        return "The number of threads to use, which should be >= size of thread pool.";
    }

    public int getNumThreads() {
        return this.m_numThreads;
    }

    public void setNumThreads(int nT) {
        this.m_numThreads = nT;
    }

    public String poolSizeTipText() {
        return "The size of the thread pool, for example, the number of cores in the CPU.";
    }

    public int getPoolSize() {
        return this.m_poolSize;
    }

    public void setPoolSize(int nT) {
        this.m_poolSize = nT;
    }

    public String stepSizeTipText() {
        return "Step size for the evaluation, if evaluation is time consuming.";
    }

    public int getStepSize() {
        return this.m_StepSize;
    }

    public void setStepSize(int newStepSize) {
        this.m_StepSize = newStepSize;
    }

    public String numRunsTipText() {
        return "Number of runs for cross-validation.";
    }

    public int getNumRuns() {
        return this.m_NumRuns;
    }

    public void setNumRuns(int newNumRuns) {
        this.m_NumRuns = newNumRuns;
    }

    public String numFoldsTipText() {
        return "Number of folds for cross-validation.";
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setNumFolds(int newNumFolds) {
        this.m_NumFolds = newNumFolds;
    }

    public String lookAheadIterationsTipText() {
        return "The number of iterations to look ahead for to find a better optimum.";
    }

    public int getLookAheadIterations() {
        return this.m_lookAheadIterations;
    }

    public void setLookAheadIterations(int newLookAheadIterations) {
        this.m_lookAheadIterations = newLookAheadIterations;
    }

    /*
     * Could not resolve type clashes
     */
    @Override
    public void buildClassifier(Instances data) throws Exception {
        if (this.m_IterativeClassifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        this.getCapabilities().testWithFail(data);
        Random randomInstance = new Random(this.m_Seed);
        Instances origData = data;
        data = new Instances(data);
        data.deleteWithMissingClass();
        if (data.numInstances() < this.m_NumFolds) {
            System.err.println("WARNING: reducing number of folds to number of instances in IterativeClassifierOptimizer");
            this.m_NumFolds = data.numInstances();
        }
        Instances[][] trainingSets = new Instances[this.m_NumRuns][this.m_NumFolds];
        Instances[][] testSets = new Instances[this.m_NumRuns][this.m_NumFolds];
        final IterativeClassifier[][] classifiers = new IterativeClassifier[this.m_NumRuns][this.m_NumFolds];
        for (int j = 0; j < this.m_NumRuns; ++j) {
            data.randomize(randomInstance);
            if (data.classAttribute().isNominal()) {
                data.stratify(this.m_NumFolds);
            }
            for (int i = 0; i < this.m_NumFolds; ++i) {
                trainingSets[j][i] = data.trainCV(this.m_NumFolds, i, randomInstance);
                testSets[j][i] = data.testCV(this.m_NumFolds, i);
                classifiers[j][i] = (IterativeClassifier)AbstractClassifier.makeCopy(this.m_IterativeClassifier);
                classifiers[j][i].initializeClassifier(trainingSets[j][i]);
            }
        }
        ExecutorService pool = Executors.newFixedThreadPool(this.m_poolSize);
        Evaluation eval = new Evaluation(data);
        EvaluationMetricHelper helper = new EvaluationMetricHelper(eval);
        boolean maximise = helper.metricIsMaximisable(this.m_evalMetric);
        this.m_bestResult = maximise ? Double.MIN_VALUE : Double.MAX_VALUE;
        this.m_thresholds = null;
        int numIts = 0;
        this.m_bestNumIts = 0;
        int numberOfIterationsSinceMinimum = -1;
        while (true) {
            if (numIts % this.m_StepSize == 0) {
                double delta;
                int j;
                int r;
                double result = 0.0;
                double[] tempThresholds = null;
                if (!this.m_UseAverage) {
                    eval = new Evaluation(data);
                    helper.setEvaluation(eval);
                    for (r = 0; r < this.m_NumRuns; ++r) {
                        for (int i = 0; i < this.m_NumFolds; ++i) {
                            eval.evaluateModel(classifiers[r][i], testSets[r][i], new Object[0]);
                        }
                    }
                    result = this.getClassValueIndex() >= 0 ? helper.getNamedMetric(this.m_evalMetric, this.getClassValueIndex()) : helper.getNamedMetric(this.m_evalMetric, new int[0]);
                    tempThresholds = helper.getNamedMetricThresholds(this.m_evalMetric);
                } else {
                    for (r = 0; r < this.m_NumRuns; ++r) {
                        for (int i = 0; i < this.m_NumFolds; ++i) {
                            eval = new Evaluation(trainingSets[r][i]);
                            helper.setEvaluation(eval);
                            eval.evaluateModel(classifiers[r][i], testSets[r][i], new Object[0]);
                            result += this.getClassValueIndex() >= 0 ? helper.getNamedMetric(this.m_evalMetric, this.getClassValueIndex()) : helper.getNamedMetric(this.m_evalMetric, new int[0]);
                            double[] thresholds = helper.getNamedMetricThresholds(this.m_evalMetric);
                            if (thresholds == null) continue;
                            if (tempThresholds == null) {
                                tempThresholds = new double[data.numClasses()];
                            }
                            for (int j2 = 0; j2 < thresholds.length; ++j2) {
                                int n = j2;
                                tempThresholds[n] = tempThresholds[n] + thresholds[j2];
                            }
                        }
                    }
                    result /= (double)(this.m_NumFolds * this.m_NumRuns);
                    if (tempThresholds != null) {
                        j = 0;
                        while (j < tempThresholds.length) {
                            int n = j++;
                            tempThresholds[n] = tempThresholds[n] / (double)(this.m_NumRuns * this.m_NumFolds);
                        }
                    }
                }
                if (this.m_Debug) {
                    System.err.println("Iteration: " + numIts + " " + "Measure: " + result);
                    if (tempThresholds != null) {
                        System.err.print("Thresholds:");
                        for (j = 0; j < tempThresholds.length; ++j) {
                            System.err.print(" " + tempThresholds[j]);
                        }
                        System.err.println();
                    }
                }
                double d = delta = maximise ? this.m_bestResult - result : result - this.m_bestResult;
                if (delta < 0.0) {
                    this.m_bestResult = result;
                    this.m_bestNumIts = numIts;
                    this.m_thresholds = tempThresholds;
                    numberOfIterationsSinceMinimum = -1;
                }
            }
            ++numIts;
            if (++numberOfIterationsSinceMinimum >= this.m_lookAheadIterations) break;
            int numRuns = this.m_NumRuns * this.m_NumFolds;
            final int N = this.m_NumFolds;
            int chunksize = numRuns / this.m_numThreads;
            HashSet<Future<Boolean>> results = new HashSet<Future<Boolean>>();
            for (int j = 0; j < this.m_numThreads; ++j) {
                final int lo = j * chunksize;
                final int hi = j < this.m_numThreads - 1 ? lo + chunksize : numRuns;
                Future<Boolean> futureT = pool.submit(new Callable<Boolean>(){

                    @Override
                    public Boolean call() throws Exception {
                        for (int k = lo; k < hi; ++k) {
                            if (classifiers[k / N][k % N].next()) continue;
                            if (IterativeClassifierOptimizer.this.m_Debug) {
                                System.err.println("Classifier failed to iterate in cross-validation.");
                            }
                            return false;
                        }
                        return true;
                    }
                });
                results.add(futureT);
            }
            try {
                boolean failure = false;
                for (Future futureT : results) {
                    if (((Boolean)futureT.get()).booleanValue()) continue;
                    failure = true;
                    break;
                }
                if (!failure) continue;
            }
            catch (Exception e) {
                System.out.println("Classifiers could not be generated.");
                e.printStackTrace();
                continue;
            }
            break;
        }
        trainingSets = null;
        testSets = null;
        data = null;
        this.m_IterativeClassifier.initializeClassifier(origData);
        int i = 0;
        while (i++ < this.m_bestNumIts && this.m_IterativeClassifier.next()) {
        }
        this.m_IterativeClassifier.done();
        pool.shutdown();
    }

    @Override
    public double[] distributionForInstance(Instance inst) throws Exception {
        if (this.m_thresholds != null) {
            double[] dist = this.m_IterativeClassifier.distributionForInstance(inst);
            double[] newDist = new double[dist.length];
            for (int i = 0; i < dist.length; ++i) {
                if (!(dist[i] >= this.m_thresholds[i])) continue;
                newDist[i] = 1.0;
            }
            Utils.normalize(newDist);
            return newDist;
        }
        return this.m_IterativeClassifier.distributionForInstance(inst);
    }

    public String toString() {
        if (this.m_IterativeClassifier == null) {
            return "No classifier built yet.";
        }
        StringBuffer sb = new StringBuffer();
        sb.append("Best value found: " + this.m_bestResult + "\n");
        sb.append("Best number of iterations found: " + this.m_bestNumIts + "\n\n");
        if (this.m_thresholds != null) {
            sb.append("Thresholds found: ");
            for (int i = 0; i < this.m_thresholds.length; ++i) {
                sb.append(this.m_thresholds[i] + " ");
            }
        }
        sb.append("\n\n");
        sb.append(this.m_IterativeClassifier.toString());
        return sb.toString();
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(7);
        newVector.addElement(new Option("\tIf set, average estimate is used rather than one estimate from pooled predictions.\n", "A", 0, "-A"));
        newVector.addElement(new Option("\t" + this.lookAheadIterationsTipText() + "\n" + "\t(default 50)", "L", 1, "-L <num>"));
        newVector.addElement(new Option("\t" + this.poolSizeTipText() + "\n\t(default 1)", "P", 1, "-P <int>"));
        newVector.addElement(new Option("\t" + this.numThreadsTipText() + "\n" + "\t(default 1)", "E", 1, "-E <int>"));
        newVector.addElement(new Option("\t" + this.stepSizeTipText() + "\n" + "\t(default 1)", "I", 1, "-I <num>"));
        newVector.addElement(new Option("\tNumber of folds for cross-validation.\n\t(default 10)", "F", 1, "-F <num>"));
        newVector.addElement(new Option("\tNumber of runs for cross-validation.\n\t(default 1)", "R", 1, "-R <num>"));
        newVector.addElement(new Option("\tFull name of base classifier.\n\t(default: " + this.defaultIterativeClassifierString() + ")", "W", 1, "-W"));
        List<String> metrics = EvaluationMetricHelper.getAllMetricNames();
        StringBuilder b = new StringBuilder();
        int length = 0;
        for (String m : metrics) {
            b.append(m.toLowerCase()).append(",");
            if ((length += m.length()) < 60) continue;
            b.append("\n\t");
            length = 0;
        }
        newVector.addElement(new Option("\tEvaluation metric to optimise (default rmse). Available metrics:\n\t" + b.substring(0, b.length() - 1), "metric", 1, "-metric <name>"));
        newVector.addElement(new Option("\tClass value index to optimise. Ignored for all but information-retrieval\n\ttype metrics (such as roc area). If unspecified (or a negative value is supplied),\n\tand an information-retrieval metric is specified, then the class-weighted average\n\tmetric used. (default -1)", "class-value-index", 1, "-class-value-index <0-based index>"));
        newVector.addAll(Collections.list(super.listOptions()));
        newVector.addElement(new Option("", "", 0, "\nOptions specific to classifier " + this.m_IterativeClassifier.getClass().getName() + ":"));
        newVector.addAll(Collections.list(((OptionHandler)((Object)this.m_IterativeClassifier)).listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String classValIndex;
        super.setOptions(options);
        this.setUseAverage(Utils.getFlag('A', options));
        String lookAheadIterations = Utils.getOption('L', options);
        if (lookAheadIterations.length() != 0) {
            this.setLookAheadIterations(Integer.parseInt(lookAheadIterations));
        } else {
            this.setLookAheadIterations(50);
        }
        String PoolSize = Utils.getOption('P', options);
        if (PoolSize.length() != 0) {
            this.setPoolSize(Integer.parseInt(PoolSize));
        } else {
            this.setPoolSize(1);
        }
        String NumThreads = Utils.getOption('E', options);
        if (NumThreads.length() != 0) {
            this.setNumThreads(Integer.parseInt(NumThreads));
        } else {
            this.setNumThreads(1);
        }
        String stepSize = Utils.getOption('I', options);
        if (stepSize.length() != 0) {
            this.setStepSize(Integer.parseInt(stepSize));
        } else {
            this.setStepSize(1);
        }
        String numFolds = Utils.getOption('F', options);
        if (numFolds.length() != 0) {
            this.setNumFolds(Integer.parseInt(numFolds));
        } else {
            this.setNumFolds(10);
        }
        String numRuns = Utils.getOption('R', options);
        if (numRuns.length() != 0) {
            this.setNumRuns(Integer.parseInt(numRuns));
        } else {
            this.setNumRuns(1);
        }
        String evalMetric = Utils.getOption("metric", options);
        if (evalMetric.length() > 0) {
            boolean found = false;
            for (int i = 0; i < TAGS_EVAL.length; ++i) {
                if (!TAGS_EVAL[i].getIDStr().equalsIgnoreCase(evalMetric)) continue;
                this.setEvaluationMetric(new SelectedTag(i, TAGS_EVAL));
                found = true;
                break;
            }
            if (!found) {
                throw new Exception("Unknown evaluation metric: " + evalMetric);
            }
        }
        if ((classValIndex = Utils.getOption("class-value-index", options)).length() > 0) {
            this.setClassValueIndex(Integer.parseInt(classValIndex));
        } else {
            this.setClassValueIndex(-1);
        }
        String classifierName = Utils.getOption('W', options);
        if (classifierName.length() > 0) {
            this.setIterativeClassifier(this.getIterativeClassifier(classifierName, Utils.partitionOptions(options)));
        } else {
            this.setIterativeClassifier(this.getIterativeClassifier(this.defaultIterativeClassifierString(), Utils.partitionOptions(options)));
        }
    }

    protected IterativeClassifier getIterativeClassifier(String name, String[] options) throws Exception {
        Classifier c = AbstractClassifier.forName(name, options);
        if (c instanceof IterativeClassifier) {
            return (IterativeClassifier)c;
        }
        throw new IllegalArgumentException(name + " is not an IterativeClassifier.");
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        if (this.getUseAverage()) {
            options.add("-A");
        }
        options.add("-W");
        options.add(this.getIterativeClassifier().getClass().getName());
        options.add("-L");
        options.add("" + this.getLookAheadIterations());
        options.add("-P");
        options.add("" + this.getPoolSize());
        options.add("-E");
        options.add("" + this.getNumThreads());
        options.add("-I");
        options.add("" + this.getStepSize());
        options.add("-F");
        options.add("" + this.getNumFolds());
        options.add("-R");
        options.add("" + this.getNumRuns());
        options.add("-metric");
        options.add(this.getEvaluationMetric().getSelectedTag().getIDStr());
        if (this.getClassValueIndex() >= 0) {
            options.add("-class-value-index");
            options.add("" + this.getClassValueIndex());
        }
        Collections.addAll(options, super.getOptions());
        String[] classifierOptions = ((OptionHandler)((Object)this.m_IterativeClassifier)).getOptions();
        if (classifierOptions.length > 0) {
            options.add("--");
            Collections.addAll(options, classifierOptions);
        }
        return options.toArray(new String[0]);
    }

    public String evaluationMetricTipText() {
        return "The evaluation metric to use";
    }

    public void setEvaluationMetric(SelectedTag metric) {
        if (metric.getTags() == TAGS_EVAL) {
            this.m_evalMetric = metric.getSelectedTag().getIDStr();
        }
    }

    public SelectedTag getEvaluationMetric() {
        for (int i = 0; i < TAGS_EVAL.length; ++i) {
            if (!TAGS_EVAL[i].getIDStr().equalsIgnoreCase(this.m_evalMetric)) continue;
            return new SelectedTag(i, TAGS_EVAL);
        }
        return new SelectedTag(12, TAGS_EVAL);
    }

    public String classValueIndexTipText() {
        return "The class value index to use with information retrieval type metrics. A value < 0 indicates to use the class weighted average version of the metric.";
    }

    public void setClassValueIndex(int i) {
        this.m_classValueIndex = i;
    }

    public int getClassValueIndex() {
        return this.m_classValueIndex;
    }

    public String iterativeClassifierTipText() {
        return "The iterative classifier to be optimized.";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result;
        if (this.getIterativeClassifier() != null) {
            result = this.getIterativeClassifier().getCapabilities();
        } else {
            result = new Capabilities(this);
            result.disableAll();
        }
        for (Capabilities.Capability cap : Capabilities.Capability.values()) {
            result.enableDependency(cap);
        }
        result.setOwner(this);
        return result;
    }

    public void setIterativeClassifier(IterativeClassifier newIterativeClassifier) {
        this.m_IterativeClassifier = newIterativeClassifier;
    }

    public IterativeClassifier getIterativeClassifier() {
        return this.m_IterativeClassifier;
    }

    protected String getIterativeClassifierSpec() {
        IterativeClassifier c = this.getIterativeClassifier();
        return c.getClass().getName() + " " + Utils.joinOptions(((OptionHandler)((Object)c)).getOptions());
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10649 $");
    }

    public double measureBestNumIts() {
        return this.m_bestNumIts;
    }

    public double measureBestVal() {
        return this.m_bestResult;
    }

    @Override
    public Enumeration<String> enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(2);
        newVector.addElement("measureBestNumIts");
        newVector.addElement("measureBestVal");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.compareToIgnoreCase("measureBestNumIts") == 0) {
            return this.measureBestNumIts();
        }
        if (additionalMeasureName.compareToIgnoreCase("measureBestVal") == 0) {
            return this.measureBestVal();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (IterativeClassifierOptimizer)");
    }

    public static void main(String[] argv) {
        IterativeClassifierOptimizer.runClassifier(new IterativeClassifierOptimizer(), argv);
    }

    static {
        List<String> evalNames = EvaluationMetricHelper.getAllMetricNames();
        TAGS_EVAL = new Tag[evalNames.size()];
        for (int i = 0; i < evalNames.size(); ++i) {
            IterativeClassifierOptimizer.TAGS_EVAL[i] = new Tag(i, evalNames.get(i), evalNames.get(i), false);
        }
    }
}

