/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.script.ScriptService;

public interface RemoteConnectorExecutor {
    default public ModelTensorOutput executePredict(MLInput mlInput) {
        ArrayList<ModelTensors> tensorOutputs = new ArrayList<ModelTensors>();
        if (mlInput.getInputDataset() instanceof TextDocsInputDataSet) {
            TextDocsInputDataSet textDocsInputDataSet = (TextDocsInputDataSet)mlInput.getInputDataset();
            int processedDocs = 0;
            while (processedDocs < textDocsInputDataSet.getDocs().size()) {
                Map parameters;
                List textDocs = textDocsInputDataSet.getDocs().subList(processedDocs, textDocsInputDataSet.getDocs().size());
                ArrayList<ModelTensors> tempTensorOutputs = new ArrayList<ModelTensors>();
                this.preparePayloadAndInvokeRemoteModel(MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset((MLInputDataset)TextDocsInputDataSet.builder().docs(textDocs).build()).build(), tempTensorOutputs);
                int tensorCount = 0;
                if (tempTensorOutputs.size() > 0 && ((ModelTensors)tempTensorOutputs.get(0)).getMlModelTensors() != null) {
                    tensorCount = ((ModelTensors)tempTensorOutputs.get(0)).getMlModelTensors().size();
                }
                if ((parameters = this.getConnector().getParameters()) != null && parameters.containsKey("input_docs_processed_step_size")) {
                    int stepSize = Integer.parseInt((String)parameters.get("input_docs_processed_step_size"));
                    if (stepSize <= 0) {
                        throw new IllegalArgumentException("Invalid parameter: input_docs_processed_step_size. It must be positive integer.");
                    }
                    processedDocs += stepSize;
                } else {
                    processedDocs += Math.max(tensorCount, 1);
                }
                tensorOutputs.addAll(tempTensorOutputs);
            }
        } else {
            this.preparePayloadAndInvokeRemoteModel(mlInput, tensorOutputs);
        }
        return new ModelTensorOutput(tensorOutputs);
    }

    default public void setScriptService(ScriptService scriptService) {
    }

    public ScriptService getScriptService();

    public Connector getConnector();

    public TokenBucket getRateLimiter();

    public Map<String, TokenBucket> getUserRateLimiterMap();

    public MLGuard getMlGuard();

    public Client getClient();

    default public void setClient(Client client) {
    }

    default public void setXContentRegistry(NamedXContentRegistry xContentRegistry) {
    }

    default public void setClusterService(ClusterService clusterService) {
    }

    default public void setRateLimiter(TokenBucket rateLimiter) {
    }

    default public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
    }

    default public void setMlGuard(MLGuard mlGuard) {
    }

    default public void preparePayloadAndInvokeRemoteModel(MLInput mlInput, List<ModelTensors> tensorOutputs) {
        Connector connector = this.getConnector();
        HashMap<String, String> parameters = new HashMap<String, String>();
        if (connector.getParameters() != null) {
            parameters.putAll(connector.getParameters());
        }
        MLInputDataset inputDataset = mlInput.getInputDataset();
        HashMap inputParameters = new HashMap();
        if (inputDataset instanceof RemoteInferenceInputDataSet && ((RemoteInferenceInputDataSet)inputDataset).getParameters() != null) {
            ConnectorUtils.escapeRemoteInferenceInputData((RemoteInferenceInputDataSet)inputDataset);
            inputParameters.putAll(((RemoteInferenceInputDataSet)inputDataset).getParameters());
        }
        parameters.putAll(inputParameters);
        RemoteInferenceInputDataSet inputData = ConnectorUtils.processInput(mlInput, connector, parameters, this.getScriptService());
        if (inputData.getParameters() != null) {
            parameters.putAll(inputData.getParameters());
        }
        parameters.putAll(inputParameters);
        String payload = (String)connector.createPredictPayload(parameters);
        connector.validatePayload(payload);
        String userStr = (String)this.getClient().threadPool().getThreadContext().getTransient("_opendistro_security_user_info");
        User user = User.parse((String)userStr);
        if (this.getRateLimiter() != null && !this.getRateLimiter().request()) {
            throw new OpenSearchStatusException("Request is throttled at model level.", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        if (user != null && this.getUserRateLimiterMap() != null && this.getUserRateLimiterMap().get(user.getName()) != null && !this.getUserRateLimiterMap().get(user.getName()).request()) {
            throw new OpenSearchStatusException("Request is throttled at user level. If you think there's an issue, please contact your cluster admin.", RestStatus.TOO_MANY_REQUESTS, new Object[0]);
        }
        if (this.getMlGuard() != null && !this.getMlGuard().validate(payload, MLGuard.Type.INPUT).booleanValue()) {
            throw new IllegalArgumentException("guardrails triggered for user input");
        }
        this.invokeRemoteModel(mlInput, parameters, payload, tensorOutputs);
    }

    public void invokeRemoteModel(MLInput var1, Map<String, String> var2, String var3, List<ModelTensors> var4);
}

