/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query.nativelib;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.Bits;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.query.ExactSearcher;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.ResultUtil;
import org.opensearch.knn.index.query.nativelib.DocAndScoreQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

public class NativeEngineKnnVectorQuery
extends Query {
    @Generated
    private static final Logger log = LogManager.getLogger(NativeEngineKnnVectorQuery.class);
    private final KNNQuery knnQuery;

    public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, float boost) throws IOException {
        List<Map<Integer, Float>> perLeafResults;
        IndexReader reader = indexSearcher.getIndexReader();
        KNNWeight knnWeight = (KNNWeight)this.knnQuery.createWeight(indexSearcher, ScoreMode.COMPLETE, 1.0f);
        List leafReaderContexts = reader.leaves();
        RescoreContext rescoreContext = this.knnQuery.getRescoreContext();
        int finalK = this.knnQuery.getK();
        if (rescoreContext == null) {
            perLeafResults = this.doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
        } else {
            int firstPassK = rescoreContext.getFirstPassK(finalK);
            perLeafResults = this.doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
            ResultUtil.reduceToTopK(perLeafResults, firstPassK);
            StopWatch stopWatch = new StopWatch().start();
            perLeafResults = this.doRescore(indexSearcher, leafReaderContexts, knnWeight, perLeafResults, finalK);
            long rescoreTime = stopWatch.stop().totalTime().millis();
            log.debug("Rescoring results took {} ms. oversampled k:{}, segments:{}", (Object)rescoreTime, (Object)firstPassK, (Object)leafReaderContexts.size());
        }
        ResultUtil.reduceToTopK(perLeafResults, finalK);
        TopDocs[] topDocs = new TopDocs[perLeafResults.size()];
        for (int i = 0; i < perLeafResults.size(); ++i) {
            topDocs[i] = ResultUtil.resultMapToTopDocs(perLeafResults.get(i), ((LeafReaderContext)leafReaderContexts.get((int)i)).docBase);
        }
        TopDocs topK = TopDocs.merge((int)this.knnQuery.getK(), (TopDocs[])topDocs);
        if (topK.scoreDocs.length == 0) {
            return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
        }
        return this.createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
    }

    private List<Map<Integer, Float>> doSearch(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, int k) throws IOException {
        ArrayList<Callable<Map>> tasks = new ArrayList<Callable<Map>>(leafReaderContexts.size());
        for (LeafReaderContext leafReaderContext : leafReaderContexts) {
            tasks.add(() -> this.searchLeaf(leafReaderContext, knnWeight, k));
        }
        return indexSearcher.getTaskExecutor().invokeAll(tasks);
    }

    private List<Map<Integer, Float>> doRescore(IndexSearcher indexSearcher, List<LeafReaderContext> leafReaderContexts, KNNWeight knnWeight, List<Map<Integer, Float>> perLeafResults, int k) throws IOException {
        ArrayList<Callable<Map>> rescoreTasks = new ArrayList<Callable<Map>>(leafReaderContexts.size());
        int i = 0;
        while (i < perLeafResults.size()) {
            LeafReaderContext leafReaderContext = leafReaderContexts.get(i);
            int finalI = i++;
            rescoreTasks.add(() -> {
                BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet((Map)perLeafResults.get(finalI));
                ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder().matchedDocs(convertedBitSet).useQuantizedVectorsForSearch(false).k(k).isParentHits(false).knnQuery(this.knnQuery).build();
                return knnWeight.exactSearch(leafReaderContext, exactSearcherContext);
            });
        }
        return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks);
    }

    private Query createDocAndScoreQuery(IndexReader reader, TopDocs topK) {
        int len = topK.scoreDocs.length;
        Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
        int[] docs = new int[len];
        float[] scores = new float[len];
        for (int i = 0; i < len; ++i) {
            docs[i] = topK.scoreDocs[i].doc;
            scores[i] = topK.scoreDocs[i].score;
        }
        int[] segmentStarts = NativeEngineKnnVectorQuery.findSegmentStarts(reader, docs);
        return new DocAndScoreQuery(this.knnQuery.getK(), docs, scores, segmentStarts, reader.getContext().id());
    }

    static int[] findSegmentStarts(IndexReader reader, int[] docs) {
        int[] starts = new int[reader.leaves().size() + 1];
        starts[starts.length - 1] = docs.length;
        if (starts.length == 2) {
            return starts;
        }
        int resultIndex = 0;
        for (int i = 1; i < starts.length - 1; ++i) {
            int upper = ((LeafReaderContext)reader.leaves().get((int)i)).docBase;
            if ((resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper)) < 0) {
                resultIndex = -1 - resultIndex;
            }
            starts[i] = resultIndex;
        }
        return starts;
    }

    private Map<Integer, Float> searchLeaf(LeafReaderContext ctx, KNNWeight queryWeight, int k) throws IOException {
        Map<Integer, Float> leafDocScores = queryWeight.searchLeaf(ctx, k);
        Bits liveDocs = ctx.reader().getLiveDocs();
        if (liveDocs != null) {
            leafDocScores.entrySet().removeIf(entry -> !liveDocs.get(((Integer)entry.getKey()).intValue()));
        }
        return leafDocScores;
    }

    public String toString(String field) {
        return ((Object)((Object)this)).getClass().getSimpleName() + "[" + field + "]..." + KNNQuery.class.getSimpleName() + "[" + this.knnQuery.toString() + "]";
    }

    public void visit(QueryVisitor visitor) {
        visitor.visitLeaf((Query)this);
    }

    public boolean equals(Object obj) {
        if (!this.sameClassAs(obj)) {
            return false;
        }
        return this.knnQuery == ((NativeEngineKnnVectorQuery)((Object)obj)).knnQuery;
    }

    public int hashCode() {
        return Objects.hash(this.classHash(), this.knnQuery.hashCode());
    }

    @Generated
    public KNNQuery getKnnQuery() {
        return this.knnQuery;
    }

    @Generated
    public NativeEngineKnnVectorQuery(KNNQuery knnQuery) {
        this.knnQuery = knnQuery;
    }
}

