/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.util.onnx;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.ByteString;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXShape;

public abstract class ONNXUtils {
    private ONNXUtils() {
    }

    public static OnnxMl.TypeProto buildTensorTypeNode(ONNXShape shape, OnnxMl.TensorProto.DataType type) {
        OnnxMl.TypeProto.Builder builder = OnnxMl.TypeProto.newBuilder();
        OnnxMl.TypeProto.Tensor.Builder tensorBuilder = OnnxMl.TypeProto.Tensor.newBuilder();
        tensorBuilder.setElemType(type.getNumber());
        tensorBuilder.setShape(shape.getProto());
        builder.setTensorType(tensorBuilder.build());
        return builder.build();
    }

    public static OnnxMl.TensorProto scalarBuilder(ONNXContext context, String name, int value) {
        OnnxMl.TensorProto.Builder scalarBuilder = OnnxMl.TensorProto.newBuilder();
        scalarBuilder.setName(context.generateUniqueName(name));
        scalarBuilder.setDataType(OnnxMl.TensorProto.DataType.INT32.getNumber());
        scalarBuilder.addInt32Data(value);
        return scalarBuilder.build();
    }

    public static OnnxMl.TensorProto scalarBuilder(ONNXContext context, String name, long value) {
        OnnxMl.TensorProto.Builder scalarBuilder = OnnxMl.TensorProto.newBuilder();
        scalarBuilder.setName(context.generateUniqueName(name));
        scalarBuilder.setDataType(OnnxMl.TensorProto.DataType.INT64.getNumber());
        scalarBuilder.addInt64Data(value);
        return scalarBuilder.build();
    }

    public static OnnxMl.TensorProto scalarBuilder(ONNXContext context, String name, float value) {
        OnnxMl.TensorProto.Builder scalarBuilder = OnnxMl.TensorProto.newBuilder();
        scalarBuilder.setName(context.generateUniqueName(name));
        scalarBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
        scalarBuilder.addFloatData(value);
        return scalarBuilder.build();
    }

    public static OnnxMl.TensorProto scalarBuilder(ONNXContext context, String name, double value) {
        OnnxMl.TensorProto.Builder scalarBuilder = OnnxMl.TensorProto.newBuilder();
        scalarBuilder.setName(context.generateUniqueName(name));
        scalarBuilder.setDataType(OnnxMl.TensorProto.DataType.DOUBLE.getNumber());
        scalarBuilder.addDoubleData(value);
        return scalarBuilder.build();
    }

    public static OnnxMl.TensorProto floatTensorBuilder(ONNXContext context, String name, List<Integer> dims, Consumer<FloatBuffer> dataPopulator) {
        int size = dims.stream().reduce((a, b) -> a * b).orElse(0);
        ByteBuffer buffer = ByteBuffer.allocate(size * 4).order(ByteOrder.LITTLE_ENDIAN);
        FloatBuffer floatBuffer = buffer.asFloatBuffer();
        dataPopulator.accept(floatBuffer);
        floatBuffer.rewind();
        return OnnxMl.TensorProto.newBuilder().setName(context.generateUniqueName(name)).setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()).addAllDims(dims.stream().map(Integer::longValue).collect(Collectors.toList())).setRawData(ByteString.copyFrom((ByteBuffer)buffer)).build();
    }

    public static OnnxMl.TensorProto doubleTensorBuilder(ONNXContext context, String name, List<Integer> dims, Consumer<DoubleBuffer> dataPopulator) {
        int size = dims.stream().reduce((a, b) -> a * b).orElse(0);
        ByteBuffer buffer = ByteBuffer.allocate(size * 8).order(ByteOrder.LITTLE_ENDIAN);
        DoubleBuffer doubleBuffer = buffer.asDoubleBuffer();
        dataPopulator.accept(doubleBuffer);
        doubleBuffer.rewind();
        return OnnxMl.TensorProto.newBuilder().setName(context.generateUniqueName(name)).setDataType(OnnxMl.TensorProto.DataType.DOUBLE.getNumber()).addAllDims(() -> dims.stream().map(Integer::longValue).iterator()).setRawData(ByteString.copyFrom((ByteBuffer)buffer)).build();
    }

    public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, float[] parameters) {
        return ONNXUtils.floatTensorBuilder(context, name, Collections.singletonList(parameters.length), fb -> fb.put(parameters));
    }

    public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters) {
        return ONNXUtils.arrayBuilder(context, name, parameters, true);
    }

    public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters, boolean downcast) {
        if (downcast) {
            return ONNXUtils.floatTensorBuilder(context, name, Collections.singletonList(parameters.length), fb -> Arrays.stream(parameters).forEachOrdered(d -> fb.put((float)d)));
        }
        return ONNXUtils.doubleTensorBuilder(context, name, Collections.singletonList(parameters.length), db -> Arrays.stream(parameters).forEachOrdered(db::put));
    }

    public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, int[] parameters) {
        OnnxMl.TensorProto.Builder arrBuilder = OnnxMl.TensorProto.newBuilder();
        arrBuilder.setName(context.generateUniqueName(name));
        arrBuilder.addDims(parameters.length);
        int capacity = parameters.length * 4;
        ByteBuffer buffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN);
        arrBuilder.setDataType(OnnxMl.TensorProto.DataType.INT32.getNumber());
        IntBuffer intBuffer = buffer.asIntBuffer();
        intBuffer.put(parameters);
        intBuffer.rewind();
        arrBuilder.setRawData(ByteString.copyFrom((ByteBuffer)buffer));
        return arrBuilder.build();
    }

    public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, long[] parameters) {
        OnnxMl.TensorProto.Builder arrBuilder = OnnxMl.TensorProto.newBuilder();
        arrBuilder.setName(context.generateUniqueName(name));
        arrBuilder.addDims(parameters.length);
        int capacity = parameters.length * 8;
        ByteBuffer buffer = ByteBuffer.allocate(capacity).order(ByteOrder.LITTLE_ENDIAN);
        arrBuilder.setDataType(OnnxMl.TensorProto.DataType.INT64.getNumber());
        LongBuffer longBuffer = buffer.asLongBuffer();
        longBuffer.put(parameters);
        longBuffer.rewind();
        arrBuilder.setRawData(ByteString.copyFrom((ByteBuffer)buffer));
        return arrBuilder.build();
    }
}

