/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.convolutional;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;

public abstract class Deconvolution
extends AbstractBlock {
    protected Shape kernelShape;
    protected Shape stride;
    protected Shape padding;
    protected Shape outPadding;
    protected Shape dilation;
    protected int filters;
    protected int groups;
    protected boolean includeBias;
    protected Parameter weight;
    protected Parameter bias;

    public Deconvolution(DeconvolutionBuilder<?> builder) {
        this.kernelShape = builder.kernelShape;
        this.stride = builder.stride;
        this.padding = builder.padding;
        this.outPadding = builder.outPadding;
        this.dilation = builder.dilation;
        this.filters = builder.filters;
        this.groups = builder.groups;
        this.includeBias = builder.includeBias;
        this.weight = this.addParameter(Parameter.builder().setName("weight").setType(Parameter.Type.WEIGHT).build());
        if (this.includeBias) {
            this.bias = this.addParameter(Parameter.builder().setName("bias").setType(Parameter.Type.BIAS).build());
        }
    }

    protected abstract LayoutType[] getExpectedLayout();

    protected abstract String getStringLayout();

    protected abstract int numDimensions();

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        Device device = input.getDevice();
        NDArray weightArr = parameterStore.getValue(this.weight, device, training);
        NDArray biasArr = parameterStore.getValue(this.bias, device, training);
        return Deconvolution.deconvolution(input, weightArr, biasArr, this.stride, this.padding, this.outPadding, this.dilation, this.groups);
    }

    @Override
    protected void beforeInitialize(Shape ... inputShapes) {
        super.beforeInitialize(inputShapes);
        Block.validateLayout(this.getExpectedLayout(), inputShapes[0].getLayout());
    }

    @Override
    protected void prepare(Shape[] inputs) {
        long inputChannel = inputs[0].get(1);
        this.weight.setShape(new Shape(this.filters, inputChannel / (long)this.groups).addAll(this.kernelShape));
        if (this.bias != null) {
            this.bias.setShape(new Shape(this.filters));
        }
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputs) {
        long[] shape = new long[this.numDimensions()];
        shape[0] = inputs[0].get(0);
        shape[1] = this.filters;
        for (int i = 0; i < this.numDimensions() - 2; ++i) {
            shape[2 + i] = (inputs[0].get(2 + i) - 1L) * this.stride.get(i) - 2L * this.padding.get(i) + this.dilation.get(i) * (this.kernelShape.get(i) - 1L) + this.outPadding.get(i) + 1L;
        }
        return new Shape[]{new Shape(shape)};
    }

    @Override
    public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        if (loadVersion != this.version) {
            throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
        }
        this.readInputShapes(is);
    }

    static NDList deconvolution(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation, int groups) {
        return input.getNDArrayInternal().deconvolution(input, weight, bias, stride, padding, outPadding, dilation, groups);
    }

    public static abstract class DeconvolutionBuilder<T extends DeconvolutionBuilder> {
        protected Shape kernelShape;
        protected Shape stride;
        protected Shape padding;
        protected Shape outPadding;
        protected Shape dilation;
        protected int filters;
        protected int groups = 1;
        protected boolean includeBias = true;

        public T setKernelShape(Shape kernelShape) {
            this.kernelShape = kernelShape;
            return this.self();
        }

        public T optStride(Shape stride) {
            this.stride = stride;
            return this.self();
        }

        public T optPadding(Shape padding) {
            this.padding = padding;
            return this.self();
        }

        public T optOutPadding(Shape outPadding) {
            this.outPadding = outPadding;
            return this.self();
        }

        public T optDilation(Shape dilate) {
            this.dilation = dilate;
            return this.self();
        }

        public T setFilters(int filters) {
            this.filters = filters;
            return this.self();
        }

        public T optGroups(int groups) {
            this.groups = groups;
            return this.self();
        }

        public T optBias(boolean includeBias) {
            this.includeBias = includeBias;
            return this.self();
        }

        protected void validate() {
            if (this.kernelShape == null || this.filters == 0) {
                throw new IllegalArgumentException("Kernel and numFilters must be set");
            }
        }

        protected abstract T self();
    }
}

