/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes;

import java.util.Enumeration;
import java.util.Random;
import java.util.StringTokenizer;
import java.util.Vector;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.blr.GaussianPriorImpl;
import weka.classifiers.bayes.blr.LaplacePriorImpl;
import weka.classifiers.bayes.blr.Prior;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.SerializedObject;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;

public class BayesianLogisticRegression
extends Classifier
implements OptionHandler,
TechnicalInformationHandler {
    static final long serialVersionUID = -8013478897911757631L;
    public static double[] LogLikelihood;
    public static double[] InputHyperparameterValues;
    boolean debug = false;
    public boolean NormalizeData = false;
    public double Tolerance = 5.0E-4;
    public double Threshold = 0.5;
    public static final int GAUSSIAN = 1;
    public static final int LAPLACIAN = 2;
    public static final Tag[] TAGS_PRIOR;
    public int PriorClass = 1;
    public int NumFolds = 2;
    public static final int NORM_BASED = 1;
    public static final int CV_BASED = 2;
    public static final int SPECIFIC_VALUE = 3;
    public static final Tag[] TAGS_HYPER_METHOD;
    public int HyperparameterSelection = 1;
    public int ClassIndex = -1;
    public double HyperparameterValue = 0.27;
    public String HyperparameterRange = "R:0.01-316,3.16";
    public int maxIterations = 100;
    public int iterationCounter = 0;
    public double[] BetaVector;
    public double[] DeltaBeta;
    public double[] DeltaUpdate;
    public double[] Delta;
    public double[] Hyperparameters;
    public double[] R;
    public double[] DeltaR;
    public double Change;
    public Filter m_Filter;
    protected Instances m_Instances;
    protected Prior m_PriorUpdate;

    public String globalInfo() {
        return "Implements Bayesian Logistic Regression for both Gaussian and Laplace Priors.\n\nFor more information, see\n\n" + this.getTechnicalInformation();
    }

    public void initialize() throws Exception {
        int n;
        this.Change = 0.0;
        if (this.NormalizeData) {
            this.m_Filter = new Normalize();
            this.m_Filter.setInputFormat(this.m_Instances);
            this.m_Instances = Filter.useFilter(this.m_Instances, this.m_Filter);
        }
        Attribute attribute = new Attribute("(intercept)");
        this.m_Instances.insertAttributeAt(attribute, 0);
        for (n = 0; n < this.m_Instances.numInstances(); ++n) {
            Instance instance = this.m_Instances.instance(n);
            instance.setValue(0, 1.0);
        }
        int n2 = this.m_Instances.numAttributes();
        int n3 = this.m_Instances.numInstances();
        this.ClassIndex = this.m_Instances.classIndex();
        this.iterationCounter = 0;
        switch (this.HyperparameterSelection) {
            case 1: {
                this.HyperparameterValue = this.normBasedHyperParameter();
                if (!this.debug) break;
                System.out.println("Norm-based Hyperparameter: " + this.HyperparameterValue);
                break;
            }
            case 2: {
                this.HyperparameterValue = this.CVBasedHyperparameter();
                if (!this.debug) break;
                System.out.println("CV-based Hyperparameter: " + this.HyperparameterValue);
            }
        }
        this.BetaVector = new double[n2];
        this.Delta = new double[n2];
        this.DeltaBeta = new double[n2];
        this.Hyperparameters = new double[n2];
        this.DeltaUpdate = new double[n2];
        for (int i = 0; i < n2; ++i) {
            this.BetaVector[i] = 0.0;
            this.Delta[i] = 1.0;
            this.DeltaBeta[i] = 0.0;
            this.DeltaUpdate[i] = 0.0;
            this.Hyperparameters[i] = this.HyperparameterValue;
        }
        this.DeltaR = new double[n3];
        this.R = new double[n3];
        for (n = 0; n < n3; ++n) {
            this.DeltaR[n] = 0.0;
            this.R[n] = 0.0;
        }
        this.m_PriorUpdate = this.PriorClass == 1 ? new GaussianPriorImpl() : new LaplacePriorImpl();
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.BINARY_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.setMinimumNumberInstances(0);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        this.getCapabilities().testWithFail(instances);
        this.m_Instances = new Instances(instances);
        this.initialize();
        do {
            for (int i = 0; i < this.m_Instances.numAttributes(); ++i) {
                if (i == this.ClassIndex) continue;
                this.DeltaUpdate[i] = this.m_PriorUpdate.update(i, this.m_Instances, this.BetaVector[i], this.Hyperparameters[i], this.R, this.Delta[i]);
                this.DeltaBeta[i] = Math.min(Math.max(this.DeltaUpdate[i], 0.0 - this.Delta[i]), this.Delta[i]);
                for (int j = 0; j < this.m_Instances.numInstances(); ++j) {
                    Instance instance = this.m_Instances.instance(j);
                    if (instance.value(i) == 0.0) continue;
                    this.DeltaR[j] = this.DeltaBeta[i] * instance.value(i) * BayesianLogisticRegression.classSgn(instance.classValue());
                    int n = j;
                    this.R[n] = this.R[n] + this.DeltaR[j];
                }
                int n = i;
                this.BetaVector[n] = this.BetaVector[n] + this.DeltaBeta[i];
                this.Delta[i] = Math.max(2.0 * Math.abs(this.DeltaBeta[i]), this.Delta[i] / 2.0);
            }
        } while (!this.stoppingCriterion());
        this.m_PriorUpdate.computelogLikelihood(this.BetaVector, this.m_Instances);
        this.m_PriorUpdate.computePenalty(this.BetaVector, this.Hyperparameters);
    }

    public static double classSgn(double d) {
        if (d == 0.0) {
            return -1.0;
        }
        return 1.0;
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = null;
        technicalInformation = new TechnicalInformation(TechnicalInformation.Type.TECHREPORT);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Alexander Genkin and David D. Lewis and David Madigan");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2004");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Large-scale bayesian logistic regression for text categorization");
        technicalInformation.setValue(TechnicalInformation.Field.INSTITUTION, "DIMACS");
        technicalInformation.setValue(TechnicalInformation.Field.URL, "http://www.stat.rutgers.edu/~madigan/PAPERS/shortFat-v3a.pdf");
        return technicalInformation;
    }

    public static double bigF(double d, double d2) {
        double d3 = 0.25;
        double d4 = Math.abs(d);
        if (d4 > d2) {
            d3 = 1.0 / (2.0 + Math.exp(d4 - d2) + Math.exp(d2 - d4));
        }
        return d3;
    }

    public boolean stoppingCriterion() {
        double d = 0.0;
        double d2 = 1.0;
        double d3 = 0.0;
        for (int i = 0; i < this.m_Instances.numInstances(); ++i) {
            d += Math.abs(this.DeltaR[i]);
            d2 += Math.abs(this.R[i]);
        }
        double d4 = Math.abs(d - this.Change);
        this.Change = d4 / d2;
        if (this.debug) {
            System.out.println(this.Change + " <= " + this.Tolerance);
        }
        boolean bl = this.Change <= this.Tolerance || this.iterationCounter >= this.maxIterations;
        ++this.iterationCounter;
        this.Change = d;
        return bl;
    }

    public static double logisticLinkFunction(double d) {
        return Math.exp(d) / (1.0 + Math.exp(d));
    }

    public static double sgn(double d) {
        double d2 = 0.0;
        if (d > 0.0) {
            d2 = 1.0;
        } else if (d < 0.0) {
            d2 = -1.0;
        }
        return d2;
    }

    public double normBasedHyperParameter() {
        double d = 0.0;
        for (int i = 0; i < this.m_Instances.numInstances(); ++i) {
            Instance instance = this.m_Instances.instance(i);
            double d2 = 0.0;
            for (int j = 0; j < this.m_Instances.numAttributes(); ++j) {
                if (j == this.ClassIndex) continue;
                d2 += instance.value(j) * instance.value(j);
            }
            d += d2;
        }
        return (double)this.m_Instances.numAttributes() / (d /= (double)this.m_Instances.numInstances());
    }

    public double classifyInstance(Instance instance) throws Exception {
        double d = 0.0;
        double d2 = 0.0;
        d = this.BetaVector[0];
        for (int i = 0; i < instance.numAttributes(); ++i) {
            if (i == this.ClassIndex - 1) continue;
            d += this.BetaVector[i + 1] * instance.value(i);
        }
        d2 = (d = BayesianLogisticRegression.logisticLinkFunction(d)) > this.Threshold ? 1.0 : 0.0;
        return d2;
    }

    public String toString() {
        if (this.m_Instances == null) {
            return "Bayesian logistic regression: No model built yet.";
        }
        StringBuffer stringBuffer = new StringBuffer();
        String string = "";
        switch (this.HyperparameterSelection) {
            case 1: {
                string = "Norm-Based Hyperparameter Selection: ";
                break;
            }
            case 2: {
                string = "Cross-Validation Based Hyperparameter Selection: ";
                break;
            }
            case 3: {
                string = "Specified Hyperparameter: ";
            }
        }
        stringBuffer.append(string).append(this.HyperparameterValue).append("\n\n");
        stringBuffer.append("Regression Coefficients\n");
        stringBuffer.append("=========================\n\n");
        for (int i = 0; i < this.m_Instances.numAttributes(); ++i) {
            if (i == this.ClassIndex || this.BetaVector[i] == 0.0) continue;
            stringBuffer.append(this.m_Instances.attribute(i).name()).append(" : ").append(this.BetaVector[i]).append("\n");
        }
        stringBuffer.append("===========================\n\n");
        stringBuffer.append("Likelihood: " + this.m_PriorUpdate.getLoglikelihood() + "\n\n");
        stringBuffer.append("Penalty: " + this.m_PriorUpdate.getPenalty() + "\n\n");
        stringBuffer.append("Regularized Log Posterior: " + this.m_PriorUpdate.getLogPosterior() + "\n");
        stringBuffer.append("===========================\n\n");
        return stringBuffer.toString();
    }

    public double CVBasedHyperparameter() throws Exception {
        int n;
        int n2;
        Object object;
        boolean bl = false;
        double[] dArray = null;
        double d = 0.0;
        double d2 = 0.0;
        StringTokenizer stringTokenizer = new StringTokenizer(this.HyperparameterRange);
        String string = stringTokenizer.nextToken(":");
        if (string.equals("R")) {
            object = stringTokenizer.nextToken();
            stringTokenizer = new StringTokenizer((String)object);
            double d3 = Double.parseDouble(stringTokenizer.nextToken("-"));
            stringTokenizer = new StringTokenizer(stringTokenizer.nextToken());
            double d4 = Double.parseDouble(stringTokenizer.nextToken(","));
            double d5 = Double.parseDouble(stringTokenizer.nextToken());
            n2 = (int)((Math.log10(d4) - Math.log10(d3)) / Math.log10(d5) + 1.0);
            dArray = new double[n2];
            n = 0;
            for (double d6 = d3; d6 <= d4; d6 *= d5) {
                dArray[n++] = d6;
            }
        } else if (string.equals("L")) {
            object = new Vector();
            while (stringTokenizer.hasMoreTokens()) {
                ((Vector)object).add(stringTokenizer.nextToken(","));
            }
            dArray = new double[((Vector)object).size()];
            for (n2 = 0; n2 < ((Vector)object).size(); ++n2) {
                dArray[n2] = Double.parseDouble((String)((Vector)object).get(n2));
            }
        }
        if (dArray != null) {
            int n3 = this.NumFolds;
            Random random = new Random();
            this.m_Instances.randomize(random);
            this.m_Instances.stratify(n3);
            for (n = 0; n < dArray.length; ++n) {
                for (int i = 0; i < n3; ++i) {
                    Instances instances = this.m_Instances.trainCV(n3, i, random);
                    SerializedObject serializedObject = new SerializedObject(this);
                    BayesianLogisticRegression bayesianLogisticRegression = (BayesianLogisticRegression)serializedObject.getObject();
                    bayesianLogisticRegression.setHyperparameterSelection(new SelectedTag(3, TAGS_HYPER_METHOD));
                    bayesianLogisticRegression.setHyperparameterValue(dArray[n]);
                    bayesianLogisticRegression.setPriorClass(new SelectedTag(this.PriorClass, TAGS_PRIOR));
                    bayesianLogisticRegression.setThreshold(this.Threshold);
                    bayesianLogisticRegression.setTolerance(this.Tolerance);
                    bayesianLogisticRegression.buildClassifier(instances);
                    Instances instances2 = this.m_Instances.testCV(n3, i);
                    double d7 = bayesianLogisticRegression.getLoglikeliHood(bayesianLogisticRegression.BetaVector, instances2);
                    if (this.debug) {
                        System.out.println("Fold " + i + "Hyperparameter: " + dArray[n]);
                        System.out.println("===================================");
                        System.out.println(" Likelihood: " + d7);
                    }
                    if (!(n == 0 | d7 > d2)) continue;
                    d2 = d7;
                    d = dArray[n];
                }
            }
        } else {
            return this.HyperparameterValue;
        }
        return d;
    }

    public double getLoglikeliHood(double[] dArray, Instances instances) {
        this.m_PriorUpdate.computelogLikelihood(dArray, instances);
        return this.m_PriorUpdate.getLoglikelihood();
    }

    public Enumeration listOptions() {
        Vector<Option> vector = new Vector<Option>();
        vector.addElement(new Option("\tShow Debugging Output\n", "D", 0, "-D"));
        vector.addElement(new Option("\tDistribution of the Prior (1=Gaussian, 2=Laplacian)\n\t(default: 1=Gaussian)", "P", 1, "-P <integer>"));
        vector.addElement(new Option("\tHyperparameter Selection Method (1=Norm-based, 2=CV-based, 3=specific value)\n\t(default: 1=Norm-based)", "H", 1, "-H <integer>"));
        vector.addElement(new Option("\tSpecified Hyperparameter Value (use in conjunction with -H 3)\n\t(default: 0.27)", "V", 1, "-V <double>"));
        vector.addElement(new Option("\tHyperparameter Range (use in conjunction with -H 2)\n\t(format: R:start-end,multiplier OR L:val(1), val(2), ..., val(n))\n\t(default: R:0.01-316,3.16)", "R", 1, "-R <string>"));
        vector.addElement(new Option("\tTolerance Value\n\t(default: 0.0005)", "Tl", 1, "-Tl <double>"));
        vector.addElement(new Option("\tThreshold Value\n\t(default: 0.5)", "S", 1, "-S <double>"));
        vector.addElement(new Option("\tNumber Of Folds (use in conjuction with -H 2)\n\t(default: 2)", "F", 1, "-F <integer>"));
        vector.addElement(new Option("\tMax Number of Iterations\n\t(default: 100)", "I", 1, "-I <integer>"));
        vector.addElement(new Option("\tNormalize the data", "N", 0, "-N"));
        return vector.elements();
    }

    public void setOptions(String[] stringArray) throws Exception {
        String string;
        String string2;
        String string3;
        String string4;
        String string5;
        this.debug = Utils.getFlag('D', stringArray);
        String string6 = Utils.getOption("Tl", stringArray);
        if (string6.length() != 0) {
            this.Tolerance = Double.parseDouble(string6);
        }
        if ((string5 = Utils.getOption('S', stringArray)).length() != 0) {
            this.Threshold = Double.parseDouble(string5);
        }
        if ((string4 = Utils.getOption('H', stringArray)).length() != 0) {
            this.HyperparameterSelection = Integer.parseInt(string4);
        }
        if ((string3 = Utils.getOption('V', stringArray)).length() != 0) {
            this.HyperparameterValue = Double.parseDouble(string3);
        }
        String string7 = Utils.getOption("R", stringArray);
        String string8 = Utils.getOption('P', stringArray);
        if (string8.length() != 0) {
            this.PriorClass = Integer.parseInt(string8);
        }
        if ((string2 = Utils.getOption('F', stringArray)).length() != 0) {
            this.NumFolds = Integer.parseInt(string2);
        }
        if ((string = Utils.getOption('I', stringArray)).length() != 0) {
            this.maxIterations = Integer.parseInt(string);
        }
        this.NormalizeData = Utils.getFlag('N', stringArray);
        Utils.checkForRemainingOptions(stringArray);
    }

    public String[] getOptions() {
        Vector<String> vector = new Vector<String>();
        vector.add("-D");
        vector.add("-Tl");
        vector.add("" + this.Tolerance);
        vector.add("-S");
        vector.add("" + this.Threshold);
        vector.add("-H");
        vector.add("" + this.HyperparameterSelection);
        vector.add("-V");
        vector.add("" + this.HyperparameterValue);
        vector.add("-R");
        vector.add("" + this.HyperparameterRange);
        vector.add("-P");
        vector.add("" + this.PriorClass);
        vector.add("-F");
        vector.add("" + this.NumFolds);
        vector.add("-I");
        vector.add("" + this.maxIterations);
        vector.add("-N");
        return vector.toArray(new String[vector.size()]);
    }

    public static void main(String[] stringArray) {
        BayesianLogisticRegression.runClassifier(new BayesianLogisticRegression(), stringArray);
    }

    public String debugTipText() {
        return "Turns on debugging mode.";
    }

    public void setDebug(boolean bl) {
        this.debug = bl;
    }

    public String hyperparameterSelectionTipText() {
        return "Select the type of Hyperparameter to be used.";
    }

    public SelectedTag getHyperparameterSelection() {
        return new SelectedTag(this.HyperparameterSelection, TAGS_HYPER_METHOD);
    }

    public void setHyperparameterSelection(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_HYPER_METHOD) {
            int n = selectedTag.getSelectedTag().getID();
            if (n >= 1 && n <= 3) {
                this.HyperparameterSelection = n;
            } else {
                throw new IllegalArgumentException("Wrong selection type, -H value should be: 1 for norm-based, 2 for CV-based and 3 for specific value");
            }
        }
    }

    public String priorClassTipText() {
        return "The type of prior to be used.";
    }

    public void setPriorClass(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_PRIOR) {
            int n = selectedTag.getSelectedTag().getID();
            if (n == 1 || n == 2) {
                this.PriorClass = n;
            } else {
                throw new IllegalArgumentException("Wrong selection type, -P value should be: 1 for Gaussian or 2 for Laplacian");
            }
        }
    }

    public SelectedTag getPriorClass() {
        return new SelectedTag(this.PriorClass, TAGS_PRIOR);
    }

    public String thresholdTipText() {
        return "Set the threshold for classifiction. The logistic function doesn't return a class label but an estimate of p(y=+1|B,x(i)). These estimates need to be converted to binary class label predictions. values above the threshold are assigned class +1.";
    }

    public double getThreshold() {
        return this.Threshold;
    }

    public void setThreshold(double d) {
        this.Threshold = d;
    }

    public String toleranceTipText() {
        return "This value decides the stopping criterion.";
    }

    public double getTolerance() {
        return this.Tolerance;
    }

    public void setTolerance(double d) {
        this.Tolerance = d;
    }

    public String hyperparameterValueTipText() {
        return "Specific hyperparameter value. Used when the hyperparameter selection method is set to specific value";
    }

    public double getHyperparameterValue() {
        return this.HyperparameterValue;
    }

    public void setHyperparameterValue(double d) {
        this.HyperparameterValue = d;
    }

    public String numFoldsTipText() {
        return "The number of folds to use for CV-based hyperparameter selection.";
    }

    public int getNumFolds() {
        return this.NumFolds;
    }

    public void setNumFolds(int n) {
        this.NumFolds = n;
    }

    public String maxIterationsTipText() {
        return "The maximum number of iterations to perform.";
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int n) {
        this.maxIterations = n;
    }

    public String normalizeDataTipText() {
        return "Normalize the data.";
    }

    public boolean isNormalizeData() {
        return this.NormalizeData;
    }

    public void setNormalizeData(boolean bl) {
        this.NormalizeData = bl;
    }

    public String hyperparameterRangeTipText() {
        return "Hyperparameter value range. In case of CV-based Hyperparameters, you can specify the range in two ways: \nComma-Separated: L: 3,5,6 (This will be a list of possible values.)\nRange: R:0.01-316,3.16 (This will take values from 0.01-316 (inclusive) in multiplications of 3.16";
    }

    public String getHyperparameterRange() {
        return this.HyperparameterRange;
    }

    public void setHyperparameterRange(String string) {
        this.HyperparameterRange = string;
    }

    public boolean isDebug() {
        return this.debug;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 5516 $");
    }

    static {
        TAGS_PRIOR = new Tag[]{new Tag(1, "Gaussian"), new Tag(2, "Laplacian")};
        TAGS_HYPER_METHOD = new Tag[]{new Tag(1, "Norm-based"), new Tag(2, "CV-based"), new Tag(3, "Specific value")};
    }
}

