/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.extensions.rest;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionModule;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.extensions.DiscoveryExtensionNode;
import org.opensearch.extensions.rest.ExtensionRestRequest;
import org.opensearch.extensions.rest.RegisterRestActionsRequest;
import org.opensearch.extensions.rest.RestExecuteOnExtensionResponse;
import org.opensearch.http.HttpRequest;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.NamedRoute;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestStatus;
import org.opensearch.transport.TransportException;
import org.opensearch.transport.TransportResponseHandler;
import org.opensearch.transport.TransportService;

public class RestSendToExtensionAction
extends BaseRestHandler {
    private static final String SEND_TO_EXTENSION_ACTION = "send_to_extension_action";
    private static final Logger logger = LogManager.getLogger(RestSendToExtensionAction.class);
    private static final Principal DEFAULT_PRINCIPAL = new Principal(){

        @Override
        public String getName() {
            return "OpenSearchUser";
        }
    };
    private final List<RestHandler.Route> routes;
    private final List<RestHandler.DeprecatedRoute> deprecatedRoutes;
    private final String pathPrefix;
    private final DiscoveryExtensionNode discoveryExtensionNode;
    private final TransportService transportService;
    private static final Set<String> allowList = Set.of("Content-Type");
    private static final Set<String> denyList = Set.of("Authorization", "Proxy-Authorization");

    public RestSendToExtensionAction(RegisterRestActionsRequest restActionsRequest, DiscoveryExtensionNode discoveryExtensionNode, TransportService transportService, ActionModule.DynamicActionRegistry dynamicActionRegistry) {
        String path;
        RestRequest.Method method;
        this.pathPrefix = "/_extensions/_" + restActionsRequest.getUniqueId();
        ArrayList<RestHandler.Route> restActionsAsRoutes = new ArrayList<RestHandler.Route>();
        for (String restAction : restActionsRequest.getRestActions()) {
            Optional<Object> name = Optional.empty();
            String[] parts = restAction.split(" ");
            if (parts.length < 2) {
                throw new IllegalArgumentException("REST action must contain at least a REST method and route");
            }
            try {
                method = RestRequest.Method.valueOf(parts[0].trim());
                path = this.pathPrefix + parts[1].trim();
                if (parts.length > 2) {
                    name = Optional.of(parts[2].trim());
                }
            }
            catch (IllegalArgumentException | IndexOutOfBoundsException e) {
                throw new IllegalArgumentException(restAction + " does not begin with a valid REST method");
            }
            logger.info("Registering: " + method + " " + path);
            if (name.isPresent()) {
                NamedRoute nr = new NamedRoute(method, path, (String)name.get());
                restActionsAsRoutes.add(nr);
                dynamicActionRegistry.registerDynamicRoute(nr, this);
                continue;
            }
            RestHandler.Route r = new RestHandler.Route(method, path);
            restActionsAsRoutes.add(r);
            dynamicActionRegistry.registerDynamicRoute(r, this);
        }
        this.routes = Collections.unmodifiableList(restActionsAsRoutes);
        ArrayList<RestHandler.DeprecatedRoute> restActionsAsDeprecatedRoutes = new ArrayList<RestHandler.DeprecatedRoute>();
        List<String> deprecatedActions = restActionsRequest.getDeprecatedRestActions();
        for (int i = 0; i < deprecatedActions.size() - 1; i += 2) {
            String restAction = deprecatedActions.get(i);
            String message = deprecatedActions.get(i + 1);
            int delim = restAction.indexOf(32);
            try {
                method = RestRequest.Method.valueOf(restAction.substring(0, delim));
                path = this.pathPrefix + restAction.substring(delim).trim();
            }
            catch (IllegalArgumentException | IndexOutOfBoundsException e) {
                throw new IllegalArgumentException(restAction + " does not begin with a valid REST method");
            }
            logger.info("Registering: " + method + " " + path + " with deprecation message " + message);
            restActionsAsDeprecatedRoutes.add(new RestHandler.DeprecatedRoute(method, path, message));
        }
        this.deprecatedRoutes = Collections.unmodifiableList(restActionsAsDeprecatedRoutes);
        this.discoveryExtensionNode = discoveryExtensionNode;
        this.transportService = transportService;
    }

    @Override
    public String getName() {
        return SEND_TO_EXTENSION_ACTION;
    }

    @Override
    public List<RestHandler.Route> routes() {
        return this.routes;
    }

    @Override
    public List<RestHandler.DeprecatedRoute> deprecatedRoutes() {
        return this.deprecatedRoutes;
    }

    public Map<String, List<String>> filterHeaders(Map<String, List<String>> headers, Set<String> allowList, Set<String> denyList) {
        Map<String, List<String>> filteredHeaders = headers.entrySet().stream().filter(e -> !denyList.contains(e.getKey())).filter(e -> allowList.contains(e.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        return filteredHeaders;
    }

    @Override
    public BaseRestHandler.RestChannelConsumer prepareRequest(final RestRequest request, NodeClient client) throws IOException {
        HttpRequest httpRequest = request.getHttpRequest();
        String path = request.path();
        RestRequest.Method method = request.method();
        String uri = httpRequest.uri();
        Map<String, String> params = request.params();
        Map<String, List<String>> headers = request.getHeaders();
        XContentType contentType = request.getXContentType();
        BytesReference content = request.content();
        HttpRequest.HttpVersion httpVersion = httpRequest.protocolVersion();
        if (path.startsWith(this.pathPrefix)) {
            path = path.substring(this.pathPrefix.length());
        }
        String message = "Forwarding the request " + method + " " + path + " to " + this.discoveryExtensionNode;
        logger.info(message);
        final RestExecuteOnExtensionResponse restExecuteOnExtensionResponse = new RestExecuteOnExtensionResponse(RestStatus.INTERNAL_SERVER_ERROR, "text/plain; charset=UTF-8", message.getBytes(StandardCharsets.UTF_8), Collections.emptyMap(), Collections.emptyList(), false);
        final CompletableFuture inProgressFuture = new CompletableFuture();
        TransportResponseHandler<RestExecuteOnExtensionResponse> restExecuteOnExtensionResponseHandler = new TransportResponseHandler<RestExecuteOnExtensionResponse>(){

            public RestExecuteOnExtensionResponse read(StreamInput in) throws IOException {
                return new RestExecuteOnExtensionResponse(in);
            }

            @Override
            public void handleResponse(RestExecuteOnExtensionResponse response) {
                logger.info("Received response from extension: {}", (Object)response.getStatus());
                restExecuteOnExtensionResponse.setStatus(response.getStatus());
                restExecuteOnExtensionResponse.setContentType(response.getContentType());
                restExecuteOnExtensionResponse.setContent(response.getContent());
                restExecuteOnExtensionResponse.setHeaders(response.getHeaders());
                response.getConsumedParams().stream().forEach(p -> request.param((String)p));
                if (response.isContentConsumed()) {
                    request.content();
                }
                inProgressFuture.complete(response);
            }

            @Override
            public void handleException(TransportException exp) {
                logger.debug("REST request failed", (Throwable)((Object)exp));
                request.params().keySet().stream().forEach(p -> request.param((String)p));
                request.content();
                inProgressFuture.completeExceptionally((Throwable)((Object)exp));
            }

            @Override
            public String executor() {
                return "generic";
            }
        };
        try {
            String extensionTokenProcessor = "placeholder_token_processor";
            String requestIssuerIdentity = "placeholder_request_issuer_identity";
            Map<String, List<String>> filteredHeaders = this.filterHeaders(headers, allowList, denyList);
            this.transportService.sendRequest(this.discoveryExtensionNode, "internal:extensions/restexecuteonextensiontaction", new ExtensionRestRequest(method, uri, path, params, filteredHeaders, contentType, content, "placeholder_request_issuer_identity", httpVersion), restExecuteOnExtensionResponseHandler);
            inProgressFuture.orTimeout(10L, TimeUnit.SECONDS).join();
        }
        catch (CompletionException e2) {
            Throwable cause = e2.getCause();
            if (cause instanceof TimeoutException) {
                return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.REQUEST_TIMEOUT, "No response from extension to request."));
            }
            if (e2.getCause() instanceof RuntimeException) {
                throw (RuntimeException)e2.getCause();
            }
            if (e2.getCause() instanceof Error) {
                throw (Error)e2.getCause();
            }
            throw new RuntimeException(e2.getCause());
        }
        catch (Exception ex) {
            logger.info("Failed to send REST Actions to extension " + this.discoveryExtensionNode.getName(), (Throwable)ex);
            return channel -> channel.sendResponse(new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, ex.getMessage()));
        }
        BytesRestResponse restResponse = new BytesRestResponse(restExecuteOnExtensionResponse.getStatus(), restExecuteOnExtensionResponse.getContentType(), restExecuteOnExtensionResponse.getContent());
        restExecuteOnExtensionResponse.getHeaders().entrySet().stream().forEach(e -> ((List)e.getValue()).stream().forEach(v -> restResponse.addHeader((String)e.getKey(), (String)v)));
        return channel -> channel.sendResponse(restResponse);
    }
}

