/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ml;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.caching.CacheProvider;
import org.opensearch.timeseries.caching.TimeSeriesCache;
import org.opensearch.timeseries.indices.IndexManagement;
import org.opensearch.timeseries.ml.CheckpointDao;
import org.opensearch.timeseries.ml.IntermediateResult;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelManager;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.model.IndexableResult;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.ratelimit.CheckpointWriteWorker;
import org.opensearch.timeseries.ratelimit.ColdStartWorker;
import org.opensearch.timeseries.ratelimit.FeatureRequest;
import org.opensearch.timeseries.ratelimit.RateLimitedRequestWorker;
import org.opensearch.timeseries.ratelimit.RequestPriority;
import org.opensearch.timeseries.ratelimit.SaveResultStrategy;
import org.opensearch.timeseries.stats.Stats;

public abstract class RealTimeInferencer<RCFModelType extends ThresholdedRandomCutForest, ResultType extends IndexableResult, RCFResultType extends IntermediateResult<ResultType>, IndexType extends Enum<IndexType>, IndexManagementType extends IndexManagement<IndexType>, CheckpointDaoType extends CheckpointDao<RCFModelType, IndexType, IndexManagementType>, CheckpointWriterType extends CheckpointWriteWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType>, ColdStarterType extends ModelColdStart<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType, CheckpointWriterType>, ModelManagerType extends ModelManager<RCFModelType, ResultType, RCFResultType, IndexType, IndexManagementType, CheckpointDaoType, CheckpointWriterType, ColdStarterType>, SaveResultStrategyType extends SaveResultStrategy<ResultType, RCFResultType>, CacheType extends TimeSeriesCache<RCFModelType>, ColdStartWorkerType extends ColdStartWorker<RCFModelType, IndexType, IndexManagementType, CheckpointDaoType, CheckpointWriterType, ColdStarterType, CacheType, ResultType, RCFResultType, ModelManagerType, SaveResultStrategyType>> {
    private static final Logger LOG = LogManager.getLogger(RealTimeInferencer.class);
    protected ModelManagerType modelManager;
    protected Stats stats;
    private String modelCorruptionStat;
    protected CheckpointDaoType checkpointDao;
    protected ColdStartWorkerType coldStartWorker;
    protected SaveResultStrategyType resultWriteWorker;
    private CacheProvider<RCFModelType, CacheType> cache;
    private Map<String, Lock> modelLocks = Collections.synchronizedMap(new WeakHashMap());
    private ThreadPool threadPool;
    private String threadPoolName;

    public RealTimeInferencer(ModelManagerType modelManager, Stats stats, String modelCorruptionStat, CheckpointDaoType checkpointDao, ColdStartWorkerType coldStartWorker, SaveResultStrategyType resultWriteWorker, CacheProvider<RCFModelType, CacheType> cache, ThreadPool threadPool, String threadPoolName) {
        this.modelManager = modelManager;
        this.stats = stats;
        this.modelCorruptionStat = modelCorruptionStat;
        this.checkpointDao = checkpointDao;
        this.coldStartWorker = coldStartWorker;
        this.resultWriteWorker = resultWriteWorker;
        this.cache = cache;
        this.threadPool = threadPool;
        this.threadPoolName = threadPoolName;
        this.modelLocks = Collections.synchronizedMap(new WeakHashMap());
    }

    public boolean process(Sample sample, ModelState<RCFModelType> modelState, Config config, String taskId) {
        long windowDelayMillis = config.getWindowDelay() == null ? 0L : ((IntervalTimeConfiguration)config.getWindowDelay()).toDuration().toMillis();
        long curExecutionEnd = sample.getDataEndTime().toEpochMilli() + windowDelayMillis;
        long nextExecutionEnd = curExecutionEnd + config.getIntervalInMilliseconds();
        return this.processWithTimeout(sample, modelState, config, taskId, curExecutionEnd, nextExecutionEnd);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean processWithTimeout(Sample sample, ModelState<RCFModelType> modelState, Config config, String taskId, long curExecutionEnd, long nextExecutionEnd) {
        String modelId = modelState.getModelId();
        ReentrantLock lock = (ReentrantLock)this.modelLocks.computeIfAbsent(modelId, k -> new ReentrantLock());
        if (lock.tryLock()) {
            try {
                this.tryProcess(sample, modelState, config, taskId, curExecutionEnd);
            }
            finally {
                if (lock.isHeldByCurrentThread()) {
                    lock.unlock();
                }
            }
            return true;
        }
        if (System.currentTimeMillis() >= nextExecutionEnd) {
            LOG.warn("Timeout reached, not retrying.");
        } else {
            this.threadPool.schedule(() -> this.processWithTimeout(sample, modelState, config, taskId, curExecutionEnd, nextExecutionEnd), new TimeValue(1L, TimeUnit.SECONDS), this.threadPoolName);
        }
        return false;
    }

    private boolean tryProcess(Sample sample, ModelState<RCFModelType> modelState, Config config, String taskId, long curExecutionEnd) {
        String modelId = modelState.getModelId();
        try {
            Object result = ((ModelManager)this.modelManager).getResult(sample, modelState, modelId, config, taskId);
            this.resultWriteWorker.saveResult(result, config, sample.getDataStartTime(), sample.getDataEndTime(), modelId, sample.getValueList(), modelState.getEntity(), taskId);
        }
        catch (IllegalArgumentException e) {
            if (e.getMessage() != null && e.getMessage().contains("incorrect ordering of time")) {
                LOG.warn(String.format(Locale.ROOT, "incorrect ordering of time for config %s model %s at data end time %d", config.getId(), modelState.getModelId(), sample.getDataEndTime().toEpochMilli()));
            } else {
                this.reColdStart(config, modelId, e, sample, taskId);
            }
            return false;
        }
        catch (Exception e) {
            this.reColdStart(config, modelId, e, sample, taskId);
        }
        return true;
    }

    private void reColdStart(Config config, String modelId, Exception e, Sample sample, String taskId) {
        LOG.error((Message)new ParameterizedMessage("Likely model corruption for [{}]", (Object)modelId), (Throwable)e);
        this.stats.getStat(this.modelCorruptionStat).increment();
        this.cache.get().removeModel(config.getId(), modelId);
        if (null != modelId) {
            ((CheckpointDao)this.checkpointDao).deleteModelCheckpoint(modelId, (ActionListener<Void>)ActionListener.wrap(r -> LOG.debug((Message)new ParameterizedMessage("Succeeded in deleting checkpoint [{}].", (Object)modelId)), ex -> LOG.error((Message)new ParameterizedMessage("Failed to delete checkpoint [{}].", (Object)modelId), (Throwable)ex)));
        }
        ((RateLimitedRequestWorker)this.coldStartWorker).put((FeatureRequest)new FeatureRequest(System.currentTimeMillis() + config.getIntervalInMilliseconds(), config.getId(), RequestPriority.MEDIUM, modelId, sample.getValueList(), sample.getDataStartTime().toEpochMilli(), taskId));
    }
}

