/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.tokenize;

import ai.djl.MalformedModelException;
import ai.djl.huggingface.tokenizers.Encoding;
import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.parameter.textembedding.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.input.parameter.textembedding.SparseEmbeddingFormat;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.analysis.DJLUtils;
import org.opensearch.ml.engine.annotation.Function;

@Function(value=FunctionName.SPARSE_TOKENIZE)
public class SparseTokenizerModel
extends DLModel {
    @Generated
    private static final Logger log = LogManager.getLogger(SparseTokenizerModel.class);
    private HuggingFaceTokenizer tokenizer;
    private Map<String, Float> idf;
    public String IDF_FILE_NAME = "idf.json";

    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLInputDataset inputDataSet = mlInput.getInputDataset();
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet)inputDataSet;
        MLAlgoParams parameters = mlInput.getParameters();
        SparseEmbeddingFormat sparseEmbeddingFormat = SparseEmbeddingFormat.WORD;
        if (parameters instanceof AsymmetricTextEmbeddingParameters) {
            AsymmetricTextEmbeddingParameters sparseParams = (AsymmetricTextEmbeddingParameters)parameters;
            sparseEmbeddingFormat = sparseParams.getSparseEmbeddingFormat();
        }
        for (String doc : textDocsInput.getDocs()) {
            Encoding encodings = this.tokenizer.encode(doc);
            long[] indices = encodings.getIds();
            long[] uniqueIndices = Arrays.stream(indices).distinct().toArray();
            String[] tokens = (String[])Arrays.stream(uniqueIndices).mapToObj(value -> this.tokenizer.decode(new long[]{value}, true)).toArray(String[]::new);
            HashMap<String, Float> tokenWeights = new HashMap<String, Float>();
            for (int i = 0; i < uniqueIndices.length; ++i) {
                String token = tokens[i];
                if (token.isEmpty()) continue;
                if (sparseEmbeddingFormat == SparseEmbeddingFormat.TOKEN_ID) {
                    tokenWeights.put(String.valueOf(uniqueIndices[i]), this.idf.getOrDefault(token, Float.valueOf(1.0f)));
                    continue;
                }
                tokenWeights.put(token, this.idf.getOrDefault(token, Float.valueOf(1.0f)));
            }
            Map wrappedMap = Map.of("response", Collections.singletonList(tokenWeights));
            ModelTensor tensor = ModelTensor.builder().dataAsMap(wrappedMap).build();
            tensorOutputs.add(new ModelTensors(List.of(tensor)));
        }
        return new ModelTensorOutput(tensorOutputs);
    }

    @Override
    protected void doLoadModel(List<Predictor<Input, Output>> predictorList, List<ZooModel<Input, Output>> modelList, String engine, Path modelPath, MLModelConfig modelConfig) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        this.tokenizer = HuggingFaceTokenizer.builder().optPadding(true).optTokenizerPath(modelPath.resolve("tokenizer.json")).build();
        this.idf = new HashMap<String, Float>();
        if (Files.exists(modelPath.resolve(this.IDF_FILE_NAME), new LinkOption[0])) {
            this.idf = DJLUtils.fetchTokenWeights(modelPath.resolve(this.IDF_FILE_NAME));
        }
        log.info("sparse tokenize Model {} is successfully deployed", (Object)this.modelId);
    }

    @Override
    public boolean isModelReady() {
        return this.modelHelper != null && this.modelId != null && this.tokenizer != null;
    }

    @Override
    public void close() {
        if (this.modelHelper != null && this.modelId != null) {
            this.modelHelper.deleteFileCache(this.modelId);
            if (this.idf != null || this.tokenizer != null) {
                this.tokenizer = null;
                this.idf = null;
            }
        }
    }

    @Override
    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) {
        return null;
    }

    @Override
    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        return null;
    }
}

