/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.codec.nativeindex;

import java.io.IOException;
import java.io.OutputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.HashMap;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.store.FilterDirectory;
import org.opensearch.common.Nullable;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNVectorUtil;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.nativeindex.DefaultIndexBuildStrategy;
import org.opensearch.knn.index.codec.nativeindex.MemOptimizedNativeIndexBuildStrategy;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexBuildStrategy;
import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams;
import org.opensearch.knn.index.codec.util.KNNCodecUtil;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.engine.qframe.QuantizationConfig;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.indices.Model;
import org.opensearch.knn.indices.ModelCache;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

public class NativeIndexWriter {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeIndexWriter.class);
    private static final Long CRC32_CHECKSUM_SANITY = -4294967296L;
    private final SegmentWriteState state;
    private final FieldInfo fieldInfo;
    private final NativeIndexBuildStrategy indexBuilder;
    @Nullable
    private final QuantizationState quantizationState;

    public static NativeIndexWriter getWriter(FieldInfo fieldInfo, SegmentWriteState state) {
        return NativeIndexWriter.createWriter(fieldInfo, state, null);
    }

    public static NativeIndexWriter getWriter(FieldInfo fieldInfo, SegmentWriteState state, QuantizationState quantizationState) {
        return NativeIndexWriter.createWriter(fieldInfo, state, quantizationState);
    }

    public void flushIndex(KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
        KNNVectorUtil.iterateVectorValuesOnce(knnVectorValues);
        this.buildAndWriteIndex(knnVectorValues, totalLiveDocs);
        this.recordRefreshStats();
    }

    public void mergeIndex(KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
        KNNVectorUtil.iterateVectorValuesOnce(knnVectorValues);
        if (knnVectorValues.docId() == Integer.MAX_VALUE) {
            log.debug("Skipping mergeIndex, vector values are already iterated for {}", (Object)this.fieldInfo.name);
            return;
        }
        long bytesPerVector = knnVectorValues.bytesPerVector();
        this.startMergeStats(totalLiveDocs, bytesPerVector);
        this.buildAndWriteIndex(knnVectorValues, totalLiveDocs);
        this.endMergeStats(totalLiveDocs, bytesPerVector);
    }

    private void buildAndWriteIndex(KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
        if (totalLiveDocs == 0) {
            log.debug("No live docs for field " + this.fieldInfo.name);
            return;
        }
        KNNEngine knnEngine = FieldInfoExtractor.extractKNNEngine(this.fieldInfo);
        String engineFileName = KNNCodecUtil.buildEngineFileName(this.state.segmentInfo.name, knnEngine.getVersion(), this.fieldInfo.name, knnEngine.getExtension());
        String indexPath = Paths.get(((FSDirectory)FilterDirectory.unwrap((Directory)this.state.directory)).getDirectory().toString(), engineFileName).toString();
        this.state.directory.createOutput(engineFileName, this.state.context).close();
        BuildIndexParams nativeIndexParams = this.indexParams(this.fieldInfo, indexPath, knnEngine, knnVectorValues, totalLiveDocs);
        this.indexBuilder.buildAndWriteIndex(nativeIndexParams);
        this.writeFooter(indexPath, engineFileName, this.state);
    }

    private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine, KNNVectorValues<?> vectorValues, int totalLiveDocs) throws IOException {
        Map<String, Object> parameters;
        VectorDataType vectorDataType = this.quantizationState != null ? QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo) : FieldInfoExtractor.extractVectorDataType(fieldInfo);
        if (fieldInfo.attributes().containsKey("model_id")) {
            Model model = this.getModel(fieldInfo);
            parameters = this.getTemplateParameters(fieldInfo, model);
        } else {
            parameters = this.getParameters(fieldInfo, vectorDataType, knnEngine);
        }
        return BuildIndexParams.builder().fieldName(fieldInfo.name).parameters(parameters).vectorDataType(vectorDataType).knnEngine(knnEngine).indexPath(indexPath).quantizationState(this.quantizationState).vectorValues(vectorValues).totalLiveDocs(totalLiveDocs).build();
    }

    private Map<String, Object> getParameters(FieldInfo fieldInfo, VectorDataType vectorDataType, KNNEngine knnEngine) throws IOException {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        Map fieldAttributes = fieldInfo.attributes();
        String parametersString = (String)fieldAttributes.get("parameters");
        if (parametersString == null) {
            String m;
            parameters.put("spaceType", fieldAttributes.getOrDefault("spaceType", SpaceType.DEFAULT.getValue()));
            String efConstruction = (String)fieldAttributes.get("efConstruction");
            HashMap<String, Integer> algoParams = new HashMap<String, Integer>();
            if (efConstruction != null) {
                algoParams.put("ef_construction", Integer.parseInt(efConstruction));
            }
            if ((m = (String)fieldAttributes.get("M")) != null) {
                algoParams.put("m", Integer.parseInt(m));
            }
            parameters.put("parameters", algoParams);
        } else {
            parameters.putAll(XContentHelper.createParser((NamedXContentRegistry)NamedXContentRegistry.EMPTY, (DeprecationHandler)DeprecationHandler.THROW_UNSUPPORTED_OPERATION, (BytesReference)new BytesArray(parametersString), (MediaType)MediaTypeRegistry.getDefaultMediaType()).map());
        }
        parameters.put("data_type", vectorDataType.getValue());
        this.maybeAddBinaryPrefixForFaissBWC(knnEngine, parameters, fieldAttributes);
        parameters.put("indexThreadQty", KNNSettings.state().getSettingValue("knn.algo_param.index_thread_qty"));
        return parameters;
    }

    private void maybeAddBinaryPrefixForFaissBWC(KNNEngine knnEngine, Map<String, Object> parameters, Map<String, String> fieldAttributes) {
        if (KNNEngine.FAISS != knnEngine) {
            return;
        }
        if (!VectorDataType.BINARY.getValue().equals(fieldAttributes.getOrDefault("data_type", VectorDataType.DEFAULT.getValue()))) {
            return;
        }
        if (parameters.get("index_description") == null) {
            return;
        }
        if (parameters.get("index_description").toString().startsWith("B")) {
            return;
        }
        parameters.put("index_description", "B" + parameters.get("index_description").toString());
        IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
    }

    private Map<String, Object> getTemplateParameters(FieldInfo fieldInfo, Model model) throws IOException {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("indexThreadQty", KNNSettings.state().getSettingValue("knn.algo_param.index_thread_qty"));
        parameters.put("model_id", fieldInfo.attributes().get("model_id"));
        parameters.put("model_blob", model.getModelBlob());
        if (FieldInfoExtractor.extractQuantizationConfig(fieldInfo) != QuantizationConfig.EMPTY) {
            IndexUtil.updateVectorDataTypeToParameters(parameters, VectorDataType.BINARY);
        } else {
            IndexUtil.updateVectorDataTypeToParameters(parameters, model.getModelMetadata().getVectorDataType());
        }
        return parameters;
    }

    private Model getModel(FieldInfo fieldInfo) {
        String modelId = (String)fieldInfo.attributes().get("model_id");
        Model model = ModelCache.getInstance().get(modelId);
        if (model.getModelBlob() == null) {
            throw new RuntimeException(String.format("There is no trained model with id \"%s\"", modelId));
        }
        return model;
    }

    private void startMergeStats(int numDocs, long bytesPerVector) {
        KNNGraphValue.MERGE_CURRENT_OPERATIONS.increment();
        KNNGraphValue.MERGE_CURRENT_DOCS.incrementBy(numDocs);
        KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.incrementBy(bytesPerVector);
        KNNGraphValue.MERGE_TOTAL_OPERATIONS.increment();
        KNNGraphValue.MERGE_TOTAL_DOCS.incrementBy(numDocs);
        KNNGraphValue.MERGE_TOTAL_SIZE_IN_BYTES.incrementBy(bytesPerVector);
    }

    private void endMergeStats(int numDocs, long arraySize) {
        KNNGraphValue.MERGE_CURRENT_OPERATIONS.decrement();
        KNNGraphValue.MERGE_CURRENT_DOCS.decrementBy(numDocs);
        KNNGraphValue.MERGE_CURRENT_SIZE_IN_BYTES.decrementBy(arraySize);
    }

    private void recordRefreshStats() {
        KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment();
    }

    private boolean isChecksumValid(long value) {
        return (value & CRC32_CHECKSUM_SANITY) != 0L;
    }

    private void writeFooter(String indexPath, String engineFileName, SegmentWriteState state) throws IOException {
        OutputStream os = Files.newOutputStream(Paths.get(indexPath, new String[0]), StandardOpenOption.APPEND);
        ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN);
        byteBuffer.putInt(-1071082520);
        byteBuffer.putInt(0);
        os.write(byteBuffer.array());
        os.flush();
        ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context);
        checksumIndexInput.seek(checksumIndexInput.length());
        long value = checksumIndexInput.getChecksum();
        checksumIndexInput.close();
        if (this.isChecksumValid(value)) {
            throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")");
        }
        byteBuffer.putLong(0, value);
        os.write(byteBuffer.array());
        os.close();
    }

    private static NativeIndexWriter createWriter(FieldInfo fieldInfo, SegmentWriteState state, @Nullable QuantizationState quantizationState) {
        KNNEngine knnEngine = FieldInfoExtractor.extractKNNEngine(fieldInfo);
        boolean isTemplate = fieldInfo.attributes().containsKey("model_id");
        boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine;
        NativeIndexBuildStrategy strategy = iterative ? MemOptimizedNativeIndexBuildStrategy.getInstance() : DefaultIndexBuildStrategy.getInstance();
        return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState);
    }

    @Generated
    public NativeIndexWriter(SegmentWriteState state, FieldInfo fieldInfo, NativeIndexBuildStrategy indexBuilder, QuantizationState quantizationState) {
        this.state = state;
        this.fieldInfo = fieldInfo;
        this.indexBuilder = indexBuilder;
        this.quantizationState = quantizationState;
    }
}

