/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.timeseries.ml;

import com.amazon.randomcutforest.parkservices.ThresholdedRandomCutForest;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.opensearch.timeseries.MemoryTracker;
import org.opensearch.timeseries.ml.ModelState;
import org.opensearch.timeseries.ml.SingleStreamModelIdMapper;

public class MemoryAwareConcurrentHashmap<RCFModelType extends ThresholdedRandomCutForest>
extends ConcurrentHashMap<String, ModelState<RCFModelType>> {
    protected final MemoryTracker memoryTracker;

    public MemoryAwareConcurrentHashmap(MemoryTracker memoryTracker) {
        this.memoryTracker = memoryTracker;
    }

    @Override
    public ModelState<RCFModelType> remove(Object key) {
        ModelState deletedModelState = (ModelState)super.remove(key);
        if (deletedModelState != null && deletedModelState.getModel().isPresent()) {
            long memoryToRelease = this.memoryTracker.estimateTRCFModelSize((ThresholdedRandomCutForest)deletedModelState.getModel().get());
            this.memoryTracker.releaseMemory(memoryToRelease, true, MemoryTracker.Origin.REAL_TIME_DETECTOR);
        }
        return deletedModelState;
    }

    @Override
    public ModelState<RCFModelType> put(String key, ModelState<RCFModelType> value) {
        ModelState<RCFModelType> previousAssociatedState = super.put(key, value);
        if (value != null && value.getModel().isPresent()) {
            long memoryToConsume = this.memoryTracker.estimateTRCFModelSize((ThresholdedRandomCutForest)value.getModel().get());
            this.memoryTracker.consumeMemory(memoryToConsume, true, MemoryTracker.Origin.REAL_TIME_DETECTOR);
        }
        return previousAssociatedState;
    }

    public Map<String, Long> getModelSize(String configId) {
        HashMap<String, Long> res = new HashMap<String, Long>();
        super.entrySet().stream().filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId((String)entry.getKey()).equals(configId)).forEach((? super T entry) -> {
            Optional modelOptional = ((ModelState)entry.getValue()).getModel();
            if (modelOptional.isPresent()) {
                res.put((String)entry.getKey(), this.memoryTracker.estimateTRCFModelSize((ThresholdedRandomCutForest)modelOptional.get()));
            }
        });
        return res;
    }

    public boolean doesModelExist(String configId) {
        return super.entrySet().stream().filter(entry -> SingleStreamModelIdMapper.getConfigIdForModelId((String)entry.getKey()).equals(configId)).anyMatch(n -> true);
    }

    public boolean hostIfPossible(String modelId, ModelState<RCFModelType> toUpdate) {
        return Optional.ofNullable(toUpdate).filter(state -> state.getModel().isPresent()).filter(state -> this.memoryTracker.isHostingAllowed(modelId, (ThresholdedRandomCutForest)state.getModel().get())).map(state -> {
            super.put(modelId, toUpdate);
            return true;
        }).orElse(false);
    }
}

