/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.controller;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Map;
import lombok.Generated;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.controller.MLCreateControllerRequest;
import org.opensearch.ml.common.transport.controller.MLCreateControllerResponse;
import org.opensearch.ml.common.transport.controller.MLDeployControllerAction;
import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest;
import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class CreateControllerTransportAction
extends HandledTransportAction<ActionRequest, MLCreateControllerResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(CreateControllerTransportAction.class);
    private final MLIndicesHandler mlIndicesHandler;
    private final Client client;
    private final MLModelManager mlModelManager;
    private final ClusterService clusterService;
    private final MLModelCacheHelper mlModelCacheHelper;
    private final ModelAccessControlHelper modelAccessControlHelper;

    @Inject
    public CreateControllerTransportAction(TransportService transportService, ActionFilters actionFilters, MLIndicesHandler mlIndicesHandler, Client client, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mlModelCacheHelper, MLModelManager mlModelManager) {
        super("cluster:admin/opensearch/ml/controllers/create", transportService, actionFilters, MLCreateControllerRequest::new);
        this.mlIndicesHandler = mlIndicesHandler;
        this.client = client;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.mlModelCacheHelper = mlModelCacheHelper;
        this.modelAccessControlHelper = modelAccessControlHelper;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<MLCreateControllerResponse> actionListener) {
        MLCreateControllerRequest createControllerRequest = MLCreateControllerRequest.fromActionRequest((ActionRequest)request);
        MLController controller = createControllerRequest.getControllerInput();
        String modelId = controller.getModelId();
        User user = RestActionUtils.getUserContext(this.client);
        String[] excludes = new String[]{"model_content", "content"};
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> ((ThreadContext.StoredContext)context).restore());
            this.mlModelManager.getModel(modelId, null, excludes, (ActionListener<MLModel>)ActionListener.wrap(mlModel -> {
                FunctionName functionName = mlModel.getAlgorithm();
                if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) {
                    this.modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), this.client, (ActionListener<Boolean>)ActionListener.wrap(hasPermission -> {
                        if (hasPermission.booleanValue()) {
                            if (mlModel.getModelState() != MLModelState.DEPLOYING) {
                                this.indexAndCreateController((MLModel)mlModel, controller, (ActionListener<MLCreateControllerResponse>)wrappedListener);
                            } else {
                                wrappedListener.onFailure((Exception)new OpenSearchStatusException("Creating a model controller during its corresponding model in DEPLOYING state is not allowed, please either create the model controller after it is deployed or before deploying it. Model ID: " + modelId, RestStatus.CONFLICT, new Object[0]));
                                log.error("Failed to create a model controller during its corresponding model in DEPLOYING state. Model ID: " + modelId);
                            }
                        } else {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model controller, model ID: " + modelId, RestStatus.FORBIDDEN, new Object[0]));
                        }
                    }, exception -> {
                        log.error("Permission denied: Unable to create the model controller for the model with ID {}. Details: {}", (Object)modelId, exception);
                        wrappedListener.onFailure(exception);
                    }));
                } else {
                    wrappedListener.onFailure((Exception)new OpenSearchStatusException("Creating model controller on this operation on the function category " + functionName.toString() + " is not supported.", RestStatus.FORBIDDEN, new Object[0]));
                }
            }, e -> wrappedListener.onFailure((Exception)new OpenSearchStatusException("Failed to find model to create the corresponding model controller with the provided model ID: " + modelId, RestStatus.NOT_FOUND, new Object[0]))));
        }
        catch (Exception e2) {
            log.error("Failed to create model controller for " + modelId, (Throwable)e2);
            actionListener.onFailure(e2);
        }
    }

    private void indexAndCreateController(MLModel mlModel, MLController controller, ActionListener<MLCreateControllerResponse> actionListener) {
        this.mlIndicesHandler.initMLControllerIndex(ActionListener.wrap(indexCreated -> {
            if (!indexCreated.booleanValue()) {
                actionListener.onFailure((Exception)new RuntimeException("Failed to create model controller index."));
                return;
            }
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                ActionListener indexResponseListener = ActionListener.wrap(indexResponse -> {
                    String modelId = indexResponse.getId();
                    MLCreateControllerResponse response = new MLCreateControllerResponse(modelId, indexResponse.getResult().name());
                    log.info("Model controller for model id {} saved into index, result:{}", (Object)modelId, (Object)indexResponse.getResult());
                    if (indexResponse.getResult() == DocWriteResponse.Result.CREATED) {
                        this.mlModelManager.updateModel(modelId, Map.of("is_controller_enabled", true));
                    }
                    if (!ArrayUtils.isEmpty((Object[])this.mlModelCacheHelper.getWorkerNodes(modelId))) {
                        log.info("Model {} is deployed. Start to deploy the model controller into cache.", (Object)modelId);
                        String[] targetNodeIds = this.mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm());
                        MLDeployControllerNodesRequest deployControllerNodesRequest = new MLDeployControllerNodesRequest(targetNodeIds, controller.getModelId());
                        this.client.execute((ActionType)MLDeployControllerAction.INSTANCE, (ActionRequest)deployControllerNodesRequest, ActionListener.wrap(nodesResponse -> {
                            if (nodesResponse != null && this.isDeployControllerSuccessOnAllNodes((MLDeployControllerNodesResponse)nodesResponse)) {
                                log.info("Successfully create model controller and deploy it into cache with model ID {}", (Object)modelId);
                                actionListener.onResponse((Object)response);
                            } else {
                                Object[] nodeIds = this.getDeployControllerFailedNodesList((MLDeployControllerNodesResponse)nodesResponse);
                                log.error("Successfully create model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", (Object)modelId, (Object)Arrays.toString(nodeIds));
                                actionListener.onFailure((Exception)new RuntimeException("Successfully create model controller index with model ID " + modelId + " but deploy model controller to cache was failed on following nodes " + Arrays.toString(nodeIds) + ", please retry."));
                            }
                        }, e -> {
                            log.error("Failed to deploy model controller for model: {}" + modelId, (Throwable)e);
                            actionListener.onFailure(e);
                        }));
                    } else {
                        actionListener.onResponse((Object)response);
                    }
                }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0));
                IndexRequest indexRequest = new IndexRequest(".plugins-ml-controller").id(controller.getModelId());
                indexRequest.source(controller.toXContent(XContentBuilder.builder((XContent)XContentType.JSON.xContent()), ToXContent.EMPTY_PARAMS));
                indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                this.client.index(indexRequest, ActionListener.runBefore((ActionListener)indexResponseListener, () -> ((ThreadContext.StoredContext)context).restore()));
            }
            catch (Exception e) {
                log.error("Failed to save model controller", (Throwable)e);
                actionListener.onFailure(e);
            }
        }, e -> {
            log.error("Failed to init model controller index", (Throwable)e);
            actionListener.onFailure(e);
        }));
    }

    private boolean isDeployControllerSuccessOnAllNodes(MLDeployControllerNodesResponse deployControllerNodesResponse) {
        return deployControllerNodesResponse.failures() == null || deployControllerNodesResponse.failures().isEmpty();
    }

    private String[] getDeployControllerFailedNodesList(MLDeployControllerNodesResponse deployControllerNodesResponse) {
        if (deployControllerNodesResponse == null) {
            return this.getAllNodes();
        }
        ArrayList<String> nodeIds = new ArrayList<String>();
        for (FailedNodeException failedNodeException : deployControllerNodesResponse.failures()) {
            nodeIds.add(failedNodeException.nodeId());
        }
        return nodeIds.toArray(new String[0]);
    }

    private String[] getAllNodes() {
        Iterator iterator = this.clusterService.state().nodes().iterator();
        ArrayList<String> nodeIds = new ArrayList<String>();
        while (iterator.hasNext()) {
            nodeIds.add(((DiscoveryNode)iterator.next()).getId());
        }
        return nodeIds.toArray(new String[0]);
    }
}

