/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ratelimit;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.Set;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.get.MultiGetItemResponse;
import org.opensearch.action.get.MultiGetRequest;
import org.opensearch.action.get.MultiGetResponse;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Provider;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.breaker.CircuitBreakerService;
import org.opensearch.timeseries.caching.TimeSeriesCache;
import org.opensearch.timeseries.common.exception.EndRunException;
import org.opensearch.timeseries.indices.IndexManagement;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.IntermediateResult;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.RealTimeInferencer;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IndexableResult;
import org.opensearch.timeseries.ratelimit.BatchWorker;
import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker;
import org.opensearch.timeseries.ratelimit.ColdStartWorker;
import org.opensearch.timeseries.ratelimit.FeatureRequest;
import org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker;
import org.opensearch.timeseries.ratelimit.RequestPriority;
import org.opensearch.timeseries.ratelimit.SaveResultStrategy;
import org.opensearch.timeseries.util.ActionListenerExecutor;
import org.opensearch.timeseries.util.ExceptionUtil;

public abstract class CheckpointReadWorker<RCFModelType extends ThresholdedRandomCutForest, ResultType extends IndexableResult, RCFResultType extends IntermediateResult<ResultType>, IndexType extends Enum<IndexType>, IndexManagementType extends IndexManagement<IndexType>, CheckpointType extends CheckpointDao<RCFModelType, IndexType, IndexManagementType>, CheckpointWriteWorkerType extends CheckpointWriteWorker<RCFModelType, IndexType, IndexManagementType, CheckpointType>, ColdStarterType extends ModelColdStart<RCFModelType, IndexType, IndexManagementType, CheckpointType, CheckpointWriteWorkerType>, ModelManagerType extends ModelManager<RCFModelType, ResultType, RCFResultType, IndexType, IndexManagementType, CheckpointType, CheckpointWriteWorkerType, ColdStarterType>, CacheType extends TimeSeriesCache<RCFModelType>, SaveResultStrategyType extends SaveResultStrategy<ResultType, RCFResultType>, ColdStartWorkerType extends ColdStartWorker<RCFModelType, IndexType, IndexManagementType, CheckpointType, CheckpointWriteWorkerType, ColdStarterType, CacheType, ResultType, RCFResultType, ModelManagerType, SaveResultStrategyType>, InferencerType extends RealTimeInferencer<RCFModelType, ResultType, RCFResultType, IndexType, IndexManagementType, CheckpointType, CheckpointWriteWorkerType, ColdStarterType, ModelManagerType, SaveResultStrategyType, CacheType, ColdStartWorkerType>>
extends BatchWorker<FeatureRequest, MultiGetRequest, MultiGetResponse> {
    private static final Logger LOG = LogManager.getLogger(CheckpointReadWorker.class);
    protected final ModelManagerType modelManager;
    protected final CheckpointType checkpointDao;
    protected final ColdStartWorkerType coldStartWorker;
    protected final CheckpointWriteWorkerType checkpointWriteWorker;
    protected final Provider<? extends TimeSeriesCache<RCFModelType>> cacheProvider;
    protected final String checkpointIndexName;
    protected final InferencerType inferencer;

    public CheckpointReadWorker(String workerName, long heapSizeInBytes, int singleRequestSizeInBytes, Setting<Float> maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, CircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, String threadPoolName, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, ModelManagerType modelManager, CheckpointType checkpointDao, ColdStartWorkerType entityColdStartWorker, NodeStateManager stateManager, Provider<? extends TimeSeriesCache<RCFModelType>> cacheProvider, Duration stateTtl, CheckpointWriteWorkerType checkpointWriteWorker, Setting<Integer> concurrencySetting, Setting<Integer> batchSizeSetting, String checkpointIndexName, AnalysisType context, InferencerType inferencer) {
        super(workerName, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, threadPoolName, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, concurrencySetting, executionTtl, batchSizeSetting, stateTtl, stateManager, context);
        this.modelManager = modelManager;
        this.checkpointDao = checkpointDao;
        this.coldStartWorker = entityColdStartWorker;
        this.cacheProvider = cacheProvider;
        this.checkpointWriteWorker = checkpointWriteWorker;
        this.checkpointIndexName = checkpointIndexName;
        this.inferencer = inferencer;
    }

    @Override
    protected void executeBatchRequest(MultiGetRequest request, ActionListener<MultiGetResponse> listener) {
        ((CheckpointDao)this.checkpointDao).batchRead(request, listener);
    }

    @Override
    protected MultiGetRequest toBatchRequest(List<FeatureRequest> toProcess) {
        MultiGetRequest multiGetRequest = new MultiGetRequest();
        for (FeatureRequest request : toProcess) {
            String modelId = request.getModelId();
            if (null == modelId) continue;
            multiGetRequest.add(new MultiGetRequest.Item(this.checkpointIndexName, modelId));
        }
        return multiGetRequest;
    }

    @Override
    protected ActionListener<MultiGetResponse> getResponseListener(List<FeatureRequest> toProcess, MultiGetRequest batchRequest) {
        return ActionListener.wrap(response -> {
            MultiGetItemResponse[] itemResponses = response.getResponses();
            HashMap<String, MultiGetItemResponse> successfulRequests = new HashMap<String, MultiGetItemResponse>();
            HashSet<String> retryableRequests = null;
            HashSet<String> notFoundModels = null;
            boolean printedUnexpectedFailure = false;
            HashMap<String, Exception> stopDetectorRequests = null;
            for (MultiGetItemResponse itemResponse : itemResponses) {
                String modelId = itemResponse.getId();
                if (itemResponse.isFailed()) {
                    Exception failure = itemResponse.getFailure().getFailure();
                    if (failure instanceof IndexNotFoundException) {
                        for (FeatureRequest origRequest : toProcess) {
                            ((RateLimitedRequestWorker)this.coldStartWorker).put((FeatureRequest)origRequest);
                        }
                        return;
                    }
                    if (ExceptionUtil.isRetryAble(failure)) {
                        if (retryableRequests == null) {
                            retryableRequests = new HashSet<String>();
                        }
                        retryableRequests.add(modelId);
                        continue;
                    }
                    if (ExceptionUtil.isOverloaded(failure)) {
                        LOG.error("too many get model checkpoint requests or shard not available");
                        this.setCoolDownStart();
                        continue;
                    }
                    if (!printedUnexpectedFailure) {
                        LOG.error("Unexpected failure", (Throwable)failure);
                        printedUnexpectedFailure = true;
                    }
                    if (stopDetectorRequests == null) {
                        stopDetectorRequests = new HashMap<String, Exception>();
                    }
                    stopDetectorRequests.put(modelId, failure);
                    continue;
                }
                if (!itemResponse.getResponse().isExists()) {
                    if (notFoundModels == null) {
                        notFoundModels = new HashSet<String>();
                    }
                    notFoundModels.add(modelId);
                    continue;
                }
                successfulRequests.put(modelId, itemResponse);
            }
            if (notFoundModels != null) {
                for (FeatureRequest origRequest : toProcess) {
                    String modelId = origRequest.getModelId();
                    if (modelId == null || !notFoundModels.contains(modelId)) continue;
                    ((RateLimitedRequestWorker)this.coldStartWorker).put((FeatureRequest)origRequest);
                }
            }
            if (stopDetectorRequests != null) {
                for (FeatureRequest origRequest : toProcess) {
                    String modelId = origRequest.getModelId();
                    if (modelId == null || !stopDetectorRequests.containsKey(modelId)) continue;
                    String configID = origRequest.getConfigId();
                    this.nodeStateManager.setException(configID, new EndRunException(configID, "We might have bugs.", (Throwable)stopDetectorRequests.get(modelId), false));
                    break;
                }
            }
            if (successfulRequests.isEmpty() && (retryableRequests == null || retryableRequests.isEmpty())) {
                return;
            }
            this.processCheckpointIteration(0, toProcess, successfulRequests, (Set<String>)retryableRequests);
        }, exception -> {
            if (ExceptionUtil.isOverloaded(exception)) {
                LOG.error("too many get model checkpoint requests or shard not available");
                this.setCoolDownStart();
            } else if (ExceptionUtil.isRetryAble(exception)) {
                this.putAll(toProcess);
            } else {
                LOG.error("Fail to restore models", (Throwable)exception);
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    protected void processCheckpointIteration(int i, List<FeatureRequest> toProcess, Map<String, MultiGetItemResponse> successfulRequests, Set<String> retryableRequests) {
        if (i >= toProcess.size()) {
            return;
        }
        boolean processNextInCallBack = false;
        try {
            FeatureRequest origRequest = toProcess.get(i);
            String modelId = origRequest.getModelId();
            if (null == modelId) {
                return;
            }
            String configId = origRequest.getConfigId();
            MultiGetItemResponse checkpointResponse = successfulRequests.get(modelId);
            if (checkpointResponse != null) {
                ModelState modelState = ((CheckpointDao)this.checkpointDao).processHCGetResponse(checkpointResponse.getResponse(), modelId, configId);
                if (null == modelState) {
                    ((RateLimitedRequestWorker)this.coldStartWorker).put((FeatureRequest)origRequest);
                    return;
                }
                this.nodeStateManager.getConfig(configId, this.context, this.processIterationUsingConfig(origRequest, i, configId, toProcess, successfulRequests, retryableRequests, modelState, modelId));
                processNextInCallBack = true;
            } else if (retryableRequests != null && retryableRequests.contains(modelId)) {
                super.put(origRequest);
            }
        }
        finally {
            if (!processNextInCallBack) {
                this.processCheckpointIteration(i + 1, toProcess, successfulRequests, retryableRequests);
            }
        }
    }

    protected ActionListener<Optional<? extends Config>> processIterationUsingConfig(FeatureRequest origRequest, int index, String configId, List<FeatureRequest> toProcess, Map<String, MultiGetItemResponse> successfulRequests, Set<String> retryableRequests, ModelState<RCFModelType> restoredModelState, String modelId) {
        return ActionListenerExecutor.wrap(configOptional -> {
            boolean loaded;
            if (configOptional.isEmpty()) {
                LOG.warn((Message)new ParameterizedMessage("Config [{}] is not available.", (Object)configId));
                this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
                return;
            }
            Config config = (Config)configOptional.get();
            boolean processed = ((RealTimeInferencer)this.inferencer).process(new Sample(origRequest.getCurrentFeature(), Instant.ofEpochMilli(origRequest.getDataStartTimeMillis()), Instant.ofEpochMilli(origRequest.getDataStartTimeMillis() + config.getIntervalInMilliseconds())), restoredModelState, config, origRequest.getTaskId());
            if (processed && !(loaded = ((TimeSeriesCache)this.cacheProvider.get()).hostIfPossible(config, restoredModelState))) {
                ((CheckpointWriteWorker)this.checkpointWriteWorker).write(restoredModelState, true, RequestPriority.LOW);
            }
            this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
        }, exception -> {
            LOG.error((Message)new ParameterizedMessage("fail to get checkpoint [{}]", (Object)modelId, exception));
            this.nodeStateManager.setException(configId, (Exception)exception);
            this.processCheckpointIteration(index + 1, toProcess, successfulRequests, retryableRequests);
        }, this.threadPool.executor(this.threadPoolName));
    }
}

