/*
 * Decompiled with CFR 0.152.
 */
package com.amazon.randomcutforest.serialize.json.v1;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.config.Precision;
import com.amazon.randomcutforest.serialize.json.v1.V1SerializedRandomCutForest;
import com.amazon.randomcutforest.state.ExecutionContext;
import com.amazon.randomcutforest.state.RandomCutForestState;
import com.amazon.randomcutforest.state.sampler.CompactSamplerState;
import com.amazon.randomcutforest.state.store.PointStoreMapper;
import com.amazon.randomcutforest.state.store.PointStoreState;
import com.amazon.randomcutforest.store.IPointStore;
import com.amazon.randomcutforest.store.IPointStoreView;
import com.amazon.randomcutforest.store.PointStore;
import com.amazon.randomcutforest.tree.ITree;
import com.amazon.randomcutforest.tree.RandomCutTree;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import java.io.Reader;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

public class V1JsonToV3StateConverter {
    private final ObjectMapper mapper = new ObjectMapper();

    public RandomCutForestState convert(String json, Precision precision) throws IOException {
        CommonUtils.checkArgument((precision == Precision.FLOAT_32 ? 1 : 0) != 0, (String)"float 64 is deprecated in v3");
        V1SerializedRandomCutForest forest = (V1SerializedRandomCutForest)this.mapper.readValue(json, V1SerializedRandomCutForest.class);
        return this.convert(forest, precision);
    }

    public Optional<RandomCutForestState> convert(ArrayList<String> jsons, int numberOfTrees, Precision precision) throws IOException {
        ArrayList<V1SerializedRandomCutForest> forests = new ArrayList<V1SerializedRandomCutForest>(jsons.size());
        int sum = 0;
        for (int i = 0; i < jsons.size(); ++i) {
            V1SerializedRandomCutForest forest = (V1SerializedRandomCutForest)this.mapper.readValue(jsons.get(i), V1SerializedRandomCutForest.class);
            forests.add(forest);
            sum += forest.getNumberOfTrees();
        }
        if (sum < numberOfTrees) {
            return Optional.empty();
        }
        return Optional.ofNullable(this.convert((List<V1SerializedRandomCutForest>)forests, numberOfTrees, precision));
    }

    public RandomCutForestState convert(Reader reader, Precision precision) throws IOException {
        CommonUtils.checkArgument((precision == Precision.FLOAT_32 ? 1 : 0) != 0, (String)"float 64 is deprecated in v3");
        V1SerializedRandomCutForest forest = (V1SerializedRandomCutForest)this.mapper.readValue(reader, V1SerializedRandomCutForest.class);
        return this.convert(forest, precision);
    }

    public RandomCutForestState convert(URL url, Precision precision) throws IOException {
        CommonUtils.checkArgument((precision == Precision.FLOAT_32 ? 1 : 0) != 0, (String)"float 64 is deprecated in v3");
        V1SerializedRandomCutForest forest = (V1SerializedRandomCutForest)this.mapper.readValue(url, V1SerializedRandomCutForest.class);
        return this.convert(forest, precision);
    }

    public RandomCutForestState convert(V1SerializedRandomCutForest serializedForest, Precision precision) {
        return this.convert(Collections.singletonList(serializedForest), serializedForest.getNumberOfTrees(), precision);
    }

