/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.forecast.ml;

import com.amazon.randomcutforest.config.ForestMode;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.config.Calibration;
import java.time.Clock;
import java.time.Duration;
import java.util.List;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.forecast.indices.ForecastIndex;
import org.opensearch.forecast.indices.ForecastIndexManagement;
import org.opensearch.forecast.ml.ForecastCheckpointDao;
import org.opensearch.forecast.model.Forecaster;
import org.opensearch.forecast.ratelimit.ForecastCheckpointWriteWorker;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.timeseries.AnalysisType;
import org.opensearch.timeseries.NodeStateManager;
import org.opensearch.timeseries.feature.FeatureManager;
import org.opensearch.timeseries.feature.SearchFeatureDao;
import org.opensearch.timeseries.ml.ModelColdStart;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.Sample;
import org.opensearch.timeseries.model.Config;
import org.opensearch.timeseries.ratelimit.RequestPriority;

public class ForecastColdStart
extends ModelColdStart<RCFCaster, ForecastIndex, ForecastIndexManagement, ForecastCheckpointDao, ForecastCheckpointWriteWorker> {
    private static final Logger logger = LogManager.getLogger(ForecastColdStart.class);

    public ForecastColdStart(Clock clock, ThreadPool threadPool, NodeStateManager nodeStateManager, int rcfSampleSize, int numberOfTrees, int numMinSamples, SearchFeatureDao searchFeatureDao, double thresholdMinPvalue, FeatureManager featureManager, Duration modelTtl, ForecastCheckpointWriteWorker checkpointWriteWorker, int coolDownMinutes, long rcfSeed, int defaultTrainSamples, int maxRoundofColdStart) {
        super(modelTtl, coolDownMinutes, clock, threadPool, numMinSamples, checkpointWriteWorker, rcfSeed, numberOfTrees, rcfSampleSize, thresholdMinPvalue, nodeStateManager, 1, defaultTrainSamples, searchFeatureDao, featureManager, maxRoundofColdStart, "forecast-threadpool", AnalysisType.FORECAST);
    }

    @Override
    protected List<Sample> trainModelFromDataSegments(List<Sample> pointSamples, ModelState<RCFCaster> modelState, Config config, String taskId) {
        if (pointSamples == null || pointSamples.size() == 0) {
            logger.info("Return early since data points must not be empty.");
            return null;
        }
        double[] firstPoint = pointSamples.get(0).getValueList();
        if (firstPoint == null || firstPoint.length == 0) {
            logger.info("Return early since data points must not be empty.");
            return null;
        }
        int shingleSize = config.getShingleSize();
        int forecastHorizon = ((Forecaster)config).getHorizon();
        int dimensions = firstPoint.length * shingleSize;
        RCFCaster.Builder casterBuilder = (RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)((RCFCaster.Builder)RCFCaster.builder().dimensions(dimensions)).numberOfTrees(this.numberOfTrees)).shingleSize(shingleSize)).sampleSize(this.rcfSampleSize)).internalShinglingEnabled(true)).precision(Precision.FLOAT_32)).anomalyRate(1.0 - this.thresholdMinPvalue)).outputAfter(Math.max(shingleSize, this.numMinSamples))).calibration(Calibration.MINIMAL).timeDecay(config.getTimeDecay().doubleValue())).parallelExecutionEnabled(false)).boundingBoxCacheFraction(0.0)).transformDecay(config.getTimeDecay().doubleValue())).forecastHorizon(forecastHorizon).initialAcceptFraction(this.initialAcceptFraction)).transformMethod(TransformMethod.NORMALIZE)).forestMode(ForestMode.STANDARD);
        casterBuilder = ForecastColdStart.applyImputationMethod(config, casterBuilder);
        if (this.rcfSeed > 0L) {
            casterBuilder.randomSeed(this.rcfSeed);
        }
        RCFCaster caster = casterBuilder.build();
        for (int i = 0; i < pointSamples.size(); ++i) {
            Sample dataSample = pointSamples.get(i);
            double[] dataValue = dataSample.getValueList();
            caster.process(dataValue, dataSample.getDataEndTime().getEpochSecond());
        }
        modelState.setModel(caster);
        modelState.setLastUsedTime(this.clock.instant());
        if (null == taskId) {
            ((ForecastCheckpointWriteWorker)this.checkpointWriteWorker).write(modelState, true, RequestPriority.MEDIUM);
        }
        return pointSamples;
    }
}

