/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.codec.KNN80Codec;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.security.AccessController;
import java.util.HashMap;
import java.util.Map;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.DocValuesConsumer;
import org.apache.lucene.codecs.DocValuesProducer;
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.common.StopWatch;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN80Codec.KNN80DocValuesReader;
import org.opensearch.knn.index.codec.transfer.VectorTransfer;
import org.opensearch.knn.index.codec.transfer.VectorTransferByte;
import org.opensearch.knn.index.codec.transfer.VectorTransferFloat;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.plugin.stats.KNNGraphValue;

class KNN80DocValuesConsumer
extends DocValuesConsumer
implements Closeable {
    @Generated
    private static final Logger log = LogManager.getLogger(KNN80DocValuesConsumer.class);
    private final Logger logger = LogManager.getLogger(KNN80DocValuesConsumer.class);
    private final DocValuesConsumer delegatee;
    private final SegmentWriteState state;
    private static final Long CRC32_CHECKSUM_SANITY = -4294967296L;

    KNN80DocValuesConsumer(DocValuesConsumer delegatee, SegmentWriteState state) {
        this.delegatee = delegatee;
        this.state = state;
    }

    public void addBinaryField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
        this.delegatee.addBinaryField(field, valuesProducer);
        if (this.isKNNBinaryFieldRequired(field)) {
            StopWatch stopWatch = new StopWatch();
            stopWatch.start();
            this.addKNNBinaryField(field, valuesProducer, false, true);
            stopWatch.stop();
            long time_in_millis = stopWatch.totalTime().millis();
            KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
            this.logger.warn("Refresh operation complete in " + time_in_millis + " ms");
        }
    }

    private boolean isKNNBinaryFieldRequired(FieldInfo field) {
        KNNEngine knnEngine = this.getKNNEngine(field);
        log.debug(String.format("Read engine [%s] for field [%s]", knnEngine.getName(), field.getName()));
        return field.attributes().containsKey("knn_field") && KNNEngine.getEnginesThatCreateCustomSegmentFiles().stream().anyMatch(engine -> engine == knnEngine);
    }

    private KNNEngine getKNNEngine(@NonNull FieldInfo field) {
        if (field == null) {
            throw new NullPointerException("field is marked non-null but is null");
        }
        String modelId = (String)field.attributes().get("model_id");
        if (modelId != null) {
            Model model = ModelCache.getInstance().get(modelId);
            return model.getModelMetadata().getKnnEngine();
        }
        String engineName = field.attributes().getOrDefault("engine", KNNEngine.DEFAULT.getName());
        return KNNEngine.getEngine(engineName);
    }

    public void addKNNBinaryField(FieldInfo field, DocValuesProducer valuesProducer, boolean isMerge, boolean isRefresh) throws IOException {
        NativeIndexCreator indexCreator;
        KNNCodecUtil.Pair pair;
        BinaryDocValues values = valuesProducer.getBinary(field);
        KNNEngine knnEngine = this.getKNNEngine(field);
        String engineFileName = KNNCodecUtil.buildEngineFileName(this.state.segmentInfo.name, knnEngine.getVersion(), field.name, knnEngine.getExtension());
        String indexPath = Paths.get(((FSDirectory)FilterDirectory.unwrap((Directory)this.state.directory)).getDirectory().toString(), engineFileName).toString();
        Map fieldAttributes = field.attributes();
        if (fieldAttributes.containsKey("model_id")) {
            String modelId = (String)fieldAttributes.get("model_id");
            Model model = ModelCache.getInstance().get(modelId);
            if (model.getModelBlob() == null) {
                throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
            }
            VectorDataType vectorDataType = model.getModelMetadata().getVectorDataType();
            pair = KNNCodecUtil.getPair(values, this.getVectorTransfer(vectorDataType));
            indexCreator = () -> this.createKNNIndexFromTemplate(model, pair, knnEngine, indexPath);
        } else {
            VectorDataType vectorDataType = VectorDataType.get(fieldAttributes.getOrDefault("data_type", VectorDataType.DEFAULT.getValue()));
            pair = KNNCodecUtil.getPair(values, this.getVectorTransfer(vectorDataType));
            indexCreator = () -> this.createKNNIndexFromScratch(field, pair, knnEngine, indexPath);
        }
        if (pair.getVectorAddress() == 0L || pair.docs.length == 0) {
            this.logger.info("Skipping engine index creation as there are no vectors or docs in the segment");
            return;
        }
        long arraySize = KNNCodecUtil.calculateArraySize(pair.docs.length, pair.getDimension(), pair.serializationMode);
        if (isMerge) {
            KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
            KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(pair.docs.length);
            KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(arraySize);
            this.recordMergeStats(pair.docs.length, arraySize);
        }
        KNNCounter.GRAPH_INDEX_REQUESTS.increment();
        if (isRefresh) {
            this.recordRefreshStats();
        }
        this.state.directory.createOutput(engineFileName, this.state.context).close();
        indexCreator.createIndex();
        this.writeFooter(indexPath, engineFileName);
    }

    private void recordMergeStats(int length, long arraySize) {
        KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement();
        KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(length);
        KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize);
        KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment();
        KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(length);
        KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(arraySize);
    }

    private void recordRefreshStats() {
        KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
    }

    private void createKNNIndexFromTemplate(Model model, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("indexThreadQty", KNNSettings.state().getSettingValue("knn.algo_param.index_thread_qty"));
        IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
        AccessController.doPrivileged(() -> {
            JNIService.createIndexFromTemplate(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, model.getModelBlob(), parameters, knnEngine);
            return null;
        });
    }

    private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pair, KNNEngine knnEngine, String indexPath) throws IOException {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        Map fieldAttributes = fieldInfo.attributes();
        String parametersString = (String)fieldAttributes.get("parameters");
        if (parametersString == null) {
            String m;
            parameters.put("spaceType", fieldAttributes.getOrDefault("spaceType", SpaceType.DEFAULT.getValue()));
            String efConstruction = (String)fieldAttributes.get("efConstruction");
            HashMap<String, Integer> algoParams = new HashMap<String, Integer>();
            if (efConstruction != null) {
                algoParams.put("ef_construction", Integer.parseInt(efConstruction));
            }
            if ((m = (String)fieldAttributes.get("M")) != null) {
                algoParams.put("m", Integer.parseInt(m));
            }
            parameters.put("parameters", algoParams);
        } else {
            parameters.putAll(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString).map());
        }
        if (KNNEngine.FAISS == knnEngine && VectorDataType.BINARY.getValue().equals(fieldAttributes.getOrDefault("data_type", VectorDataType.DEFAULT.getValue())) && parameters.get("index_description") != null) {
            parameters.put("index_description", "B" + parameters.get("index_description").toString());
            IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
        }
        parameters.put("indexThreadQty", KNNSettings.state().getSettingValue("knn.algo_param.index_thread_qty"));
        AccessController.doPrivileged(() -> {
            JNIService.createIndex(pair.docs, pair.getVectorAddress(), pair.getDimension(), indexPath, parameters, knnEngine);
            return null;
        });
    }

    public void merge(MergeState mergeState) {
        try {
            this.delegatee.merge(mergeState);
            assert (mergeState != null);
            assert (mergeState.mergeFieldInfos != null);
            for (FieldInfo fieldInfo : mergeState.mergeFieldInfos) {
                DocValuesType type = fieldInfo.getDocValuesType();
                if (type != DocValuesType.BINARY || !fieldInfo.attributes().containsKey("knn_field")) continue;
                StopWatch stopWatch = new StopWatch();
                stopWatch.start();
                this.addKNNBinaryField(fieldInfo, (DocValuesProducer)new KNN80DocValuesReader(mergeState), true, false);
                stopWatch.stop();
                long time_in_millis = stopWatch.totalTime().millis();
                KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.set(KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS.getValue() + time_in_millis);
                this.logger.warn("Merge operation complete in " + time_in_millis + " ms");
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public void addSortedSetField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
        this.delegatee.addSortedSetField(field, valuesProducer);
    }

    public void addSortedNumericField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
        this.delegatee.addSortedNumericField(field, valuesProducer);
    }

    public void addSortedField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
        this.delegatee.addSortedField(field, valuesProducer);
    }

    public void addNumericField(FieldInfo field, DocValuesProducer valuesProducer) throws IOException {
        this.delegatee.addNumericField(field, valuesProducer);
    }

    @Override
    public void close() throws IOException {
        this.delegatee.close();
    }

    private void writeFooter(String indexPath, String engineFileName) throws IOException {
        OutputStream os = Files.newOutputStream(Paths.get(indexPath, new String[0]), StandardOpenOption.APPEND);
        ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN);
        byteBuffer.putInt(-1071082520);
        byteBuffer.putInt(0);
        os.write(byteBuffer.array());
        os.flush();
        ChecksumIndexInput checksumIndexInput = this.state.directory.openChecksumInput(engineFileName, this.state.context);
        checksumIndexInput.seek(checksumIndexInput.length());
        long value = checksumIndexInput.getChecksum();
        checksumIndexInput.close();
        if (this.isChecksumValid(value)) {
            throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")");
        }
        byteBuffer.putLong(0, value);
        os.write(byteBuffer.array());
        os.close();
    }

    private boolean isChecksumValid(long value) {
        return (value & CRC32_CHECKSUM_SANITY) != 0L;
    }

    private VectorTransfer getVectorTransfer(VectorDataType vectorDataType) {
        if (VectorDataType.BINARY == vectorDataType) {
            return new VectorTransferByte(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
        }
        return new VectorTransferFloat(KNNSettings.getVectorStreamingMemoryLimit().getBytes());
    }

    @FunctionalInterface
    private static interface NativeIndexCreator {
        public void createIndex() throws IOException;
    }
}

