/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.text_similarity;

import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import java.util.ArrayList;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.DLModel;
import org.opensearch.ml.engine.algorithms.text_similarity.TextSimilarityTranslator;
import org.opensearch.ml.engine.annotation.Function;

@Function(value=FunctionName.TEXT_SIMILARITY)
public class TextSimilarityCrossEncoderModel
extends DLModel {
    @Override
    public ModelTensorOutput predict(String modelId, MLInput mlInput) throws TranslateException {
        MLInputDataset inputDataSet = mlInput.getInputDataset();
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        TextSimilarityInputDataSet textSimInput = (TextSimilarityInputDataSet)inputDataSet;
        String queryText = textSimInput.getQueryText();
        for (String doc : textSimInput.getTextDocs()) {
            Input input = new Input();
            input.add(queryText);
            input.add(doc);
            Output output = (Output)this.getPredictor().predict((Object)input);
            ModelTensors outputTensors = ModelTensors.fromBytes((byte[])output.getData().getAsBytes());
            tensorOutputs.add(outputTensors);
        }
        return new ModelTensorOutput(tensorOutputs);
    }

    @Override
    public Translator<Input, Output> getTranslator(String engine, MLModelConfig modelConfig) throws IllegalArgumentException {
        return new TextSimilarityTranslator();
    }

    @Override
    public TranslatorFactory getTranslatorFactory(String engine, MLModelConfig modelConfig) {
        return null;
    }
}

