/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.algorithms.remote;

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.util.Strings;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
import org.opensearch.ml.engine.algorithms.remote.ExecutionContext;
import org.opensearch.script.ScriptService;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.http.SdkHttpFullResponse;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;

public class MLSdkAsyncHttpResponseHandler
implements SdkAsyncHttpResponseHandler {
    @Generated
    private static final Logger log = LogManager.getLogger(MLSdkAsyncHttpResponseHandler.class);
    private Integer statusCode;
    private final StringBuilder responseBody = new StringBuilder();
    private final ExecutionContext executionContext;
    private final ActionListener<List<ModelTensors>> actionListener;
    private final Map<String, String> parameters;
    private final Map<Integer, ModelTensors> tensorOutputs;
    private final Connector connector;
    private final ScriptService scriptService;
    private final MLGuard mlGuard;

    public MLSdkAsyncHttpResponseHandler(ExecutionContext executionContext, ActionListener<List<ModelTensors>> actionListener, Map<String, String> parameters, Map<Integer, ModelTensors> tensorOutputs, Connector connector, ScriptService scriptService, MLGuard mlGuard) {
        this.executionContext = executionContext;
        this.actionListener = actionListener;
        this.parameters = parameters;
        this.tensorOutputs = tensorOutputs;
        this.connector = connector;
        this.scriptService = scriptService;
        this.mlGuard = mlGuard;
    }

    public void onHeaders(SdkHttpResponse response) {
        SdkHttpFullResponse sdkResponse = (SdkHttpFullResponse)response;
        log.debug("received response headers: " + sdkResponse.headers());
        this.statusCode = sdkResponse.statusCode();
    }

    public void onStream(Publisher<ByteBuffer> stream) {
        stream.subscribe((Subscriber)new MLResponseSubscriber());
    }

    public void onError(Throwable error) {
        log.error(error.getMessage(), error);
        RestStatus status = this.statusCode == null ? RestStatus.INTERNAL_SERVER_ERROR : RestStatus.fromCode((int)this.statusCode);
        String errorMessage = "Error communicating with remote model: " + error.getMessage();
        this.actionListener.onFailure((Exception)new OpenSearchStatusException(errorMessage, status, new Object[0]));
    }

    private void processResponse(Integer statusCode, String body, Map<String, String> parameters, Map<Integer, ModelTensors> tensorOutputs) {
        block8: {
            if (Strings.isBlank((String)body)) {
                log.error("Remote model response body is empty!");
                if (this.executionContext.getExceptionHolder().get() == null) {
                    this.executionContext.getExceptionHolder().compareAndSet(null, (Exception)new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST, new Object[0]));
                }
            } else if (statusCode < 200 || statusCode > 300) {
                log.error("Remote server returned error code: {}", (Object)statusCode);
                if (this.executionContext.getExceptionHolder().get() == null) {
                    this.executionContext.getExceptionHolder().compareAndSet(null, (Exception)new OpenSearchStatusException("Error from remote service: " + body, RestStatus.fromCode((int)statusCode), new Object[0]));
                }
            } else {
                try {
                    ModelTensors tensors = ConnectorUtils.processOutput(body, this.connector, this.scriptService, parameters, this.mlGuard);
                    tensors.setStatusCode(statusCode);
                    tensorOutputs.put(this.executionContext.getSequence(), tensors);
                }
                catch (Exception e) {
                    log.error("Failed to process response body: {}", (Object)body, (Object)e);
                    if (this.executionContext.getExceptionHolder().get() != null) break block8;
                    this.executionContext.getExceptionHolder().compareAndSet(null, (Exception)new MLException("Fail to execute predict in aws connector", (Throwable)e));
                }
            }
        }
    }

    private void reOrderTensorResponses(Map<Integer, ModelTensors> tensorOutputs) {
        ModelTensors[] modelTensors = new ModelTensors[tensorOutputs.size()];
        log.debug("Reordered tensor outputs size is {}", (Object)tensorOutputs.size());
        for (Map.Entry<Integer, ModelTensors> entry : tensorOutputs.entrySet()) {
            modelTensors[entry.getKey().intValue()] = entry.getValue();
        }
        this.actionListener.onResponse(Arrays.asList(modelTensors));
    }

    private void response(Map<Integer, ModelTensors> tensors) {
        this.processResponse(this.statusCode, this.responseBody.toString(), this.parameters, this.tensorOutputs);
        this.executionContext.getCountDownLatch().countDown();
        if (this.executionContext.getCountDownLatch().getCount() == 0L) {
            if (this.executionContext.getExceptionHolder().get() != null) {
                this.actionListener.onFailure(this.executionContext.getExceptionHolder().get());
                return;
            }
            this.reOrderTensorResponses(tensors);
        } else {
            log.debug("Not all responses received, left response count is: " + this.executionContext.getCountDownLatch().getCount());
        }
    }

    @Generated
    public Integer getStatusCode() {
        return this.statusCode;
    }

    @Generated
    public StringBuilder getResponseBody() {
        return this.responseBody;
    }

    protected class MLResponseSubscriber
    implements Subscriber<ByteBuffer> {
        private Subscription subscription;

        protected MLResponseSubscriber() {
        }

        public void onSubscribe(Subscription s) {
            this.subscription = s;
            s.request(Long.MAX_VALUE);
        }

        public void onNext(ByteBuffer byteBuffer) {
            MLSdkAsyncHttpResponseHandler.this.responseBody.append(StandardCharsets.UTF_8.decode(byteBuffer));
            this.subscription.request(Long.MAX_VALUE);
        }

        public void onError(Throwable t) {
            log.error("Error on receiving response body from remote: {}", (Object)(t instanceof NullPointerException ? "NullPointerException" : t.getMessage()), (Object)t);
            MLSdkAsyncHttpResponseHandler.this.response(MLSdkAsyncHttpResponseHandler.this.tensorOutputs);
        }

        public void onComplete() {
            MLSdkAsyncHttpResponseHandler.this.response(MLSdkAsyncHttpResponseHandler.this.tensorOutputs);
        }
    }
}

