/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.stream.IntStream;

public class SparseMax
extends AbstractBlock {
    private static final Byte VERSION = 1;
    private int axis;
    private int topK;

    public SparseMax() {
        this(-1, 3);
    }

    public SparseMax(int axis) {
        this(axis, 3);
    }

    public SparseMax(int axis, int topK) {
        super(VERSION);
        this.axis = axis;
        this.topK = topK;
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[]{inputShapes[0]};
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        if (this.axis != -1) {
            input = input.swapAxes(this.axis, -1);
        }
        NDArray level = input.argSort(-1, false).toType(DataType.INT64, false);
        int lastDimSize = (int)input.size(input.getShape().dimension() - 1);
        NDArray maskTopK = NDArrays.add((NDArray[])IntStream.range(0, this.topK).mapToObj(j -> level.get("..., {}", j).oneHot(lastDimSize)).toArray(NDArray[]::new));
        NDArray expSum = input.exp().mul(maskTopK).sum(new int[]{-1}, true).broadcast(input.getShape());
        NDArray output = input.exp().mul(maskTopK).div(expSum);
        if (this.axis != -1) {
            output = output.swapAxes(this.axis, -1);
        }
        return new NDList(output);
    }
}

