/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.ml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;

public class MLCommonsClientAccessor {
    @Generated
    private static final Logger log = LogManager.getLogger(MLCommonsClientAccessor.class);
    private static final List<String> TARGET_RESPONSE_FILTERS = List.of("sentence_embedding");
    private final MachineLearningNodeClient mlClient;

    public void inferenceSentence(@NonNull String modelId, @NonNull String inputText, @NonNull ActionListener<List<Float>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, List.of(inputText), (ActionListener<List<List<Float>>>)ActionListener.wrap(response -> {
            if (response.size() != 1) {
                listener.onFailure((Exception)new IllegalStateException("Unexpected number of vectors produced. Expected 1 vector to be returned, but got [" + response.size() + "]"));
                return;
            }
            listener.onResponse((Object)((List)response.get(0)));
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void inferenceSentences(@NonNull String modelId, @NonNull List<String> inputText, @NonNull ActionListener<List<List<Float>>> listener) {
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        this.inferenceSentences(TARGET_RESPONSE_FILTERS, modelId, inputText, listener);
    }

    public void inferenceSentences(@NonNull List<String> targetResponseFilters, @NonNull String modelId, @NonNull List<String> inputText, @NonNull ActionListener<List<List<Float>>> listener) {
        Objects.requireNonNull(targetResponseFilters, "targetResponseFilters is marked non-null but is null");
        Objects.requireNonNull(modelId, "modelId is marked non-null but is null");
        Objects.requireNonNull(inputText, "inputText is marked non-null but is null");
        Objects.requireNonNull(listener, "listener is marked non-null but is null");
        MLInput mlInput = this.createMLInput(targetResponseFilters, inputText);
        this.mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
            List<List<Float>> vector = this.buildVectorFromResponse((MLOutput)mlOutput);
            log.debug("Inference Response for input sentence {} is : {} ", (Object)inputText, vector);
            listener.onResponse(vector);
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    private MLInput createMLInput(List<String> targetResponseFilters, List<String> inputText) {
        ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
        TextDocsInputDataSet inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
        return new MLInput(FunctionName.TEXT_EMBEDDING, null, (MLInputDataset)inputDataset);
    }

    private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
        ArrayList<List<Float>> vector = new ArrayList<List<Float>>();
        ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlOutput;
        List tensorOutputList = modelTensorOutput.getMlModelOutputs();
        for (ModelTensors tensors : tensorOutputList) {
            List tensorsList = tensors.getMlModelTensors();
            for (ModelTensor tensor : tensorsList) {
                vector.add(Arrays.stream(tensor.getData()).map(value -> (Float)value).collect(Collectors.toList()));
            }
        }
        return vector;
    }

    @Generated
    public MLCommonsClientAccessor(MachineLearningNodeClient mlClient) {
        this.mlClient = mlClient;
    }
}

