/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.action.controller;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import lombok.Generated;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.transport.controller.MLDeployControllerAction;
import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesRequest;
import org.opensearch.ml.common.transport.controller.MLDeployControllerNodesResponse;
import org.opensearch.ml.common.transport.controller.MLUpdateControllerRequest;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

public class UpdateControllerTransportAction
extends HandledTransportAction<ActionRequest, UpdateResponse> {
    @Generated
    private static final Logger log = LogManager.getLogger(UpdateControllerTransportAction.class);
    private final Client client;
    private final MLModelManager mlModelManager;
    private final MLModelCacheHelper mlModelCacheHelper;
    private final ClusterService clusterService;
    private final ModelAccessControlHelper modelAccessControlHelper;

    @Inject
    public UpdateControllerTransportAction(TransportService transportService, ActionFilters actionFilters, Client client, ClusterService clusterService, ModelAccessControlHelper modelAccessControlHelper, MLModelCacheHelper mlModelCacheHelper, MLModelManager mlModelManager) {
        super("cluster:admin/opensearch/ml/controllers/update", transportService, actionFilters, MLUpdateControllerRequest::new);
        this.client = client;
        this.mlModelManager = mlModelManager;
        this.clusterService = clusterService;
        this.mlModelCacheHelper = mlModelCacheHelper;
        this.modelAccessControlHelper = modelAccessControlHelper;
    }

    protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> actionListener) {
        MLUpdateControllerRequest updateControllerRequest = MLUpdateControllerRequest.fromActionRequest((ActionRequest)request);
        MLController updateControllerInput = updateControllerRequest.getUpdateControllerInput();
        String modelId = updateControllerInput.getModelId();
        User user = RestActionUtils.getUserContext(this.client);
        String[] excludes = new String[]{"model_content", "content"};
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> ((ThreadContext.StoredContext)context).restore());
            this.mlModelManager.getModel(modelId, null, excludes, (ActionListener<MLModel>)ActionListener.wrap(mlModel -> {
                FunctionName functionName = mlModel.getAlgorithm();
                if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) {
                    this.modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), this.client, (ActionListener<Boolean>)ActionListener.wrap(hasPermission -> {
                        if (hasPermission.booleanValue()) {
                            this.mlModelManager.getController(modelId, (ActionListener<MLController>)ActionListener.wrap(controller -> {
                                boolean isDeployRequiredAfterUpdate = controller.isDeployRequiredAfterUpdate(updateControllerInput);
                                controller.update(updateControllerInput);
                                this.updateController((MLModel)mlModel, (MLController)controller, isDeployRequiredAfterUpdate, (ActionListener<UpdateResponse>)wrappedListener);
                            }, e -> {
                                if (mlModel.getIsControllerEnabled() == null || !mlModel.getIsControllerEnabled().booleanValue()) {
                                    wrappedListener.onFailure((Exception)new OpenSearchStatusException("Model controller haven't been created for the model. Consider calling create model controller api instead. Model ID: " + modelId, RestStatus.CONFLICT, new Object[0]));
                                    log.error("Model controller haven't been created for the model: " + modelId, (Throwable)e);
                                } else {
                                    log.error(e);
                                    wrappedListener.onFailure(e);
                                }
                            }));
                        } else {
                            wrappedListener.onFailure((Exception)new OpenSearchStatusException("User doesn't have privilege to perform this operation on this model controller, model ID: " + modelId, RestStatus.FORBIDDEN, new Object[0]));
                        }
                    }, exception -> {
                        log.error("Permission denied: Unable to create the model controller for the model with ID {}. Details: {}", (Object)modelId, exception);
                        wrappedListener.onFailure(exception);
                    }));
                } else {
                    wrappedListener.onFailure((Exception)new OpenSearchStatusException("Creating model controller on this operation on the function category " + functionName.toString() + " is not supported.", RestStatus.FORBIDDEN, new Object[0]));
                }
            }, e -> wrappedListener.onFailure((Exception)new OpenSearchStatusException("Failed to find model to create the corresponding model controller with the provided model ID: " + modelId, RestStatus.NOT_FOUND, new Object[0]))));
        }
        catch (Exception e2) {
            log.error("Failed to create model controller for " + modelId, (Throwable)e2);
            actionListener.onFailure(e2);
        }
    }

    private void updateController(MLModel mlModel, MLController controller, boolean isDeployRequiredAfterUpdate, ActionListener<UpdateResponse> actionListener) {
        try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
            String modelId = mlModel.getModelId();
            ActionListener updateResponseListener = ActionListener.wrap(updateResponse -> {
                if (updateResponse != null && updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
                    log.info("Model controller for model {} successfully updated to index, result: {}", (Object)modelId, (Object)updateResponse.getResult());
                    if (!ArrayUtils.isEmpty((Object[])this.mlModelCacheHelper.getWorkerNodes(modelId)) && isDeployRequiredAfterUpdate) {
                        log.info("Model {} is deployed and the user rate limiter config is constructable. Start to deploy the model controller into cache.", (Object)modelId);
                        String[] targetNodeIds = this.mlModelManager.getWorkerNodes(modelId, mlModel.getAlgorithm());
                        MLDeployControllerNodesRequest deployControllerNodesRequest = new MLDeployControllerNodesRequest(targetNodeIds, modelId);
                        this.client.execute((ActionType)MLDeployControllerAction.INSTANCE, (ActionRequest)deployControllerNodesRequest, ActionListener.wrap(nodesResponse -> {
                            if (nodesResponse != null && this.isDeployControllerSuccessOnAllNodes((MLDeployControllerNodesResponse)nodesResponse)) {
                                log.info("Successfully update model controller and deploy it into cache with model ID {}", (Object)modelId);
                                actionListener.onResponse(updateResponse);
                            } else {
                                Object[] nodeIds = this.getDeployControllerFailedNodesList((MLDeployControllerNodesResponse)nodesResponse);
                                log.error("Successfully update model controller index with model ID {} but deploy model controller to cache was failed on following nodes {}, please retry.", (Object)modelId, (Object)Arrays.toString(nodeIds));
                                actionListener.onFailure((Exception)new RuntimeException("Successfully update model controller index with model ID " + modelId + " but deploy model controller to cache was failed on following nodes " + Arrays.toString(nodeIds) + ", please retry."));
                            }
                        }, e -> {
                            log.error("Failed to deploy model controller for model: {}" + modelId, (Throwable)e);
                            actionListener.onFailure(e);
                        }));
                    } else {
                        actionListener.onResponse(updateResponse);
                    }
                } else if (updateResponse != null && updateResponse.getResult() != DocWriteResponse.Result.UPDATED) {
                    log.warn("Update model controller for model {} got a result status other than update, result status: {}", (Object)modelId, (Object)updateResponse.getResult());
                    actionListener.onResponse(updateResponse);
                } else {
                    log.error("Failed to update model controller with model ID: " + modelId);
                    actionListener.onFailure((Exception)new RuntimeException("Failed to update model controller with model ID: " + modelId));
                }
            }, arg_0 -> actionListener.onFailure(arg_0));
            UpdateRequest updateRequest = new UpdateRequest(".plugins-ml-controller", modelId);
            updateRequest.doc(controller.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS));
            updateRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
            this.client.update(updateRequest, ActionListener.runBefore((ActionListener)updateResponseListener, () -> ((ThreadContext.StoredContext)context).restore()));
        }
        catch (Exception e) {
            log.error("Failed to update model controller.", (Throwable)e);
            actionListener.onFailure(e);
        }
    }

    private boolean isDeployControllerSuccessOnAllNodes(MLDeployControllerNodesResponse deployControllerNodesResponse) {
        return deployControllerNodesResponse.failures() == null || deployControllerNodesResponse.failures().isEmpty();
    }

    private String[] getDeployControllerFailedNodesList(MLDeployControllerNodesResponse deployControllerNodesResponse) {
        if (deployControllerNodesResponse == null) {
            return this.getAllNodes();
        }
        ArrayList<String> nodeIds = new ArrayList<String>();
        for (FailedNodeException failedNodeException : deployControllerNodesResponse.failures()) {
            nodeIds.add(failedNodeException.nodeId());
        }
        return nodeIds.toArray(new String[0]);
    }

    private String[] getAllNodes() {
        Iterator iterator = this.clusterService.state().nodes().iterator();
        ArrayList<String> nodeIds = new ArrayList<String>();
        while (iterator.hasNext()) {
            nodeIds.add(((DiscoveryNode)iterator.next()).getId());
        }
        return nodeIds.toArray(new String[0]);
    }
}

