/*
 * Decompiled with CFR 0.152.
 */
package moa.evaluation;

import com.github.javacliparser.FlagOption;
import com.yahoo.labs.samoa.instances.Instance;
import com.yahoo.labs.samoa.instances.Prediction;
import java.io.Serializable;
import java.util.ArrayList;
import moa.core.Example;
import moa.core.Measurement;
import moa.core.ObjectRepository;
import moa.core.Utils;
import moa.evaluation.ClassificationPerformanceEvaluator;
import moa.options.AbstractOptionHandler;
import moa.tasks.TaskMonitor;

public class BasicClassificationPerformanceEvaluator
extends AbstractOptionHandler
implements ClassificationPerformanceEvaluator {
    private static final long serialVersionUID = 1L;
    protected Estimator weightCorrect;
    protected Estimator[] columnKappa;
    protected Estimator[] rowKappa;
    protected Estimator[] precision;
    protected Estimator[] recall;
    protected int numClasses;
    private Estimator weightCorrectNoChangeClassifier;
    private Estimator weightMajorityClassifier;
    private int lastSeenClass;
    private double totalWeightObserved;
    public FlagOption precisionRecallOutputOption = new FlagOption("precisionRecallOutput", 'o', "Outputs average precision, recall and F1 scores.");
    public FlagOption precisionPerClassOption = new FlagOption("precisionPerClass", 'p', "Report precision per class.");
    public FlagOption recallPerClassOption = new FlagOption("recallPerClass", 'r', "Report recall per class.");
    public FlagOption f1PerClassOption = new FlagOption("f1PerClass", 'f', "Report F1 per class.");

    @Override
    public void reset() {
        this.reset(this.numClasses);
    }

    public void reset(int numClasses) {
        this.numClasses = numClasses;
        this.rowKappa = new Estimator[numClasses];
        this.columnKappa = new Estimator[numClasses];
        this.precision = new Estimator[numClasses];
        this.recall = new Estimator[numClasses];
        for (int i = 0; i < this.numClasses; ++i) {
            this.rowKappa[i] = this.newEstimator();
            this.columnKappa[i] = this.newEstimator();
            this.precision[i] = this.newEstimator();
            this.recall[i] = this.newEstimator();
        }
        this.weightCorrect = this.newEstimator();
        this.weightCorrectNoChangeClassifier = this.newEstimator();
        this.weightMajorityClassifier = this.newEstimator();
        this.lastSeenClass = 0;
        this.totalWeightObserved = 0.0;
    }

    @Override
    public void addResult(Example<Instance> example, double[] classVotes) {
        Instance inst = example.getData();
        double weight = inst.weight();
        if (!inst.classIsMissing()) {
            int trueClass = (int)inst.classValue();
            int predictedClass = Utils.maxIndex(classVotes);
            if (weight > 0.0) {
                if (this.totalWeightObserved == 0.0) {
                    this.reset(inst.dataset().numClasses());
                }
                this.totalWeightObserved += weight;
                this.weightCorrect.add(predictedClass == trueClass ? weight : 0.0);
                for (int i = 0; i < this.numClasses; ++i) {
                    this.rowKappa[i].add(predictedClass == i ? weight : 0.0);
                    this.columnKappa[i].add(trueClass == i ? weight : 0.0);
                    if (predictedClass == i) {
                        this.precision[i].add(predictedClass == trueClass ? weight : 0.0);
                    } else {
                        this.precision[i].add(Double.NaN);
                    }
                    if (trueClass == i) {
                        this.recall[i].add(predictedClass == trueClass ? weight : 0.0);
                        continue;
                    }
                    this.recall[i].add(Double.NaN);
                }
            }
            this.weightCorrectNoChangeClassifier.add(this.lastSeenClass == trueClass ? weight : 0.0);
            this.weightMajorityClassifier.add(this.getMajorityClass() == trueClass ? weight : 0.0);
            this.lastSeenClass = trueClass;
        }
    }

    private int getMajorityClass() {
        int majorityClass = 0;
        double maxProbClass = 0.0;
        for (int i = 0; i < this.numClasses; ++i) {
            if (!(this.columnKappa[i].estimation() > maxProbClass)) continue;
            majorityClass = i;
            maxProbClass = this.columnKappa[i].estimation();
        }
        return majorityClass;
    }

    @Override
    public Measurement[] getPerformanceMeasurements() {
        int i;
        ArrayList<Measurement> measurements = new ArrayList<Measurement>();
        measurements.add(new Measurement("classified instances", this.getTotalWeightObserved()));
        measurements.add(new Measurement("classifications correct (percent)", this.getFractionCorrectlyClassified() * 100.0));
        measurements.add(new Measurement("Kappa Statistic (percent)", this.getKappaStatistic() * 100.0));
        measurements.add(new Measurement("Kappa Temporal Statistic (percent)", this.getKappaTemporalStatistic() * 100.0));
        measurements.add(new Measurement("Kappa M Statistic (percent)", this.getKappaMStatistic() * 100.0));
        if (this.precisionRecallOutputOption.isSet()) {
            measurements.add(new Measurement("F1 Score (percent)", this.getF1Statistic() * 100.0));
        }
        if (this.f1PerClassOption.isSet()) {
            for (i = 0; i < this.numClasses; ++i) {
                measurements.add(new Measurement("F1 Score for class " + i + " (percent)", 100.0 * this.getF1Statistic(i)));
            }
        }
        if (this.precisionRecallOutputOption.isSet()) {
            measurements.add(new Measurement("Precision (percent)", this.getPrecisionStatistic() * 100.0));
        }
        if (this.precisionPerClassOption.isSet()) {
            for (i = 0; i < this.numClasses; ++i) {
                measurements.add(new Measurement("Precision for class " + i + " (percent)", 100.0 * this.getPrecisionStatistic(i)));
            }
        }
        if (this.precisionRecallOutputOption.isSet()) {
            measurements.add(new Measurement("Recall (percent)", this.getRecallStatistic() * 100.0));
        }
        if (this.recallPerClassOption.isSet()) {
            for (i = 0; i < this.numClasses; ++i) {
                measurements.add(new Measurement("Recall for class " + i + " (percent)", 100.0 * this.getRecallStatistic(i)));
            }
        }
        Measurement[] result = new Measurement[measurements.size()];
        return measurements.toArray(result);
    }

    public double getTotalWeightObserved() {
        return this.totalWeightObserved;
    }

    public double getFractionCorrectlyClassified() {
        return this.weightCorrect.estimation();
    }

    public double getFractionIncorrectlyClassified() {
        return 1.0 - this.getFractionCorrectlyClassified();
    }

    public double getKappaStatistic() {
        if (this.getTotalWeightObserved() > 0.0) {
            double p0 = this.getFractionCorrectlyClassified();
            double pc = 0.0;
            for (int i = 0; i < this.numClasses; ++i) {
                pc += this.rowKappa[i].estimation() * this.columnKappa[i].estimation();
            }
            return (p0 - pc) / (1.0 - pc);
        }
        return 0.0;
    }

    public double getKappaTemporalStatistic() {
        if (this.getTotalWeightObserved() > 0.0) {
            double p0 = this.getFractionCorrectlyClassified();
            double pc = this.weightCorrectNoChangeClassifier.estimation();
            return (p0 - pc) / (1.0 - pc);
        }
        return 0.0;
    }

    private double getKappaMStatistic() {
        if (this.getTotalWeightObserved() > 0.0) {
            double p0 = this.getFractionCorrectlyClassified();
            double pc = this.weightMajorityClassifier.estimation();
            return (p0 - pc) / (1.0 - pc);
        }
        return 0.0;
    }

    public double getPrecisionStatistic() {
        double total = 0.0;
        for (Estimator ck : this.precision) {
            total += ck.estimation();
        }
        return total / (double)this.precision.length;
    }

    public double getPrecisionStatistic(int numClass) {
        return this.precision[numClass].estimation();
    }

    public double getRecallStatistic() {
        double total = 0.0;
        for (Estimator ck : this.recall) {
            total += ck.estimation();
        }
        return total / (double)this.recall.length;
    }

    public double getRecallStatistic(int numClass) {
        return this.recall[numClass].estimation();
    }

    public double getF1Statistic() {
        return 2.0 * (this.getPrecisionStatistic() * this.getRecallStatistic() / (this.getPrecisionStatistic() + this.getRecallStatistic()));
    }

    public double getF1Statistic(int numClass) {
        return 2.0 * (this.getPrecisionStatistic(numClass) * this.getRecallStatistic(numClass) / (this.getPrecisionStatistic(numClass) + this.getRecallStatistic(numClass)));
    }

    @Override
    public void getDescription(StringBuilder sb, int indent) {
        Measurement.getMeasurementsDescription(this.getPerformanceMeasurements(), sb, indent);
    }

    @Override
    public void addResult(Example<Instance> testInst, Prediction prediction) {
    }

    @Override
    protected void prepareForUseImpl(TaskMonitor monitor, ObjectRepository repository) {
    }

    protected Estimator newEstimator() {
        return new BasicEstimator();
    }

    public class BasicEstimator
    implements Estimator {
        protected double len;
        protected double sum;

        @Override
        public void add(double value) {
            if (!Double.isNaN(value)) {
                this.sum += value;
                this.len += 1.0;
            }
        }

        @Override
        public double estimation() {
            return this.sum / this.len;
        }
    }

    public static interface Estimator
    extends Serializable {
        public void add(double var1);

        public double estimation();
    }
}

