/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.util.List;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.general_regression.GeneralRegressionModel;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.Label;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.SchemaUtil;
import org.jpmml.converter.general_regression.GeneralRegressionModelUtil;
import org.jpmml.rexp.LMConverter;
import org.jpmml.rexp.RDoubleVector;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RExpUtil;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RStringVector;

public class GLMConverter
extends LMConverter {
    public GLMConverter(RGenericVector glm) {
        super(glm);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RGenericVector glm = (RGenericVector)this.getObject();
        RGenericVector family = glm.getGenericElement("family");
        RGenericVector model = glm.getGenericElement("model");
        RStringVector familyFamily = family.getStringElement("family");
        super.encodeSchema(encoder);
        MiningFunction miningFunction = GLMConverter.getMiningFunction((String)familyFamily.asScalar());
        switch (miningFunction) {
            case CLASSIFICATION: {
                Label label = encoder.getLabel();
                RIntegerVector variable = model.getFactorElement(label.getName().getValue());
                DataField dataField = (DataField)encoder.toCategorical(label.getName(), RExpUtil.getFactorLevels(variable));
                encoder.setLabel(dataField);
                break;
            }
        }
    }

    @Override
    public Model encodeModel(Schema schema) {
        RGenericVector glm = (RGenericVector)this.getObject();
        RDoubleVector coefficients = glm.getDoubleElement("coefficients");
        RGenericVector family = glm.getGenericElement("family");
        Double intercept = (Double)coefficients.getElement(this.getInterceptName(), false);
        RStringVector familyFamily = family.getStringElement("family");
        RStringVector familyLink = family.getStringElement("link");
        Label label = schema.getLabel();
        List features = schema.getFeatures();
        SchemaUtil.checkSize((int)(coefficients.size() - (intercept != null ? 1 : 0)), (List)features);
        List<Double> featureCoefficients = this.getFeatureCoefficients(features, coefficients);
        MiningFunction miningFunction = GLMConverter.getMiningFunction((String)familyFamily.asScalar());
        Object targetCategory = null;
        switch (miningFunction) {
            case CLASSIFICATION: {
                CategoricalLabel categoricalLabel = (CategoricalLabel)label;
                SchemaUtil.checkSize((int)2, (CategoricalLabel)categoricalLabel);
                targetCategory = categoricalLabel.getValue(1);
                break;
            }
        }
        GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema((Label)label), null, null, null).setDistribution(GLMConverter.parseFamily((String)familyFamily.asScalar())).setLinkFunction(GLMConverter.parseLinkFunction((String)familyLink.asScalar())).setLinkParameter(GLMConverter.parseLinkParameter((String)familyLink.asScalar()));
        GeneralRegressionModelUtil.encodeRegressionTable((GeneralRegressionModel)generalRegressionModel, (List)features, featureCoefficients, (Double)intercept, (Object)targetCategory);
        switch (miningFunction) {
            case CLASSIFICATION: {
                generalRegressionModel.setOutput(ModelUtil.createProbabilityOutput((DataType)DataType.DOUBLE, (CategoricalLabel)((CategoricalLabel)label)));
                break;
            }
        }
        return generalRegressionModel;
    }

    private static MiningFunction getMiningFunction(String family) {
        GeneralRegressionModel.Distribution distribution = GLMConverter.parseFamily(family);
        switch (distribution) {
            case BINOMIAL: {
                return MiningFunction.CLASSIFICATION;
            }
            case NORMAL: 
            case GAMMA: 
            case IGAUSS: 
            case POISSON: {
                return MiningFunction.REGRESSION;
            }
        }
        throw new IllegalArgumentException();
    }

    private static GeneralRegressionModel.Distribution parseFamily(String family) {
        switch (family) {
            case "binomial": {
                return GeneralRegressionModel.Distribution.BINOMIAL;
            }
            case "gaussian": {
                return GeneralRegressionModel.Distribution.NORMAL;
            }
            case "Gamma": {
                return GeneralRegressionModel.Distribution.GAMMA;
            }
            case "inverse.gaussian": {
                return GeneralRegressionModel.Distribution.IGAUSS;
            }
            case "poisson": {
                return GeneralRegressionModel.Distribution.POISSON;
            }
        }
        throw new IllegalArgumentException(family);
    }

    private static GeneralRegressionModel.LinkFunction parseLinkFunction(String link) {
        switch (link) {
            case "cloglog": {
                return GeneralRegressionModel.LinkFunction.CLOGLOG;
            }
            case "identity": {
                return GeneralRegressionModel.LinkFunction.IDENTITY;
            }
            case "inverse": {
                return GeneralRegressionModel.LinkFunction.POWER;
            }
            case "log": {
                return GeneralRegressionModel.LinkFunction.LOG;
            }
            case "logit": {
                return GeneralRegressionModel.LinkFunction.LOGIT;
            }
            case "probit": {
                return GeneralRegressionModel.LinkFunction.PROBIT;
            }
            case "sqrt": {
                return GeneralRegressionModel.LinkFunction.POWER;
            }
        }
        throw new IllegalArgumentException(link);
    }

    private static Double parseLinkParameter(String link) {
        switch (link) {
            case "inverse": {
                return -1.0;
            }
            case "sqrt": {
                return 0.5;
            }
        }
        return null;
    }
}

