/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.performanceanalyzer.rca.scheduler;

import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.performanceanalyzer.AppContext;
import org.opensearch.performanceanalyzer.commons.stats.ServiceMetrics;
import org.opensearch.performanceanalyzer.commons.stats.collectors.SampleAggregator;
import org.opensearch.performanceanalyzer.commons.stats.measurements.MeasurementSet;
import org.opensearch.performanceanalyzer.rca.framework.core.ConnectedComponent;
import org.opensearch.performanceanalyzer.rca.framework.core.Node;
import org.opensearch.performanceanalyzer.rca.framework.core.Queryable;
import org.opensearch.performanceanalyzer.rca.framework.core.RcaConf;
import org.opensearch.performanceanalyzer.rca.framework.core.Stats;
import org.opensearch.performanceanalyzer.rca.framework.metrics.RcaGraphMetrics;
import org.opensearch.performanceanalyzer.rca.framework.util.RcaUtil;
import org.opensearch.performanceanalyzer.rca.messages.IntentMsg;
import org.opensearch.performanceanalyzer.rca.net.WireHopper;
import org.opensearch.performanceanalyzer.rca.persistence.Persistable;
import org.opensearch.performanceanalyzer.rca.scheduler.GraphNodeOperations;
import org.opensearch.performanceanalyzer.rca.scheduler.Tasklet;

