/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.security.AccessController;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.apache.http.HttpEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.util.EntityUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.AbstractConnectorExecutor;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;

@ConnectorExecutor(value="http")
public class HttpJsonConnectorExecutor
extends AbstractConnectorExecutor {
    @Generated
    private static final Logger log = LogManager.getLogger(HttpJsonConnectorExecutor.class);
    private HttpConnector connector;
    private ScriptService scriptService;
    private TokenBucket rateLimiter;
    private Map<String, TokenBucket> userRateLimiterMap;
    private Client client;
    private MLGuard mlGuard;
    private CloseableHttpClient httpClient;

    public HttpJsonConnectorExecutor(Connector connector) {
        super.initialize(connector);
        this.connector = (HttpConnector)connector;
        this.httpClient = MLHttpClientFactory.getCloseableHttpClient(super.getConnectorClientConfig().getConnectionTimeout(), super.getConnectorClientConfig().getReadTimeout(), super.getConnectorClientConfig().getMaxConnections());
    }

    public HttpJsonConnectorExecutor(Connector connector, CloseableHttpClient httpClient) {
        this(connector);
        this.httpClient = httpClient;
    }

    @Override
    public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
        try {
            HttpGet request;
            AtomicReference<String> responseRef = new AtomicReference<String>("");
            AtomicReference statusCodeRef = new AtomicReference();
            switch (this.connector.getPredictHttpMethod().toUpperCase(Locale.ROOT)) {
                case "POST": {
                    try {
                        String predictEndpoint = this.connector.getPredictEndpoint(parameters);
                        request = new HttpPost(predictEndpoint);
                        String charset = parameters.containsKey("charset") ? parameters.get("charset") : "UTF-8";
                        StringEntity entity = new StringEntity(payload, charset);
                        ((HttpPost)request).setEntity((HttpEntity)entity);
                        break;
                    }
                    catch (Exception e) {
                        throw new MLException("Failed to create http request for remote model", (Throwable)e);
                    }
                }
                case "GET": {
                    try {
                        request = new HttpGet(this.connector.getPredictEndpoint(parameters));
                        break;
                    }
                    catch (Exception e) {
                        throw new MLException("Failed to create http request for remote model", (Throwable)e);
                    }
                }
                default: {
                    throw new IllegalArgumentException("unsupported http method");
                }
            }
            Map headers = this.connector.getDecryptedHeaders();
            boolean hasContentTypeHeader = false;
            if (headers != null) {
                for (String key : headers.keySet()) {
                    request.addHeader(key, (String)headers.get(key));
                    if (!key.toLowerCase().equals("Content-Type")) continue;
                    hasContentTypeHeader = true;
                }
            }
            if (!hasContentTypeHeader) {
                request.addHeader("Content-Type", "application/json");
            }
            AccessController.doPrivileged(() -> this.lambda$invokeRemoteModel$0((HttpUriRequest)request, responseRef, statusCodeRef));
            String modelResponse = responseRef.get();
            if (this.getMlGuard() != null && !this.getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT).booleanValue()) {
                throw new IllegalArgumentException("guardrails triggered for LLM output");
            }
            Integer statusCode = (Integer)statusCodeRef.get();
            if (statusCode < 200 || statusCode >= 300) {
                throw new OpenSearchStatusException("Error from remote service: " + modelResponse, RestStatus.fromCode((int)statusCode), new Object[0]);
            }
            ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, (Connector)this.connector, this.scriptService, parameters);
            tensors.setStatusCode(statusCode);
            tensorOutputs.add(tensors);
        }
        catch (RuntimeException e) {
            log.error("Fail to execute http connector", (Throwable)e);
            throw e;
        }
        catch (Throwable e) {
            log.error("Fail to execute http connector", e);
            throw new MLException("Fail to execute http connector", e);
        }
    }

    @Generated
    public HttpConnector getConnector() {
        return this.connector;
    }

    @Override
    @Generated
    public void setScriptService(ScriptService scriptService) {
        this.scriptService = scriptService;
    }

    @Override
    @Generated
    public ScriptService getScriptService() {
        return this.scriptService;
    }

    @Override
    @Generated
    public void setRateLimiter(TokenBucket rateLimiter) {
        this.rateLimiter = rateLimiter;
    }

    @Override
    @Generated
    public TokenBucket getRateLimiter() {
        return this.rateLimiter;
    }

    @Override
    @Generated
    public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
        this.userRateLimiterMap = userRateLimiterMap;
    }

    @Override
    @Generated
    public Map<String, TokenBucket> getUserRateLimiterMap() {
        return this.userRateLimiterMap;
    }

    @Override
    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Override
    @Generated
    public Client getClient() {
        return this.client;
    }

    @Override
    @Generated
    public void setMlGuard(MLGuard mlGuard) {
        this.mlGuard = mlGuard;
    }

    @Override
    @Generated
    public MLGuard getMlGuard() {
        return this.mlGuard;
    }

    private /* synthetic */ Void lambda$invokeRemoteModel$0(HttpUriRequest request, AtomicReference responseRef, AtomicReference statusCodeRef) throws Exception {
        try (CloseableHttpResponse response = this.httpClient.execute(request);){
            HttpEntity responseEntity = response.getEntity();
            String responseBody = EntityUtils.toString((HttpEntity)responseEntity);
            EntityUtils.consume((HttpEntity)responseEntity);
            responseRef.set(responseBody);
            statusCodeRef.set(response.getStatusLine().getStatusCode());
        }
        return null;
    }
}

