/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.meta;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IteratedSingleClassifierEnhancer;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.rules.ZeroR;
import weka.classifiers.trees.DecisionStump;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.UnassignedClassException;
import weka.core.Utils;
import weka.core.WeightedInstancesHandler;

public class AdditiveRegression
extends IteratedSingleClassifierEnhancer
implements OptionHandler,
AdditionalMeasureProducer,
WeightedInstancesHandler,
TechnicalInformationHandler,
IterativeClassifier {
    static final long serialVersionUID = -2368937577670527151L;
    protected ArrayList<Classifier> m_Classifiers;
    protected double m_shrinkage = 1.0;
    protected ZeroR m_zeroR;
    protected boolean m_SuitableData = true;
    protected Instances m_Data;
    protected double m_SSE;
    protected double m_Diff;

    public String globalInfo() {
        return " Meta classifier that enhances the performance of a regression base classifier. Each iteration fits a model to the residuals left by the classifier on the previous iteration. Prediction is accomplished by adding the predictions of each classifier. Reducing the shrinkage (learning rate) parameter helps prevent overfitting and has a smoothing effect but increases the learning time.\n\nFor more information see:\n\n" + this.getTechnicalInformation().toString();
    }

    @Override
    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.TECHREPORT);
        result.setValue(TechnicalInformation.Field.AUTHOR, "J.H. Friedman");
        result.setValue(TechnicalInformation.Field.YEAR, "1999");
        result.setValue(TechnicalInformation.Field.TITLE, "Stochastic Gradient Boosting");
        result.setValue(TechnicalInformation.Field.INSTITUTION, "Stanford University");
        result.setValue(TechnicalInformation.Field.PS, "http://www-stat.stanford.edu/~jhf/ftp/stobst.ps");
        return result;
    }

    public AdditiveRegression() {
        this(new DecisionStump());
    }

    public AdditiveRegression(Classifier classifier) {
        this.m_Classifier = classifier;
    }

    @Override
    protected String defaultClassifierString() {
        return "weka.classifiers.trees.DecisionStump";
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> newVector = new Vector<Option>(1);
        newVector.addElement(new Option("\tSpecify shrinkage rate. (default = 1.0, ie. no shrinkage)\n", "S", 1, "-S"));
        newVector.addAll(Collections.list(super.listOptions()));
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String optionString = Utils.getOption('S', options);
        if (optionString.length() != 0) {
            Double temp = Double.valueOf(optionString);
            this.setShrinkage(temp);
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions(options);
    }

    @Override
    public String[] getOptions() {
        Vector<String> options = new Vector<String>();
        options.add("-S");
        options.add("" + this.getShrinkage());
        Collections.addAll(options, super.getOptions());
        return options.toArray(new String[0]);
    }

    public String shrinkageTipText() {
        return "Shrinkage rate. Smaller values help prevent overfitting and have a smoothing effect (but increase learning time). Default = 1.0, ie. no shrinkage.";
    }

    public void setShrinkage(double l) {
        this.m_shrinkage = l;
    }

    public double getShrinkage() {
        return this.m_shrinkage;
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAllClasses();
        result.disableAllClassDependencies();
        result.enable(Capabilities.Capability.NUMERIC_CLASS);
        result.enable(Capabilities.Capability.DATE_CLASS);
        return result;
    }

    @Override
    public void buildClassifier(Instances data) throws Exception {
        this.initializeClassifier(data);
        while (this.next()) {
        }
        this.done();
    }

    @Override
    public void initializeClassifier(Instances data) throws Exception {
        this.getCapabilities().testWithFail(data);
        this.m_Data = new Instances(data);
        this.m_Data.deleteWithMissingClass();
        this.m_zeroR = new ZeroR();
        this.m_zeroR.buildClassifier(this.m_Data);
        if (this.m_Data.numAttributes() == 1) {
            System.err.println("Cannot build model (only class attribute present in data!), using ZeroR model instead!");
            this.m_SuitableData = false;
            return;
        }
        this.m_SuitableData = true;
        this.m_Classifiers = new ArrayList(this.m_NumIterations);
        this.m_Data = this.residualReplace(this.m_Data, this.m_zeroR, false);
        this.m_SSE = 0.0;
        this.m_Diff = Double.MAX_VALUE;
        for (int i = 0; i < this.m_Data.numInstances(); ++i) {
            this.m_SSE += this.m_Data.instance(i).weight() * this.m_Data.instance(i).classValue() * this.m_Data.instance(i).classValue();
        }
        if (this.m_Debug) {
            System.err.println("Sum of squared residuals (predicting the mean) : " + this.m_SSE);
        }
    }

    @Override
    public boolean next() throws Exception {
        if (!this.m_SuitableData || this.m_Classifiers.size() >= this.m_NumIterations || this.m_Diff <= Utils.SMALL) {
            return false;
        }
        this.m_Classifiers.add(AbstractClassifier.makeCopy(this.m_Classifier));
        this.m_Classifiers.get(this.m_Classifiers.size() - 1).buildClassifier(this.m_Data);
        this.m_Data = this.residualReplace(this.m_Data, this.m_Classifiers.get(this.m_Classifiers.size() - 1), true);
        double sum = 0.0;
        for (int i = 0; i < this.m_Data.numInstances(); ++i) {
            sum += this.m_Data.instance(i).weight() * this.m_Data.instance(i).classValue() * this.m_Data.instance(i).classValue();
        }
        if (this.m_Debug) {
            System.err.println("Sum of squared residuals : " + sum);
        }
        this.m_Diff = this.m_SSE - sum;
        this.m_SSE = sum;
        return true;
    }

    @Override
    public void done() {
        this.m_Data = null;
    }

    @Override
    public double classifyInstance(Instance inst) throws Exception {
        double prediction = this.m_zeroR.classifyInstance(inst);
        if (!this.m_SuitableData) {
            return prediction;
        }
        for (Classifier classifier : this.m_Classifiers) {
            double toAdd = classifier.classifyInstance(inst);
            if (Utils.isMissingValue(toAdd)) {
                throw new UnassignedClassException("AdditiveRegression: base learner predicted missing value.");
            }
            prediction += (toAdd *= this.getShrinkage());
        }
        return prediction;
    }

    private Instances residualReplace(Instances data, Classifier c, boolean useShrinkage) throws Exception {
        Instances newInst = new Instances(data);
        for (int i = 0; i < newInst.numInstances(); ++i) {
            double pred = c.classifyInstance(newInst.instance(i));
            if (Utils.isMissingValue(pred)) {
                throw new UnassignedClassException("AdditiveRegression: base learner predicted missing value.");
            }
            if (useShrinkage) {
                pred *= this.getShrinkage();
            }
            double residual = newInst.instance(i).classValue() - pred;
            newInst.instance(i).setClassValue(residual);
        }
        return newInst;
    }

    @Override
    public Enumeration<String> enumerateMeasures() {
        Vector<String> newVector = new Vector<String>(1);
        newVector.addElement("measureNumIterations");
        return newVector.elements();
    }

    @Override
    public double getMeasure(String additionalMeasureName) {
        if (additionalMeasureName.compareToIgnoreCase("measureNumIterations") == 0) {
            return this.measureNumIterations();
        }
        throw new IllegalArgumentException(additionalMeasureName + " not supported (AdditiveRegression)");
    }

    public double measureNumIterations() {
        return this.m_Classifiers.size();
    }

    public String toString() {
        StringBuffer text = new StringBuffer();
        if (this.m_zeroR == null) {
            return "Classifier hasn't been built yet!";
        }
        if (!this.m_SuitableData) {
            StringBuffer buf = new StringBuffer();
            buf.append(this.getClass().getName().replaceAll(".*\\.", "") + "\n");
            buf.append(this.getClass().getName().replaceAll(".*\\.", "").replaceAll(".", "=") + "\n\n");
            buf.append("Warning: No model could be built, hence ZeroR model is used:\n\n");
            buf.append(this.m_zeroR.toString());
            return buf.toString();
        }
        text.append("Additive Regression\n\n");
        text.append("ZeroR model\n\n" + this.m_zeroR + "\n\n");
        text.append("Base classifier " + this.getClassifier().getClass().getName() + "\n\n");
        text.append("" + this.m_Classifiers.size() + " models generated.\n");
        for (int i = 0; i < this.m_Classifiers.size(); ++i) {
            text.append("\nModel number " + i + "\n\n" + this.m_Classifiers.get(i) + "\n");
        }
        return text.toString();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 11192 $");
    }

    public static void main(String[] argv) {
        AdditiveRegression.runClassifier(new AdditiveRegression(), argv);
    }
}

