/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ml;

import com.amazon.randomcutforest.config.ImputationMethod;
import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.core.util.Throwables;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.CleanState;
import org.opensearch.timeseries.MaintenanceState;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.caching.DoorKeeper;
import org.opensearch.timeseries.common.exception.EndRunException;
import org.opensearch.timeseries.common.exception.TimeSeriesException;
import org.opensearch.timeseries.dataprocessor.ImputationOption;
import org.opensearch.timeseries.feature.FeatureManager;
import org.opensearch.timeseries.feature.SearchFeatureDao;
import org.opensearch.timeseries.indices.IndexManagement;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.Entity;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker;
import org.opensearch.timeseries.ratelimit.FeatureRequest;
import org.opensearch.timeseries.ratelimit.RequestPriority;
import org.opensearch.timeseries.util.ExceptionUtil;

public abstract class ModelColdStart<RCFModelType extends ThresholdedRandomCutForest, IndexType extends Enum<IndexType>, IndexManagementType extends IndexManagement<IndexType>, CheckpointDaoType extends CheckpointDao<RCFModelType, IndexType, IndexManagementType>, CheckpointWriteWorkerType extends CheckpointWriteWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType>>
implements MaintenanceState,
CleanState {
    private static final Logger logger = LogManager.getLogger(ModelColdStart.class);
    private final Duration modelTtl;
    protected Map<String, DoorKeeper> doorKeepers;
    protected Instant lastThrottledColdStartTime;
    protected int coolDownMinutes;
    protected final Clock clock;
    protected final ThreadPool threadPool;
    protected final int numMinSamples;
    protected CheckpointWriteWorkerType checkpointWriteWorker;
    protected final long rcfSeed;
    protected final int numberOfTrees;
    protected final int rcfSampleSize;
    protected final double thresholdMinPvalue;
    protected final double initialAcceptFraction;
    protected final NodeStateManager nodeStateManager;
    protected final int defaulStrideLength;
    protected final int defaultNumberOfSamples;
    protected final SearchFeatureDao searchFeatureDao;
    protected final FeatureManager featureManager;
    protected final int maxRoundofColdStart;
    protected final String threadPoolName;
    protected final AnalysisType context;

    public ModelColdStart(Duration modelTtl, int coolDownMinutes, Clock clock, ThreadPool threadPool, int numMinSamples, CheckpointWriteWorkerType checkpointWriteWorker, long rcfSeed, int numberOfTrees, int rcfSampleSize, double thresholdMinPvalue, NodeStateManager nodeStateManager, int defaultSampleStride, int defaultTrainSamples, SearchFeatureDao searchFeatureDao, FeatureManager featureManager, int maxRoundofColdStart, String threadPoolName, AnalysisType context) {
        this.modelTtl = modelTtl;
        this.coolDownMinutes = coolDownMinutes;
        this.clock = clock;
        this.threadPool = threadPool;
        this.numMinSamples = numMinSamples;
        this.checkpointWriteWorker = checkpointWriteWorker;
        this.rcfSeed = rcfSeed;
        this.numberOfTrees = numberOfTrees;
        this.rcfSampleSize = rcfSampleSize;
        this.thresholdMinPvalue = thresholdMinPvalue;
        this.doorKeepers = new ConcurrentHashMap<String, DoorKeeper>();
        this.lastThrottledColdStartTime = Instant.MIN;
        this.initialAcceptFraction = (double)numMinSamples * 1.0 / (double)rcfSampleSize;
        this.nodeStateManager = nodeStateManager;
        this.defaulStrideLength = defaultSampleStride;
        this.defaultNumberOfSamples = defaultTrainSamples;
        this.searchFeatureDao = searchFeatureDao;
        this.featureManager = featureManager;
        this.maxRoundofColdStart = maxRoundofColdStart;
        this.threadPoolName = threadPoolName;
        this.context = context;
    }

    @Override
    public void maintenance() {
        this.doorKeepers.entrySet().stream().forEach(doorKeeperEntry -> {
            String id = (String)doorKeeperEntry.getKey();
            DoorKeeper doorKeeper = (DoorKeeper)doorKeeperEntry.getValue();
            if (doorKeeper.expired(this.modelTtl)) {
                this.doorKeepers.remove(id);
            } else {
                doorKeeper.maintenance();
            }
        });
    }

    @Override
    public void clear(String id) {
        this.doorKeepers.remove(id);
    }

    public void trainModel(FeatureRequest coldStartRequest, String configId, ModelState<RCFModelType> modelState, ActionListener<List<Sample>> listener) {
        this.nodeStateManager.getConfig(configId, this.context, (ActionListener<Optional<? extends Config>>)ActionListener.wrap(configOptional -> {
            if (!configOptional.isPresent()) {
                logger.warn((Message)new ParameterizedMessage("Config [{}] is not available.", (Object)configId));
                listener.onFailure((Exception)new TimeSeriesException(configId, "fail to find config"));
                return;
            }
            Config config = (Config)configOptional.get();
            String modelId = modelState.getModelId();
            if (modelState.getSamples().size() < this.numMinSamples) {
                this.coldStart(modelId, coldStartRequest, modelState, config, listener);
            } else {
                try {
                    this.trainModelFromExistingSamples(modelState, config, coldStartRequest.getTaskId());
                    listener.onResponse(null);
                }
                catch (Exception e) {
                    listener.onFailure(e);
                }
            }
        }, arg_0 -> listener.onFailure(arg_0)));
    }

    public void trainModelFromExistingSamples(ModelState<RCFModelType> modelState, Config config, String taskId) {
        if (modelState.getSamples().size() >= this.numMinSamples) {
            Deque<Sample> samples = modelState.getSamples();
            this.trainModelFromDataSegments(new ArrayList<Sample>(samples), modelState, config, taskId);
            modelState.clearSamples();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void coldStart(String modelId, FeatureRequest coldStartRequest, ModelState<RCFModelType> modelState, Config config, ActionListener<List<Sample>> listener) {
        logger.debug("Trigger cold start for {}", (Object)modelId);
        if (modelState == null) {
            listener.onFailure((Exception)new IllegalArgumentException(String.format(Locale.ROOT, "Cannot have empty model state", new Object[0])));
            return;
        }
        if (this.lastThrottledColdStartTime.plus(Duration.ofMinutes(this.coolDownMinutes)).isAfter(this.clock.instant())) {
            listener.onResponse(null);
            return;
        }
        String configId = config.getId();
        boolean earlyExit = true;
        try {
            DoorKeeper doorKeeper = this.doorKeepers.computeIfAbsent(configId, id -> new DoorKeeper(100000L, config.getIntervalDuration().multipliedBy(60L), this.clock, 3));
            if (doorKeeper.appearsMoreThanOrEqualToThreshold(modelId)) {
                logger.info("Won't retry real-time cold start within {} intervals for model {}", (Object)60, (Object)modelId);
                return;
            }
            doorKeeper.put(modelId);
            ActionListener coldStartCallBack = ActionListener.wrap(trainingData -> {
                modelState.clearSamples();
                if (trainingData != null && !trainingData.isEmpty()) {
                    int dataSize = trainingData.size();
                    if (dataSize >= this.numMinSamples) {
                        List<Sample> processedTrainingData = this.trainModelFromDataSegments((List<Sample>)trainingData, modelState, config, coldStartRequest.getTaskId());
                        logger.info("Succeeded in training entity: {}", (Object)modelId);
                        listener.onResponse(processedTrainingData);
                    } else {
                        logger.info("Not enough data to train model: {}, currently we have {}", (Object)modelId, (Object)dataSize);
                        trainingData.forEach(modelState::addSample);
                        ((CheckpointWriteWorker)this.checkpointWriteWorker).write(modelState, true, RequestPriority.MEDIUM);
                        listener.onResponse(null);
                    }
                } else {
                    logger.info("Cannot get training data for {}", (Object)modelId);
                    listener.onResponse(null);
                }
            }, exception -> {
                try {
                    logger.error((Message)new ParameterizedMessage("Error while cold start {}", (Object)modelId), (Throwable)exception);
                    Throwable cause = Throwables.getRootCause((Throwable)exception);
                    if (ExceptionUtil.isOverloaded(cause)) {
                        logger.error("too many requests");
                        this.lastThrottledColdStartTime = Instant.now();
                    } else if (cause instanceof TimeSeriesException || exception instanceof TimeSeriesException) {
                        this.nodeStateManager.setException(configId, (Exception)exception);
                    } else {
                        this.nodeStateManager.setException(configId, new TimeSeriesException(configId, cause));
                    }
                    listener.onFailure(exception);
                }
                catch (Exception e) {
                    listener.onFailure(e);
                }
            });
            this.threadPool.executor(this.threadPoolName).execute(() -> this.getColdStartData(configId, coldStartRequest, (ActionListener<List<Sample>>)new ThreadedActionListener(logger, this.threadPool, this.threadPoolName, coldStartCallBack, false)));
            earlyExit = false;
        }
        finally {
            if (earlyExit) {
                listener.onResponse(null);
            }
        }
    }

    private void getColdStartData(String configId, FeatureRequest coldStartRequest, ActionListener<List<Sample>> listener) {
        ActionListener getDetectorListener = ActionListener.wrap(configOp -> {
            if (!configOp.isPresent()) {
                listener.onFailure((Exception)new EndRunException(configId, "Config is not available.", false));
                return;
            }
            Config config = (Config)configOp.get();
            ActionListener minTimeListener = ActionListener.wrap(earliest -> {
                if (earliest.isPresent()) {
                    long startTimeMs = (Long)earliest.get();
                    long endTimeMs = coldStartRequest.getDataStartTimeMillis();
                    int numberOfSamples = this.selectNumberOfSamples(config);
                    this.getFeatures(listener, 0, new ArrayList<Sample>(), config, coldStartRequest.getEntity(), numberOfSamples, startTimeMs, endTimeMs);
                } else {
                    listener.onResponse(new ArrayList());
                }
            }, arg_0 -> ((ActionListener)listener).onFailure(arg_0));
            this.searchFeatureDao.getMinDataTime(config, coldStartRequest.getEntity(), this.context, (ActionListener<Optional<Long>>)new ThreadedActionListener(logger, this.threadPool, this.threadPoolName, minTimeListener, false));
        }, arg_0 -> listener.onFailure(arg_0));
        this.nodeStateManager.getConfig(configId, this.context, (ActionListener<Optional<? extends Config>>)new ThreadedActionListener(logger, this.threadPool, this.threadPoolName, getDetectorListener, false));
    }

    private int selectNumberOfSamples(Config config) {
        return Math.max(this.numMinSamples, config.getHistoryIntervals());
    }

    private void getFeatures(ActionListener<List<Sample>> listener, int round, List<Sample> lastRounddataSample, Config config, Optional<Entity> entity, int numberOfSamples, long startTimeMs, long endTimeMs) {
        if (startTimeMs >= endTimeMs || endTimeMs - startTimeMs < config.getIntervalInMilliseconds()) {
            listener.onResponse(lastRounddataSample);
            return;
        }
        List<Map.Entry<Long, Long>> sampleRanges = this.searchFeatureDao.getTrainSampleRanges((IntervalTimeConfiguration)config.getInterval(), startTimeMs, endTimeMs, numberOfSamples);
        if (sampleRanges.isEmpty()) {
            listener.onResponse(lastRounddataSample);
            return;
        }
        ActionListener getFeaturelistener = ActionListener.wrap(featureSamples -> {
            int totalNumSamples = featureSamples.size();
            if (totalNumSamples != sampleRanges.size()) {
                String err = String.format(Locale.ROOT, "length mismatch: totalNumSamples %d != time range length %d", totalNumSamples, sampleRanges.size());
                listener.onFailure((Exception)new IllegalArgumentException(err));
                return;
            }
            ArrayList<Sample> samples = new ArrayList<Sample>();
            for (int index = 0; index < featureSamples.size(); ++index) {
                Optional featuresOptional = (Optional)featureSamples.get(index);
                if (!featuresOptional.isPresent()) continue;
                Map.Entry curRange = (Map.Entry)sampleRanges.get(index);
                samples.add(new Sample((double[])featuresOptional.get(), Instant.ofEpochMilli((Long)curRange.getKey()), Instant.ofEpochMilli((Long)curRange.getValue())));
            }
            ArrayList<Sample> concatenatedDataSample = null;
            if (lastRounddataSample != null && lastRounddataSample.size() > 0) {
                concatenatedDataSample = new ArrayList();
                concatenatedDataSample.addAll(samples);
                concatenatedDataSample.addAll(lastRounddataSample);
            } else {
                concatenatedDataSample = samples;
            }
            if (concatenatedDataSample.size() >= this.numMinSamples || round + 1 >= this.maxRoundofColdStart) {
                listener.onResponse(concatenatedDataSample);
            } else {
                long earliestSampleStartTime = (Long)((Map.Entry)sampleRanges.get(0)).getKey();
                this.getFeatures(listener, round + 1, concatenatedDataSample, config, entity, numberOfSamples, startTimeMs, earliestSampleStartTime);
            }
        }, arg_0 -> listener.onFailure(arg_0));
        try {
            this.searchFeatureDao.getColdStartSamplesForPeriods(config, sampleRanges, entity, true, this.context, (ActionListener<List<Optional<double[]>>>)new ThreadedActionListener(logger, this.threadPool, this.threadPoolName, getFeaturelistener, false));
        }
        catch (Exception e) {
            listener.onFailure(e);
        }
    }

    public static <T extends ThresholdedRandomCutForest.Builder<T>> T applyImputationMethod(Config config, T builder) {
        ImputationOption imputationOption = config.getImputationOption();
        if (imputationOption == null) {
            return (T)builder.imputationMethod(ImputationMethod.PREVIOUS);
        }
        switch (imputationOption.getMethod()) {
            case ZERO: {
                return (T)builder.imputationMethod(ImputationMethod.ZERO);
            }
            case FIXED_VALUES: {
                List<String> enabledFeatureName = config.getEnabledFeatureNames();
                double[] fillValues = new double[enabledFeatureName.size()];
                Map<String, Double> defaultFillMap = imputationOption.getDefaultFill();
                for (int i = 0; i < enabledFeatureName.size(); ++i) {
                    fillValues[i] = defaultFillMap.get(enabledFeatureName.get(i));
                }
                return (T)builder.imputationMethod(ImputationMethod.FIXED_VALUES).fillValues(fillValues);
            }
            case PREVIOUS: {
                return (T)builder.imputationMethod(ImputationMethod.PREVIOUS);
            }
        }
        return (T)builder.imputationMethod(ImputationMethod.PREVIOUS);
    }

    protected abstract List<Sample> trainModelFromDataSegments(List<Sample> var1, ModelState<RCFModelType> var2, Config var3, String var4);
}

