/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.cluster;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.index.IndexRequest;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.WriteRequest;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.sync.MLSyncUpAction;
import org.opensearch.ml.common.transport.sync.MLSyncUpInput;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodeResponse;
import org.opensearch.ml.common.transport.sync.MLSyncUpNodesRequest;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.indices.MLIndicesHandler;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;

public class MLSyncUpCron
implements Runnable {
    @Generated
    private static final Logger log = LogManager.getLogger(MLSyncUpCron.class);
    public static final int DEPLOY_MODEL_TASK_GRACE_TIME_IN_MS = 20000;
    private Client client;
    private ClusterService clusterService;
    private DiscoveryNodeHelper nodeHelper;
    private MLIndicesHandler mlIndicesHandler;
    private Encryptor encryptor;
    private volatile Boolean mlConfigInited;
    @VisibleForTesting
    Semaphore updateModelStateSemaphore;

    public MLSyncUpCron(Client client, ClusterService clusterService, DiscoveryNodeHelper nodeHelper, MLIndicesHandler mlIndicesHandler, Encryptor encryptor) {
        this.client = client;
        this.clusterService = clusterService;
        this.nodeHelper = nodeHelper;
        this.mlIndicesHandler = mlIndicesHandler;
        this.updateModelStateSemaphore = new Semaphore(1);
        this.mlConfigInited = false;
        this.encryptor = encryptor;
    }

    @Override
    public void run() {
        this.initMLConfig();
        if (!this.clusterService.state().metadata().indices().containsKey(".plugins-ml-model")) {
            return;
        }
        log.debug("ML sync job starts");
        DiscoveryNode[] allNodes = this.nodeHelper.getAllNodes();
        MLSyncUpInput gatherInfoInput = MLSyncUpInput.builder().getDeployedModels(true).build();
        MLSyncUpNodesRequest gatherInfoRequest = new MLSyncUpNodesRequest(allNodes, gatherInfoInput);
        this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)gatherInfoRequest, ActionListener.wrap(r -> {
            List responses = r.getNodes();
            HashMap<String, Set> modelWorkerNodes = new HashMap<String, Set>();
            HashMap<String, Set> runningDeployModelTasks = new HashMap<String, Set>();
            HashMap<String, Set> deployingModels = new HashMap<String, Set>();
            for (MLSyncUpNodeResponse mLSyncUpNodeResponse : responses) {
                String[] runningDeployModelTaskIds;
                String[] runningModelIds;
                String nodeId = mLSyncUpNodeResponse.getNode().getId();
                String[] deployedModelIds = mLSyncUpNodeResponse.getDeployedModelIds();
                if (deployedModelIds != null && deployedModelIds.length > 0) {
                    for (String modelId : deployedModelIds) {
                        Set workerNodes = modelWorkerNodes.computeIfAbsent(modelId, it -> new HashSet());
                        workerNodes.add(nodeId);
                    }
                }
                if ((runningModelIds = mLSyncUpNodeResponse.getRunningDeployModelIds()) != null && runningModelIds.length > 0) {
                    for (String modelId : runningModelIds) {
                        Set workerNodes = deployingModels.computeIfAbsent(modelId, it -> new HashSet());
                        workerNodes.add(nodeId);
                    }
                }
                if ((runningDeployModelTaskIds = mLSyncUpNodeResponse.getRunningDeployModelTaskIds()) == null || runningDeployModelTaskIds.length <= 0) continue;
                for (String taskId : runningDeployModelTaskIds) {
                    Set workerNodes = runningDeployModelTasks.computeIfAbsent(taskId, it -> new HashSet());
                    workerNodes.add(nodeId);
                }
            }
            for (Map.Entry entry : modelWorkerNodes.entrySet()) {
                String modelId = (String)entry.getKey();
                log.debug("will sync model worker nodes for model: {}: {}", (Object)modelId, (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            for (Map.Entry entry : runningDeployModelTasks.entrySet()) {
                log.debug("will sync running task: {}: {}", entry.getKey(), (Object)((Set)entry.getValue()).toArray(new String[0]));
            }
            MLSyncUpInput.MLSyncUpInputBuilder inputBuilder = MLSyncUpInput.builder().syncRunningDeployModelTasks(true).runningDeployModelTasks(runningDeployModelTasks);
            if (modelWorkerNodes.size() == 0) {
                log.debug("No deployed model found. Will clear model routing on all nodes");
                inputBuilder.clearRoutingTable(true);
            } else {
                inputBuilder.modelRoutingTable(modelWorkerNodes);
            }
            MLSyncUpInput mLSyncUpInput = inputBuilder.build();
            MLSyncUpNodesRequest syncUpRequest = new MLSyncUpNodesRequest(allNodes, mLSyncUpInput);
            this.client.execute((ActionType)MLSyncUpAction.INSTANCE, (ActionRequest)syncUpRequest, ActionListener.wrap(re -> log.debug("sync model routing job finished"), ex -> log.error("Failed to sync model routing", (Throwable)ex)));
            this.mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> this.refreshModelState(modelWorkerNodes, deployingModels), e -> log.error("Failed to init model index", (Throwable)e)));
        }, e -> log.error("Failed to sync model routing", (Throwable)e)));
    }

    @VisibleForTesting
    void initMLConfig() {
        if (this.mlConfigInited.booleanValue()) {
            return;
        }
        this.mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
            GetRequest getRequest = new GetRequest(".plugins-ml-config").id("master_key");
            try (ThreadContext.StoredContext context = this.client.threadPool().getThreadContext().stashContext();){
                this.client.get(getRequest, ActionListener.wrap(getResponse -> {
                    if (!getResponse.isExists()) {
                        IndexRequest indexRequest = new IndexRequest(".plugins-ml-config").id("master_key");
                        String masterKey = this.encryptor.generateMasterKey();
                        indexRequest.source((Map)ImmutableMap.of((Object)"master_key", (Object)masterKey, (Object)"create_time", (Object)Instant.now().toEpochMilli()));
                        indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
                        this.client.index(indexRequest, ActionListener.wrap(indexResponse -> {
                            log.info("ML configuration initialized successfully");
                            this.encryptor.setMasterKey(masterKey);
                            this.mlConfigInited = true;
                        }, e -> log.debug("Failed to save ML encryption master key", (Throwable)e)));
                    } else {
                        String masterKey = (String)getResponse.getSourceAsMap().get("master_key");
                        this.encryptor.setMasterKey(masterKey);
                        this.mlConfigInited = true;
                        log.info("ML configuration already initialized, no action needed");
                    }
                }, e -> log.debug("Failed to get ML encryption master key", (Throwable)e)));
            }
        }, e -> log.debug("Failed to init ML config index", (Throwable)e)));
    }

    @VisibleForTesting
    void refreshModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, Set<String>> deployingModels) {
        if (!this.updateModelStateSemaphore.tryAcquire()) {
            return;
        }
        try {
            SearchRequest searchRequest = new SearchRequest(new String[]{".plugins-ml-model"});
            BoolQueryBuilder queryBuilder = new BoolQueryBuilder();
            queryBuilder.filter((QueryBuilder)new TermsQueryBuilder("model_state", Arrays.asList(MLModelState.LOADING.name(), MLModelState.PARTIALLY_LOADED.name(), MLModelState.LOADED.name(), MLModelState.LOAD_FAILED.name(), MLModelState.DEPLOYING.name(), MLModelState.PARTIALLY_DEPLOYED.name(), MLModelState.DEPLOYED.name(), MLModelState.DEPLOY_FAILED.name())));
            SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
            sourceBuilder.query((QueryBuilder)queryBuilder);
            sourceBuilder.size(10000);
            sourceBuilder.fetchSource(new String[]{"model_state", "algorithm", "deploy_to_all_nodes", "planning_worker_nodes", "planning_worker_node_count", "last_updated_time", "current_worker_node_count"}, null);
            searchRequest.source(sourceBuilder);
            this.client.search(searchRequest, ActionListener.wrap(res -> {
                SearchHit[] hits = res.getHits().getHits();
                HashMap<String, MLModelState> newModelStates = new HashMap<String, MLModelState>();
                HashMap<String, List<String>> newPlanningWorkerNodes = new HashMap<String, List<String>>();
                for (SearchHit hit : hits) {
                    MLModelState mlModelState;
                    List planningWorkNodes;
                    String modelId = hit.getId();
                    Map sourceAsMap = hit.getSourceAsMap();
                    FunctionName functionName = FunctionName.from((String)((String)sourceAsMap.get("algorithm")));
                    MLModelState state = MLModelState.from((String)((String)sourceAsMap.get("model_state")));
                    Long lastUpdateTime = sourceAsMap.containsKey("last_updated_time") ? (Long)sourceAsMap.get("last_updated_time") : null;
                    int planningWorkerNodeCount = sourceAsMap.containsKey("planning_worker_node_count") ? (Integer)sourceAsMap.get("planning_worker_node_count") : 0;
                    int currentWorkerNodeCountInIndex = sourceAsMap.containsKey("current_worker_node_count") ? (Integer)sourceAsMap.get("current_worker_node_count") : 0;
                    boolean deployToAllNodes = sourceAsMap.containsKey("deploy_to_all_nodes") ? (Boolean)sourceAsMap.get("deploy_to_all_nodes") : false;
                    List list = planningWorkNodes = sourceAsMap.containsKey("planning_worker_nodes") ? (List)sourceAsMap.get("planning_worker_nodes") : new ArrayList();
                    if (deployToAllNodes) {
                        DiscoveryNode[] eligibleNodes = this.nodeHelper.getEligibleNodes(functionName);
                        planningWorkerNodeCount = eligibleNodes.length;
                        List eligibleNodeIds = Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList());
                        if (eligibleNodeIds.size() != planningWorkNodes.size() || !eligibleNodeIds.containsAll(planningWorkNodes)) {
                            newPlanningWorkerNodes.put(modelId, eligibleNodeIds);
                        }
                    }
                    if ((mlModelState = this.getNewModelState(deployingModels, modelWorkerNodes, modelId, state, lastUpdateTime, planningWorkerNodeCount, currentWorkerNodeCountInIndex)) == null) continue;
                    newModelStates.put(modelId, mlModelState);
                }
                this.bulkUpdateModelState(modelWorkerNodes, newModelStates, newPlanningWorkerNodes);
            }, e -> {
                this.updateModelStateSemaphore.release();
                log.error("Failed to search models", (Throwable)e);
            }));
        }
        catch (Exception e2) {
            this.updateModelStateSemaphore.release();
            log.error("Failed to refresh model state", (Throwable)e2);
        }
    }

    private MLModelState getNewModelState(Map<String, Set<String>> deployingModels, Map<String, Set<String>> modelWorkerNodes, String modelId, MLModelState state, Long lastUpdateTime, int planningWorkerNodeCount, int currentWorkerNodeCountInIndex) {
        int currentWorkerNodeCount;
        Set<String> deployModelTaskNodes = deployingModels.get(modelId);
        if (deployModelTaskNodes != null && deployModelTaskNodes.size() > 0 && state != MLModelState.DEPLOYING) {
            return MLModelState.DEPLOYING;
        }
        int n = currentWorkerNodeCount = modelWorkerNodes.containsKey(modelId) ? modelWorkerNodes.get(modelId).size() : 0;
        if (currentWorkerNodeCount == 0 && state != MLModelState.DEPLOY_FAILED && (state != MLModelState.DEPLOYING || lastUpdateTime == null || lastUpdateTime + 20000L <= Instant.now().toEpochMilli())) {
            return MLModelState.DEPLOY_FAILED;
        }
        if (currentWorkerNodeCount > 0) {
            if (currentWorkerNodeCount < planningWorkerNodeCount && (state != MLModelState.PARTIALLY_DEPLOYED || currentWorkerNodeCountInIndex != currentWorkerNodeCount)) {
                return MLModelState.PARTIALLY_DEPLOYED;
            }
            if (planningWorkerNodeCount > 0 && currentWorkerNodeCount >= planningWorkerNodeCount && state != MLModelState.DEPLOYED) {
                if (currentWorkerNodeCount > planningWorkerNodeCount) {
                    log.warn("Model {} deployed on more nodes [{}] than planning worker node [{}]", (Object)modelId, (Object)currentWorkerNodeCount, (Object)planningWorkerNodeCount);
                }
                return MLModelState.DEPLOYED;
            }
        }
        return null;
    }

    private void bulkUpdateModelState(Map<String, Set<String>> modelWorkerNodes, Map<String, MLModelState> newModelStates, Map<String, List<String>> newPlanningWorkNodes) {
        HashSet<String> updatedModelIds = new HashSet<String>();
        updatedModelIds.addAll(newModelStates.keySet());
        updatedModelIds.addAll(newPlanningWorkNodes.keySet());
        if (updatedModelIds.size() > 0) {
            BulkRequest bulkUpdateRequest = new BulkRequest();
            for (String modelId : updatedModelIds) {
                UpdateRequest updateRequest = new UpdateRequest();
                Instant now = Instant.now();
                ImmutableMap.Builder builder = ImmutableMap.builder();
                if (newModelStates.containsKey(modelId)) {
                    builder.put((Object)"model_state", (Object)newModelStates.get(modelId).name());
                }
                if (newPlanningWorkNodes.containsKey(modelId)) {
                    builder.put((Object)"planning_worker_nodes", newPlanningWorkNodes.get(modelId));
                    builder.put((Object)"planning_worker_node_count", (Object)newPlanningWorkNodes.get(modelId).size());
                }
                builder.put((Object)"last_updated_time", (Object)now.toEpochMilli());
                Set<String> workerNodes = modelWorkerNodes.get(modelId);
                int currentWorkNodeCount = workerNodes == null ? 0 : workerNodes.size();
                builder.put((Object)"current_worker_node_count", (Object)currentWorkNodeCount);
                ((UpdateRequest)updateRequest.index(".plugins-ml-model")).id(modelId).doc((Map)builder.build());
                bulkUpdateRequest.add(updateRequest);
            }
            log.info("Refresh model state: {}", newModelStates);
            this.client.bulk(bulkUpdateRequest, ActionListener.wrap(br -> {
                this.updateModelStateSemaphore.release();
                log.debug("Refresh model state successfully");
            }, e -> {
                this.updateModelStateSemaphore.release();
                log.error("Failed to bulk update model state", (Throwable)e);
            }));
        } else {
            this.updateModelStateSemaphore.release();
        }
    }
}

