/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.codec.KNN990Codec;

import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNN990Codec.KNN990QuantizationStateWriter;
import org.opensearch.knn.index.codec.KNN990Codec.NativeEngineFieldVectorsWriter;
import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;

public class NativeEngines990KnnVectorsWriter
extends KnnVectorsWriter {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeEngines990KnnVectorsWriter.class);
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(NativeEngines990KnnVectorsWriter.class);
    private static final String FLUSH_OPERATION = "flush";
    private static final String MERGE_OPERATION = "merge";
    private final SegmentWriteState segmentWriteState;
    private final FlatVectorsWriter flatVectorsWriter;
    private KNN990QuantizationStateWriter quantizationStateWriter;
    private final List<NativeEngineFieldVectorsWriter<?>> fields = new ArrayList();
    private boolean finished;
    private final QuantizationService quantizationService = QuantizationService.getInstance();

    public NativeEngines990KnnVectorsWriter(SegmentWriteState segmentWriteState, FlatVectorsWriter flatVectorsWriter) {
        this.segmentWriteState = segmentWriteState;
        this.flatVectorsWriter = flatVectorsWriter;
    }

    public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
        NativeEngineFieldVectorsWriter<?> newField = NativeEngineFieldVectorsWriter.create(fieldInfo, this.segmentWriteState.infoStream);
        this.fields.add(newField);
        return this.flatVectorsWriter.addField(fieldInfo, newField);
    }

    public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
        this.flatVectorsWriter.flush(maxDoc, sortMap);
        for (NativeEngineFieldVectorsWriter<?> field : this.fields) {
            this.trainAndIndex(field.getFieldInfo(), (VectorValuesRetriever)(vectorDataType, fieldInfo, fieldVectorsWriter) -> this.getKNNVectorValues((VectorDataType)((Object)vectorDataType), (NativeEngineFieldVectorsWriter<?>)((Object)fieldVectorsWriter)), NativeIndexWriter::flushIndex, (Object)field, KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS, FLUSH_OPERATION);
        }
    }

    public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
        this.flatVectorsWriter.mergeOneField(fieldInfo, mergeState);
        this.trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState, KNNGraphValue.MERGE_TOTAL_TIME_IN_MILLIS, MERGE_OPERATION);
    }

    public void finish() throws IOException {
        if (this.finished) {
            throw new IllegalStateException("NativeEnginesKNNVectorsWriter is already finished");
        }
        this.finished = true;
        if (this.quantizationStateWriter != null) {
            this.quantizationStateWriter.writeFooter();
        }
        this.flatVectorsWriter.finish();
    }

    public void close() throws IOException {
        if (this.quantizationStateWriter != null) {
            this.quantizationStateWriter.closeOutput();
        }
        IOUtils.close((Closeable[])new Closeable[]{this.flatVectorsWriter});
    }

    public long ramBytesUsed() {
        return SHALLOW_SIZE + this.flatVectorsWriter.ramBytesUsed() + this.fields.stream().mapToLong(NativeEngineFieldVectorsWriter::ramBytesUsed).sum();
    }

    private <T> KNNVectorValues<T> getKNNVectorValues(VectorDataType vectorDataType, NativeEngineFieldVectorsWriter<?> field) {
        return KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors());
    }

    private <T> KNNVectorValues<T> getKNNVectorValuesForMerge(VectorDataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException {
        switch (fieldInfo.getVectorEncoding()) {
            case FLOAT32: {
                FloatVectorValues mergedFloats = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues((FieldInfo)fieldInfo, (MergeState)mergeState);
                return KNNVectorValuesFactory.getVectorValues(vectorDataType, (DocIdSetIterator)mergedFloats);
            }
            case BYTE: {
                ByteVectorValues mergedBytes = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues((FieldInfo)fieldInfo, (MergeState)mergeState);
                return KNNVectorValuesFactory.getVectorValues(vectorDataType, (DocIdSetIterator)mergedBytes);
            }
        }
        throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]");
    }

    private <T, C> void trainAndIndex(FieldInfo fieldInfo, VectorValuesRetriever<VectorDataType, FieldInfo, C, KNNVectorValues<T>> vectorValuesRetriever, IndexOperation<T> indexOperation, C VectorProcessingContext, KNNGraphValue graphBuildTime, String operationName) throws IOException {
        VectorDataType vectorDataType = FieldInfoExtractor.extractVectorDataType(fieldInfo);
        KNNVectorValues<T> knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
        QuantizationParams quantizationParams = this.quantizationService.getQuantizationParams(fieldInfo);
        QuantizationState quantizationState = null;
        int totalLiveDocs = this.getLiveDocs(vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext));
        if (quantizationParams != null && totalLiveDocs > 0) {
            this.initQuantizationStateWriterIfNecessary();
            quantizationState = this.quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
            this.quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
        }
        NativeIndexWriter writer = quantizationParams != null ? NativeIndexWriter.getWriter(fieldInfo, this.segmentWriteState, quantizationState) : NativeIndexWriter.getWriter(fieldInfo, this.segmentWriteState);
        knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext);
        StopWatch stopWatch = new StopWatch();
        stopWatch.start();
        indexOperation.buildAndWrite(writer, knnVectorValues, totalLiveDocs);
        long time_in_millis = stopWatch.totalTime().millis();
        graphBuildTime.incrementBy(time_in_millis);
        log.warn("Graph build took " + time_in_millis + " ms for " + operationName);
    }

    private int getLiveDocs(KNNVectorValues<?> vectorValues) throws IOException {
        int liveDocs = 0;
        while (vectorValues.nextDoc() != Integer.MAX_VALUE) {
            ++liveDocs;
        }
        return liveDocs;
    }

    private void initQuantizationStateWriterIfNecessary() throws IOException {
        if (this.quantizationStateWriter == null) {
            this.quantizationStateWriter = new KNN990QuantizationStateWriter(this.segmentWriteState);
            this.quantizationStateWriter.writeHeader(this.segmentWriteState);
        }
    }

    @FunctionalInterface
    private static interface VectorValuesRetriever<DataType, FieldInfo, MergeState, Result> {
        public Result apply(DataType var1, FieldInfo var2, MergeState var3) throws IOException;
    }

    @FunctionalInterface
    private static interface IndexOperation<T> {
        public void buildAndWrite(NativeIndexWriter var1, KNNVectorValues<T> var2, int var3) throws IOException;
    }
}

