/*
 * Decompiled with CFR 0.152.
 */
package moa.classifiers.functions;

import com.github.javacliparser.FloatOption;
import com.yahoo.labs.samoa.instances.Instance;
import moa.classifiers.functions.SGD;
import moa.core.DoubleVector;
import moa.core.Utils;

public class AdaGrad
extends SGD {
    private static final long serialVersionUID = -3732968666673530291L;
    protected double m_epsilon = 1.0E-8;
    public FloatOption epsilonOption = new FloatOption("epsilon", 'p', "epsilon parameter.", 1.0E-8, 0.0, 1.0);
    protected DoubleVector m_gradients;
    protected DoubleVector m_velocity;
    protected double m_biasVelocity;

    @Override
    public String getPurposeString() {
        return "An online optimiser for learning various linear models (binary class SVM, binary class logistic regression and linear regression).";
    }

    public void setEpsilon(double eps) {
        this.m_epsilon = eps;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    @Override
    public void resetLearningImpl() {
        this.reset();
        this.setLambda(this.lambdaRegularizationOption.getValue());
        this.setLearningRate(this.learningRateOption.getValue());
        this.setEpsilon(this.epsilonOption.getValue());
        this.setLossFunction(this.lossFunctionOption.getChosenIndex());
    }

    @Override
    public void trainOnInstanceImpl(Instance instance) {
        double dldz;
        double y;
        if (this.m_weights == null) {
            this.m_weights = new DoubleVector();
            this.m_gradients = new DoubleVector();
            this.m_velocity = new DoubleVector();
            this.m_weights.setValue(instance.numAttributes(), 0.0);
        }
        if (instance.classIsMissing()) {
            return;
        }
        double z = AdaGrad.dotProd(instance, this.m_weights, instance.classIndex()) + this.m_bias;
        if (instance.classAttribute().isNominal()) {
            double d = y = instance.classValue() == 0.0 ? 0.0 : 1.0;
            if (this.m_loss == 1) {
                double yhat = 1.0 / (1.0 + Math.exp(-z));
                dldz = yhat - y;
            } else {
                dldz = (y = y * 2.0 - 1.0) * z < 1.0 ? -y : 0.0;
            }
        } else {
            y = instance.classValue();
            dldz = z - y;
        }
        for (int i = 0; i < this.m_weights.numValues(); ++i) {
            this.m_gradients.setValue(i, this.m_lambda / this.m_t * this.m_weights.getValue(i));
        }
        int n = instance.numValues();
        for (int i = 0; i < n; ++i) {
            this.m_gradients.addToValue(instance.index(i), instance.valueSparse(i) * dldz);
        }
        double biasGradient = dldz;
        this.m_biasVelocity += biasGradient * biasGradient;
        this.m_bias -= this.m_learningRate / (Math.sqrt(this.m_biasVelocity) + this.m_epsilon) * biasGradient;
        for (int i = 0; i < this.m_weights.numValues(); ++i) {
            double g = this.m_gradients.getValue(i);
            this.m_velocity.addToValue(i, g * g);
            this.m_weights.addToValue(i, -(this.m_learningRate / (Math.sqrt(this.m_velocity.getValue(i)) + this.m_epsilon)) * g);
        }
        this.m_t += 1.0;
    }

    @Override
    public String toString() {
        if (this.m_weights == null) {
            return "AdaGrad: No model built yet.\n";
        }
        StringBuffer buff = new StringBuffer();
        buff.append("Loss function: ");
        if (this.m_loss == 0) {
            buff.append("Hinge loss (SVM)\n\n");
        } else if (this.m_loss == 1) {
            buff.append("Log loss (logistic regression)\n\n");
        } else {
            buff.append("Squared loss (linear regression)\n\n");
        }
        int printed = 0;
        for (int i = 0; i < this.m_weights.numValues(); ++i) {
            if (printed > 0) {
                buff.append(" + ");
            } else {
                buff.append("   ");
            }
            buff.append(Utils.doubleToString(this.m_weights.getValue(i), 12, 4) + " \n");
            ++printed;
        }
        if (this.m_bias > 0.0) {
            buff.append(" + " + Utils.doubleToString(this.m_bias, 12, 4));
        } else {
            buff.append(" - " + Utils.doubleToString(-this.m_bias, 12, 4));
        }
        return buff.toString();
    }
}

