/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms;

import ai.djl.inference.Predictor;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensearch.ml.common.dataset.AsymmetricTextEmbeddingParameters;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;

public abstract class TextEmbeddingModel
extends DLModel {
    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLAlgoParams mlParams = mlInput.getParameters();
        MLInputDataset inputDataSet = this.isAsymmetricModel(mlParams) ? this.addPrefixesToData((AsymmetricTextEmbeddingParameters)mlParams, (TextDocsInputDataSet)mlInput.getInputDataset()) : mlInput.getInputDataset();
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        TextDocsInputDataSet textDocsInput = (TextDocsInputDataSet)inputDataSet;
        ModelResultFilter resultFilter = textDocsInput.getResultFilter();
        for (String doc : textDocsInput.getDocs()) {
            Input input = new Input();
            input.add(doc);
            Output output = (Output)this.getPredictor().predict((Object)input);
            tensorOutputs.add(this.parseModelTensorOutput(output, resultFilter));
        }
        return new ModelTensorOutput(tensorOutputs);
    }

    private boolean isAsymmetricModel(MLAlgoParams mlParams) {
        if (mlParams instanceof AsymmetricTextEmbeddingParameters) {
            if (this.modelConfig == null || ((TextEmbeddingModelConfig)this.modelConfig).getPassagePrefix() == null && ((TextEmbeddingModelConfig)this.modelConfig).getQueryPrefix() == null) {
                throw new IllegalArgumentException("When passing AsymmetricTextEmbeddingParameters, the model requires to be registered with at least one of `query_prefix` or `passage_prefix`.");
            }
            return true;
        }
        if (this.modelConfig != null && (((TextEmbeddingModelConfig)this.modelConfig).getPassagePrefix() != null || ((TextEmbeddingModelConfig)this.modelConfig).getQueryPrefix() != null)) {
            throw new IllegalArgumentException("The embedding model chosen is asymmetric. To use it, you must declare whether the input is of type `QUERY` or of type `PASSAGE`.");
        }
        return false;
    }

    private TextDocsInputDataSet addPrefixesToData(AsymmetricTextEmbeddingParameters mlParams, TextDocsInputDataSet inputDataSet) {
        String prefix;
        TextEmbeddingModelConfig modelConfig = (TextEmbeddingModelConfig)this.modelConfig;
        String string = prefix = mlParams.getEmbeddingContentType() == AsymmetricTextEmbeddingParameters.EmbeddingContentType.PASSAGE ? modelConfig.getPassagePrefix() : modelConfig.getQueryPrefix();
        if (prefix != null) {
            List prefixedDocs = inputDataSet.getDocs().stream().map(s -> prefix + s).collect(Collectors.toList());
            return TextDocsInputDataSet.builder().docs(prefixedDocs).build();
        }
        return inputDataSet;
    }

    @Override
    public void warmUp(Predictor predictor, String modelId, MLModelConfig modelConfig) throws TranslateException {
        Integer modelMaxLength;
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        String warmUpSentence = "warm up sentence";
        if (modelConfig != null && (modelMaxLength = textEmbeddingModelConfig.getModelMaxLength()) != null) {
            warmUpSentence = "sentence ".repeat(modelMaxLength);
        }
        Input input = new Input();
        input.add(warmUpSentence);
        predictor.predict((Object)input);
    }

    @Override
    public Map<String, Object> getArguments(MLModelConfig modelConfig) {
        HashMap<String, Object> arguments = new HashMap<String, Object>();
        if (modelConfig == null) {
            return arguments;
        }
        TextEmbeddingModelConfig textEmbeddingModelConfig = (TextEmbeddingModelConfig)modelConfig;
        Integer modelMaxLength = textEmbeddingModelConfig.getModelMaxLength();
        if (modelMaxLength != null) {
            arguments.put("modelMaxLength", modelMaxLength);
        }
        return arguments;
    }
}