    public RandomCutForestState convert(List<V1SerializedRandomCutForest> serializedForests, int numberOfTrees, Precision precision) {
        CommonUtils.checkArgument((serializedForests.size() > 0 ? 1 : 0) != 0, (String)"incorrect usage of convert");
        CommonUtils.checkArgument((numberOfTrees > 0 ? 1 : 0) != 0, (String)"incorrect parameter");
        int sum = 0;
        for (int i = 0; i < serializedForests.size(); ++i) {
            sum += serializedForests.get(i).getNumberOfTrees();
        }
        CommonUtils.checkArgument((sum >= numberOfTrees ? 1 : 0) != 0, (String)"incorrect parameters");
        RandomCutForestState state = new RandomCutForestState();
        state.setNumberOfTrees(numberOfTrees);
        state.setDimensions(serializedForests.get(0).getDimensions());
        state.setTimeDecay(serializedForests.get(0).getLambda());
        state.setSampleSize(serializedForests.get(0).getSampleSize());
        state.setShingleSize(1);
        state.setCenterOfMassEnabled(serializedForests.get(0).isCenterOfMassEnabled());
        state.setOutputAfter(serializedForests.get(0).getOutputAfter());
        state.setStoreSequenceIndexesEnabled(serializedForests.get(0).isStoreSequenceIndexesEnabled());
        state.setTotalUpdates(serializedForests.get(0).getExecutor().getExecutor().getTotalUpdates());
        state.setCompact(true);
        state.setInternalShinglingEnabled(false);
        state.setBoundingBoxCacheFraction(1.0);
        state.setSaveSamplerStateEnabled(true);
        state.setSaveTreeStateEnabled(false);
        state.setSaveCoordinatorStateEnabled(true);
        state.setPrecision(precision.name());
        state.setCompressed(false);
        state.setPartialTreeState(false);
        ExecutionContext executionContext = new ExecutionContext();
        executionContext.setParallelExecutionEnabled(serializedForests.get(0).isParallelExecutionEnabled());
        executionContext.setThreadPoolSize(serializedForests.get(0).getThreadPoolSize());
        state.setExecutionContext(executionContext);
        SamplerConverter samplerConverter = new SamplerConverter(state.getDimensions(), state.getNumberOfTrees() * state.getSampleSize() + 1, precision, numberOfTrees);
        serializedForests.stream().flatMap(f -> Arrays.stream(f.getExecutor().getExecutor().getTreeUpdaters())).limit(numberOfTrees).map(V1SerializedRandomCutForest.TreeUpdater::getSampler).forEach(samplerConverter::addSampler);
        state.setPointStoreState(samplerConverter.getPointStoreState(precision));
        state.setCompactSamplerStates(samplerConverter.compactSamplerStates);
        return state;
    }

    static class SamplerConverter {
        private final IPointStore pointStore;
        private final List<CompactSamplerState> compactSamplerStates;
        private final Precision precision;
        private final ITree globalTree;
        private final int maxNumberOfTrees;

        public SamplerConverter(int dimensions, int capacity, Precision precision, int maxNumberOfTrees) {
            this.pointStore = PointStore.builder().dimensions(dimensions).capacity(capacity).shingleSize(1).initialSize(capacity).build();
            this.globalTree = new RandomCutTree.Builder().pointStoreView((IPointStoreView)this.pointStore).capacity(this.pointStore.getCapacity() + 1).storeSequenceIndexesEnabled(false).centerOfMassEnabled(false).boundingBoxCacheFraction(1.0).build();
            this.compactSamplerStates = new ArrayList<CompactSamplerState>();
            this.maxNumberOfTrees = maxNumberOfTrees;
            this.precision = precision;
        }

        public PointStoreState getPointStoreState(Precision precision) {
            return new PointStoreMapper().toState((PointStore)this.pointStore);
        }

        public void addSampler(V1SerializedRandomCutForest.Sampler sampler) {
            if (this.compactSamplerStates.size() < this.maxNumberOfTrees) {
                V1SerializedRandomCutForest.WeightedSamples[] samples = sampler.getWeightedSamples();
                int[] pointIndex = new int[samples.length];
                float[] weight = new float[samples.length];
                long[] sequenceIndex = new long[samples.length];
                for (int i = 0; i < samples.length; ++i) {
                    V1SerializedRandomCutForest.WeightedSamples sample = samples[i];
                    double[] point = sample.getPoint();
                    int index = this.pointStore.add(point, sample.getSequenceIndex());
                    pointIndex[i] = (Integer)this.globalTree.addPoint((Object)index, 0L);
                    if (pointIndex[i] != index) {
                        this.pointStore.incrementRefCount(pointIndex[i]);
                        this.pointStore.decrementRefCount(index);
                    }
                    weight[i] = (float)sample.getWeight();
                    sequenceIndex[i] = sample.getSequenceIndex();
                }
                CompactSamplerState samplerState = new CompactSamplerState();
                samplerState.setSize(samples.length);
                samplerState.setCapacity(sampler.getSampleSize());
                samplerState.setTimeDecay(sampler.getLambda());
                samplerState.setPointIndex(pointIndex);
                samplerState.setWeight(weight);
                samplerState.setSequenceIndex(sequenceIndex);
                samplerState.setSequenceIndexOfMostRecentTimeDecayUpdate(0L);
                samplerState.setMaxSequenceIndex(sampler.getEntriesSeen());
                samplerState.setInitialAcceptFraction(1.0);
                this.compactSamplerStates.add(samplerState);
            }
        }
    }
}

