/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.security.AccessController;
import java.time.Duration;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.Client;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.AbstractConnectorExecutor;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.core.internal.http.loader.DefaultSdkHttpClientBuilder;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.ContentStreamProvider;
import software.amazon.awssdk.http.HttpExecuteRequest;
import software.amazon.awssdk.http.HttpExecuteResponse;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.http.SdkHttpConfigurationOption;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.utils.AttributeMap;

@ConnectorExecutor(value="aws_sigv4")
public class AwsConnectorExecutor
extends AbstractConnectorExecutor {
    @Generated
    private static final Logger log = LogManager.getLogger(AwsConnectorExecutor.class);
    private AwsConnector connector;
    private SdkHttpClient httpClient;
    private ScriptService scriptService;
    private TokenBucket rateLimiter;
    private Map<String, TokenBucket> userRateLimiterMap;
    private Client client;
    private MLGuard mlGuard;

    public AwsConnectorExecutor(Connector connector, SdkHttpClient httpClient) {
        this.connector = (AwsConnector)connector;
        this.httpClient = httpClient;
    }

    public AwsConnectorExecutor(Connector connector) {
        super.initialize(connector);
        this.connector = (AwsConnector)connector;
        Duration connectionTimeout = Duration.ofMillis(super.getConnectorClientConfig().getConnectionTimeout().intValue());
        Duration readTimeout = Duration.ofMillis(super.getConnectorClientConfig().getReadTimeout().intValue());
        try (AttributeMap attributeMap = AttributeMap.builder().put((AttributeMap.Key)SdkHttpConfigurationOption.CONNECTION_TIMEOUT, (Object)connectionTimeout).put((AttributeMap.Key)SdkHttpConfigurationOption.READ_TIMEOUT, (Object)readTimeout).put((AttributeMap.Key)SdkHttpConfigurationOption.MAX_CONNECTIONS, (Object)super.getConnectorClientConfig().getMaxConnections()).build();){
            log.info("Initializing aws connector http client with attributes: connectionTimeout={}, readTimeout={}, maxConnections={}", (Object)connectionTimeout, (Object)readTimeout, (Object)super.getConnectorClientConfig().getMaxConnections());
            this.httpClient = new DefaultSdkHttpClientBuilder().buildWithDefaults(attributeMap);
        }
        catch (RuntimeException e) {
            log.error("Error initializing AWS connector HTTP client.", (Throwable)e);
            throw e;
        }
        catch (Throwable e) {
            log.error("Error initializing AWS connector HTTP client.", e);
            throw new MLException(e);
        }
    }

    @Override
    public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, String payload, List<ModelTensors> tensorOutputs) {
        try {
            String endpoint = this.connector.getPredictEndpoint(parameters);
            RequestBody requestBody = RequestBody.fromString((String)payload);
            SdkHttpFullRequest.Builder builder = SdkHttpFullRequest.builder().method(SdkHttpMethod.POST).uri(URI.create(endpoint)).contentStreamProvider(requestBody.contentStreamProvider());
            Map headers = this.connector.getDecryptedHeaders();
            if (headers != null) {
                for (String key : headers.keySet()) {
                    builder.putHeader(key, (String)headers.get(key));
                }
            }
            SdkHttpFullRequest request = builder.build();
            HttpExecuteRequest executeRequest = HttpExecuteRequest.builder().request((SdkHttpRequest)this.signRequest(request)).contentStreamProvider((ContentStreamProvider)request.contentStreamProvider().orElse(null)).build();
            HttpExecuteResponse response = AccessController.doPrivileged(() -> this.httpClient.prepareRequest(executeRequest).call());
            int statusCode = response.httpResponse().statusCode();
            AbortableInputStream body = null;
            if (response.responseBody().isPresent()) {
                body = (AbortableInputStream)response.responseBody().get();
            }
            StringBuilder responseBuilder = new StringBuilder();
            if (body != null) {
                try (BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)body, StandardCharsets.UTF_8));){
                    String line;
                    while ((line = reader.readLine()) != null) {
                        responseBuilder.append(line);
                    }
                }
            } else {
                throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST, new Object[0]);
            }
            String modelResponse = responseBuilder.toString();
            if (this.getMlGuard() != null && !this.getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT).booleanValue()) {
                throw new IllegalArgumentException("guardrails triggered for LLM output");
            }
            if (statusCode < 200 || statusCode >= 300) {
                throw new OpenSearchStatusException("Error from remote service: " + modelResponse, RestStatus.fromCode((int)statusCode), new Object[0]);
            }
            ModelTensors tensors = ConnectorUtils.processOutput(modelResponse, (Connector)this.connector, this.scriptService, parameters);
            tensors.setStatusCode(Integer.valueOf(statusCode));
            tensorOutputs.add(tensors);
        }
        catch (RuntimeException exception) {
            log.error("Failed to execute predict in aws connector: " + exception.getMessage(), (Throwable)exception);
            throw exception;
        }
        catch (Throwable e) {
            log.error("Failed to execute predict in aws connector", e);
            throw new MLException("Fail to execute predict in aws connector", e);
        }
    }

    private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {
        String accessKey = this.connector.getAccessKey();
        String secretKey = this.connector.getSecretKey();
        String sessionToken = this.connector.getSessionToken();
        String signingName = this.connector.getServiceName();
        String region = this.connector.getRegion();
        return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region);
    }

    @Generated
    public AwsConnector getConnector() {
        return this.connector;
    }

    @Override
    @Generated
    public void setScriptService(ScriptService scriptService) {
        this.scriptService = scriptService;
    }

    @Override
    @Generated
    public ScriptService getScriptService() {
        return this.scriptService;
    }

    @Override
    @Generated
    public void setRateLimiter(TokenBucket rateLimiter) {
        this.rateLimiter = rateLimiter;
    }

    @Override
    @Generated
    public TokenBucket getRateLimiter() {
        return this.rateLimiter;
    }

    @Override
    @Generated
    public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
        this.userRateLimiterMap = userRateLimiterMap;
    }

    @Override
    @Generated
    public Map<String, TokenBucket> getUserRateLimiterMap() {
        return this.userRateLimiterMap;
    }

    @Override
    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Override
    @Generated
    public Client getClient() {
        return this.client;
    }

    @Override
    @Generated
    public void setMlGuard(MLGuard mlGuard) {
        this.mlGuard = mlGuard;
    }

    @Override
    @Generated
    public MLGuard getMlGuard() {
        return this.mlGuard;
    }
}

