/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Generated;
import org.apache.commons.collections.MapUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.util.Strings;
import org.jetbrains.annotations.NotNull;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.ExecutionContext;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorThrottlingException;
import org.opensearch.script.ScriptService;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.http.SdkHttpFullResponse;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;

public class MLSdkAsyncHttpResponseHandler
implements SdkAsyncHttpResponseHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(MLSdkAsyncHttpResponseHandler.class);
    public static final String AMZ_ERROR_HEADER = "x-amzn-ErrorType";
    private Integer statusCode;
    private final StringBuilder responseBody = new StringBuilder();
    private final ExecutionContext executionContext;
    private final ActionListener<Tuple<Integer, ModelTensors>> actionListener;
    private final Map<String, String> parameters;
    private final Connector connector;
    private final String action;
    private final ScriptService scriptService;
    private final MLGuard mlGuard;
    private AtomicReference<Exception> exceptionHolder = new AtomicReference();

    public MLSdkAsyncHttpResponseHandler(ExecutionContext executionContext, ActionListener<Tuple<Integer, ModelTensors>> actionListener, Map<String, String> parameters, Connector connector, ScriptService scriptService, MLGuard mlGuard, String action) {
        this.executionContext = executionContext;
        this.actionListener = actionListener;
        this.parameters = parameters;
        this.connector = connector;
        this.scriptService = scriptService;
        this.mlGuard = mlGuard;
        this.action = action;
    }

    public void onHeaders(SdkHttpResponse response) {
        SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse)response;
        log.debug("received response headers: " + sdkResponse.headers());
        this.statusCode = sdkResponse.statusCode();
        if (this.statusCode < 200 || this.statusCode > 300) {
            this.handleThrottlingInHeader(sdkResponse);
        }
    }

    public void onStream(Publisher<ByteBuffer> stream) {
        stream.subscribe((Subscriber)new MLResponseSubscriber());
    }

    public void onError(Throwable error) {
        log.error(error.getMessage(), error);
        RestStatus status = this.statusCode == null ? RestStatus.INTERNAL_SERVER_ERROR : RestStatus.fromCode((int)this.statusCode);
        String errorMessage = "Error communicating with remote model: " + error.getMessage();
        this.actionListener.onFailure((Exception)new OpenSearchStatusException(errorMessage, status, new Object[0]));
    }

    private void handleException(Exception e) {
        if (this.exceptionHolder.get() == null) {
            this.exceptionHolder.compareAndSet(null, e);
        }
    }

    private void handleThrottlingInHeader(SdkHttpFullResponse sdkResponse) {
        if (MapUtils.isEmpty((Map)sdkResponse.headers())) {
            return;
        }
        List errorsInHeader = (List)sdkResponse.headers().get(AMZ_ERROR_HEADER);
        if (errorsInHeader == null || errorsInHeader.isEmpty()) {
            return;
        }
        boolean containsThrottlingException = errorsInHeader.stream().anyMatch(str -> str.startsWith("ThrottlingException"));
        if (containsThrottlingException) {
            log.error("Remote server returned error code: {}", (Object)this.statusCode);
            this.handleException((Exception)((Object)new RemoteConnectorThrottlingException("Error from remote service: The request was denied due to remote server throttling. To change the retry policy and behavior, please update the connector client_config.", RestStatus.fromCode((int)this.statusCode), new Object[0])));
        }
    }

    private void response() {
        if (this.exceptionHolder.get() != null) {
            this.actionListener.onFailure(this.exceptionHolder.get());
            return;
        }
        String body = this.responseBody.toString();
        if (Strings.isBlank((String)body)) {
            log.error("Remote model response body is empty!");
            this.actionListener.onFailure((Exception)new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST, new Object[0]));
            return;
        }
        if (this.statusCode < 200 || this.statusCode > 300) {
            log.error("Remote server returned error code: {}", (Object)this.statusCode);
            this.actionListener.onFailure((Exception)new OpenSearchStatusException("Error from remote service: " + body, RestStatus.fromCode((int)this.statusCode), new Object[0]));
            return;
        }
        try {
            ModelTensors tensors = ConnectorUtils.processOutput(this.action, body, this.connector, this.scriptService, this.parameters, this.mlGuard);
            tensors.setStatusCode(this.statusCode);
            this.actionListener.onResponse((Object)new Tuple((Object)this.executionContext.getSequence(), (Object)tensors));
        }
        catch (Exception e) {
            log.error("Failed to process response body: {}", (Object)body, (Object)e);
            this.actionListener.onFailure((Exception)new MLException("Fail to execute " + this.action + " in aws connector", (Throwable)e));
        }
    }

    @Generated
    public Integer getStatusCode() {
        return this.statusCode;
    }

    @Generated
    public StringBuilder getResponseBody() {
        return this.responseBody;
    }

    protected class MLResponseSubscriber
    implements Subscriber<ByteBuffer> {
        private Subscription subscription;

        protected MLResponseSubscriber() {
        }

        public void onSubscribe(@NotNull Subscription s) {
            this.subscription = s;
            s.request(Long.MAX_VALUE);
        }

        public void onNext(ByteBuffer byteBuffer) {
            MLSdkAsyncHttpResponseHandler.this.responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer));
            this.subscription.request(Long.MAX_VALUE);
        }

        public void onError(Throwable t) {
            log.error("Error on receiving response body from remote: {}", (Object)(t instanceof NullPointerException ? "NullPointerException" : t.getMessage()), (Object)t);
            MLSdkAsyncHttpResponseHandler.this.response();
        }

        public void onComplete() {
            MLSdkAsyncHttpResponseHandler.this.response();
        }
    }
}

