/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.security.AccessController;
import java.time.Duration;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.util.TokenBucket;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.AbstractConnectorExecutor;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.ExecutionContext;
import org.opensearch.ml.engine.algorithms.remote.MLSdkAsyncHttpResponseHandler;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;
import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher;
import software.amazon.awssdk.http.SdkHttpFullRequest;
import software.amazon.awssdk.http.SdkHttpMethod;
import software.amazon.awssdk.http.SdkHttpRequest;
import software.amazon.awssdk.http.async.AsyncExecuteRequest;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.async.SdkHttpContentPublisher;

@ConnectorExecutor(value="aws_sigv4")
public class AwsConnectorExecutor
extends AbstractConnectorExecutor {
    @Generated
    private static final Logger log = LogManager.getLogger(AwsConnectorExecutor.class);
    private AwsConnector connector;
    private ScriptService scriptService;
    private TokenBucket rateLimiter;
    private Map<String, TokenBucket> userRateLimiterMap;
    private Client client;
    private MLGuard mlGuard;
    private SdkAsyncHttpClient httpClient;

    public AwsConnectorExecutor(Connector connector) {
        super.initialize(connector);
        this.connector = (AwsConnector)connector;
        Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout().intValue());
        Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout().intValue());
        Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
        this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
    }

    @Override
    public Logger getLogger() {
        return log;
    }

    @Override
    public void invokeRemoteService(String action, MLInput mlInput, Map<String, String> parameters, String payload, ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener) {
        try {
            SdkHttpFullRequest request = ConnectorUtils.buildSdkRequest(action, (Connector)this.connector, parameters, payload, SdkHttpMethod.POST);
            AsyncExecuteRequest executeRequest = AsyncExecuteRequest.builder().request((SdkHttpRequest)this.signRequest(request)).requestContentPublisher((SdkHttpContentPublisher)new SimpleHttpContentPublisher(request)).responseHandler((SdkAsyncHttpResponseHandler)new MLSdkAsyncHttpResponseHandler(executionContext, actionListener, parameters, (Connector)this.connector, this.scriptService, this.mlGuard, action)).build();
            AccessController.doPrivileged(() -> this.httpClient.execute(executeRequest));
        }
        catch (RuntimeException exception) {
            log.error("Failed to execute {} in aws connector: {}", (Object)action, (Object)exception.getMessage(), (Object)exception);
            actionListener.onFailure((Exception)exception);
        }
        catch (Throwable e) {
            log.error("Failed to execute {} in aws connector", (Object)action, (Object)e);
            actionListener.onFailure((Exception)new MLException("Fail to execute " + action + " in aws connector", e));
        }
    }

    private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) {
        String accessKey = this.connector.getAccessKey();
        String secretKey = this.connector.getSecretKey();
        String sessionToken = this.connector.getSessionToken();
        String signingName = this.connector.getServiceName();
        String region = this.connector.getRegion();
        return ConnectorUtils.signRequest(request, accessKey, secretKey, sessionToken, signingName, region);
    }

    @Generated
    public AwsConnector getConnector() {
        return this.connector;
    }

    @Override
    @Generated
    public void setScriptService(ScriptService scriptService) {
        this.scriptService = scriptService;
    }

    @Override
    @Generated
    public ScriptService getScriptService() {
        return this.scriptService;
    }

    @Override
    @Generated
    public void setRateLimiter(TokenBucket rateLimiter) {
        this.rateLimiter = rateLimiter;
    }

    @Override
    @Generated
    public TokenBucket getRateLimiter() {
        return this.rateLimiter;
    }

    @Override
    @Generated
    public void setUserRateLimiterMap(Map<String, TokenBucket> userRateLimiterMap) {
        this.userRateLimiterMap = userRateLimiterMap;
    }

    @Override
    @Generated
    public Map<String, TokenBucket> getUserRateLimiterMap() {
        return this.userRateLimiterMap;
    }

    @Override
    @Generated
    public void setClient(Client client) {
        this.client = client;
    }

    @Override
    @Generated
    public Client getClient() {
        return this.client;
    }

    @Override
    @Generated
    public void setMlGuard(MLGuard mlGuard) {
        this.mlGuard = mlGuard;
    }

    @Override
    @Generated
    public MLGuard getMlGuard() {
        return this.mlGuard;
    }
}