public class RCASchedulerTask
implements Runnable {
    private static final Logger LOG = LogManager.getLogger(RCASchedulerTask.class);
    private static final String EMPTY_STRING = "";
    private Queryable newDb = null;
    private int maxTicks;
    private int currTick;
    private final ExecutorService executorPool;
    private final Map<Node<?>, List<Node<?>>> remotelyDesirableNodeSet;
    private final List<List<Tasklet>> locallyExecutableTasklets;

    public RCASchedulerTask(int maxTicks, ExecutorService executorPool, List<ConnectedComponent> connectedComponents, Queryable db, Persistable persistable, RcaConf conf, WireHopper hopper, AppContext appContext) {
        this.maxTicks = maxTicks;
        this.executorPool = executorPool;
        this.remotelyDesirableNodeSet = new HashMap();
        HashMap nodeTaskletMap = new HashMap();
        List dependencyOrderedLocallyExecutables = Collections.emptyList();
        for (ConnectedComponent component : connectedComponents) {
            List orderedTasklets = this.getLocallyExecutableNodes(component.getAllNodesByDependencyOrder(), conf, hopper, db, persistable, nodeTaskletMap, appContext);
            dependencyOrderedLocallyExecutables = RCASchedulerTask.mergeLists(orderedTasklets, dependencyOrderedLocallyExecutables);
        }
        this.locallyExecutableTasklets = Collections.unmodifiableList(dependencyOrderedLocallyExecutables);
        LOG.debug("rca: locally executable tasklet size: {}", (Object)this.locallyExecutableTasklets.size());
    }

    public static <T> List<List<T>> mergeLists(List<List<T>> l1, List<List<T>> l2) {
        if (l2.size() > l1.size()) {
            return RCASchedulerTask.mergeLists(l2, l1);
        }
        for (int idx = 0; idx < l2.size(); ++idx) {
            l1.get(idx).addAll((Collection)l2.get(idx));
        }
        return l1;
    }

    private List<List<Tasklet>> getLocallyExecutableNodes(List<List<Node<?>>> orderedNodes, RcaConf conf, WireHopper hopper, Queryable db, Persistable persistable, Map<Node<?>, Tasklet> nodeTaskletMap, AppContext appContext) {
        HashSet locallyExecutableSet = new HashSet();
        ArrayList<List<Tasklet>> dependencyOrderedLocallyExecutable = new ArrayList<List<Tasklet>>();
        for (List<Node<?>> levelNodes : orderedNodes) {
            ArrayList<Tasklet> locallyExecutableInThisLevel = new ArrayList<Tasklet>();
            for (Node<?> node : levelNodes) {
                node.setAppContext(appContext);
                if (RcaUtil.shouldExecuteLocally(node, conf)) {
                    locallyExecutableSet.add(node);
                    node.readRcaConf(conf);
                    CreatedTasklets newTasklets = this.createTaskletAndSendIntent(node, locallyExecutableSet, hopper, db, persistable, nodeTaskletMap);
                    nodeTaskletMap.put(node, newTasklets.taskletForCurrentNode);
                    locallyExecutableInThisLevel.add(newTasklets.taskletForCurrentNode);
                    if (newTasklets.remoteTasklets.isEmpty()) continue;
                    if (dependencyOrderedLocallyExecutable.isEmpty()) {
                        dependencyOrderedLocallyExecutable.add(newTasklets.remoteTasklets);
                        continue;
                    }
                    int lastIdx = dependencyOrderedLocallyExecutable.size() - 1;
                    ((List)dependencyOrderedLocallyExecutable.get(lastIdx)).addAll(newTasklets.remoteTasklets);
                    continue;
                }
                LOG.debug("rca: tag NOT matched for node: {}", (Object)node.name());
                for (Node<?> upstreamNode : node.getUpstreams()) {
                    if (!locallyExecutableSet.contains(upstreamNode)) continue;
                    if (this.remotelyDesirableNodeSet.containsKey(upstreamNode)) {
                        this.remotelyDesirableNodeSet.get(upstreamNode).add(node);
                        continue;
                    }
                    ArrayList list = new ArrayList();
                    list.add(node);
                    this.remotelyDesirableNodeSet.put(upstreamNode, list);
                }
            }
            if (locallyExecutableInThisLevel.isEmpty()) continue;
            dependencyOrderedLocallyExecutable.add(locallyExecutableInThisLevel);
        }
        return dependencyOrderedLocallyExecutable;
    }

    protected CreatedTasklets createTaskletAndSendIntent(Node<?> graphNode, Set<Node<?>> locallyExecutableNodeSet, WireHopper hopper, Queryable db, Persistable persistable, Map<Node<?>, Tasklet> nodeTaskletMap) {
        Tasklet tasklet = new Tasklet(graphNode, db, persistable, this.remotelyDesirableNodeSet, hopper, GraphNodeOperations::readFromLocal);
        CreatedTasklets ret = new CreatedTasklets(tasklet);
        String aggregationLocus = graphNode.getTags().get("aggregate-upstream");
        for (Node<?> upstreamNode : graphNode.getUpstreams()) {
            if (locallyExecutableNodeSet.contains(upstreamNode)) {
                tasklet.addPredecessor(nodeTaskletMap.get(upstreamNode));
                Map<String, String> upstreamNodeTags = upstreamNode.getTags();
                List<String> upstreamNodeLoci = Arrays.asList(upstreamNodeTags.getOrDefault("locus", EMPTY_STRING).split(","));
                if (aggregationLocus == null || !upstreamNodeLoci.contains(aggregationLocus)) continue;
                this.addReadFromRemoteTasklet(graphNode, upstreamNode, hopper, db, persistable, tasklet, ret);
                continue;
            }
            this.addReadFromRemoteTasklet(graphNode, upstreamNode, hopper, db, persistable, tasklet, ret);
        }
        return ret;
    }

    private void addReadFromRemoteTasklet(Node<?> graphNode, Node<?> upstreamNode, WireHopper hopper, Queryable db, Persistable persistable, Tasklet tasklet, CreatedTasklets ret) {
        LOG.debug("rca: Node '{}' sending intent to consume node: '{}'", (Object)graphNode.name(), (Object)upstreamNode.name());
        IntentMsg msg = new IntentMsg(graphNode.name(), upstreamNode.name(), upstreamNode.getTags());
        hopper.sendIntent(msg);
        Tasklet remoteTasklet = new Tasklet(upstreamNode, db, persistable, this.remotelyDesirableNodeSet, hopper, GraphNodeOperations::readFromWire);
        LOG.debug("Tasklet created for REMOTE node '{}' with readFromWire", (Object)graphNode.name());
        tasklet.addPredecessor(remoteTasklet);
        ret.remoteTasklets.add(remoteTasklet);
    }

    @Override
    public void run() {
        ++this.currTick;
        long runStartTime = System.currentTimeMillis();
        SampleAggregator test = ServiceMetrics.RCA_GRAPH_METRICS_AGGREGATOR;
        test.updateStat((MeasurementSet)RcaGraphMetrics.NUM_GRAPH_NODES, (Number)Stats.getInstance().getTotalNodesCount());
        this.changeDbForTasklets();
        List<CompletableFuture<Void>> lastLevelTasks = this.createAsyncTasks();
        this.preWait();
        lastLevelTasks.forEach(CompletableFuture::join);
        this.postCompletion(runStartTime);
    }

    private void changeDbForTasklets() {
        if (this.newDb != null) {
            for (List<Tasklet> taskletsAtThisLevel : this.locallyExecutableTasklets) {
                for (Tasklet tasklet : taskletsAtThisLevel) {
                    tasklet.setDb(this.newDb);
                }
            }
            this.newDb = null;
        }
    }

    protected List<CompletableFuture<Void>> createAsyncTasks() {
        HashMap<Tasklet, CompletableFuture<Void>> taskletFutureMap = new HashMap<Tasklet, CompletableFuture<Void>>();
        ArrayList<CompletableFuture<Void>> lastLevel = new ArrayList<CompletableFuture<Void>>();
        for (List<Tasklet> taskletsAtThisLevel : this.locallyExecutableTasklets) {
            lastLevel.clear();
            for (Tasklet tasklet : taskletsAtThisLevel) {
                CompletableFuture<Void> taskletFuture = tasklet.execute(this.executorPool, taskletFutureMap);
                lastLevel.add(taskletFuture);
                taskletFutureMap.put(tasklet, taskletFuture);
            }
        }
        return lastLevel;
    }

    protected void preWait() {
    }

    protected void postCompletion(long runStartTime) {
        if (this.currTick == this.maxTicks) {
            this.currTick = 0;
            this.locallyExecutableTasklets.forEach(l -> l.forEach(Tasklet::resetTicks));
            LOG.debug("Finished ticking.");
        }
        long runEndTime = System.currentTimeMillis();
        long durationMillis = runEndTime - runStartTime;
        ServiceMetrics.RCA_GRAPH_METRICS_AGGREGATOR.updateStat((MeasurementSet)RcaGraphMetrics.GRAPH_EXECUTION_TIME, (Number)durationMillis);
        ServiceMetrics.RCA_GRAPH_METRICS_AGGREGATOR.updateStat((MeasurementSet)RcaGraphMetrics.NUM_GRAPH_NODES_MUTED, (Number)Stats.getInstance().getMutedGraphNodesCount());
    }

    @VisibleForTesting
    public void setNewDb(Queryable newDb) {
        this.newDb = newDb;
    }

    private static class CreatedTasklets {
        Tasklet taskletForCurrentNode;
        List<Tasklet> remoteTasklets;

        CreatedTasklets(Tasklet taskletForCurrentNode) {
            this.taskletForCurrentNode = taskletForCurrentNode;
            this.remoteTasklets = new ArrayList<Tasklet>();
        }
    }
}

