/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.opensearch.common.Nullable;
import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;
import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.MultiCollectorWrapper;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.sort.SortAndFormats;

public abstract class HybridCollectorManager
implements CollectorManager<Collector, ReduceableSearchResult> {
    private final int numHits;
    private final HitsThresholdChecker hitsThresholdChecker;
    private final boolean isSingleShard;
    private final int trackTotalHitsUpTo;
    private final SortAndFormats sortAndFormats;
    @Nullable
    private final Weight filterWeight;
    private static final float boost_factor = 1.0f;

    public static CollectorManager createHybridCollectorManager(SearchContext searchContext) throws IOException {
        IndexReader reader = searchContext.searcher().getIndexReader();
        int totalNumDocs = Math.max(0, reader.numDocs());
        boolean isSingleShard = searchContext.numberOfShards() == 1;
        int numDocs = Math.min(searchContext.from() + searchContext.size(), totalNumDocs);
        int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo();
        Weight filteringWeight = null;
        if (Objects.nonNull(searchContext.parsedPostFilter()) && Objects.nonNull(searchContext.parsedPostFilter().query())) {
            Query filterQuery = searchContext.parsedPostFilter().query();
            ContextIndexSearcher searcher = searchContext.searcher();
            filteringWeight = searcher.createWeight(searcher.rewrite(filterQuery), ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        }
        return searchContext.shouldUseConcurrentSearch() ? new HybridCollectorConcurrentSearchManager(numDocs, new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), isSingleShard, trackTotalHitsUpTo, searchContext.sort(), filteringWeight) : new HybridCollectorNonConcurrentManager(numDocs, new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())), isSingleShard, trackTotalHitsUpTo, searchContext.sort(), filteringWeight);
    }

    public Collector newCollector() {
        HybridTopScoreDocCollector hybridcollector = new HybridTopScoreDocCollector(this.numHits, this.hitsThresholdChecker);
        return Objects.nonNull(this.filterWeight) ? new FilteredCollector((Collector)hybridcollector, this.filterWeight) : hybridcollector;
    }

    public ReduceableSearchResult reduce(Collection<Collector> collectors) {
        ArrayList<HybridTopScoreDocCollector> hybridTopScoreDocCollectors = new ArrayList<HybridTopScoreDocCollector>();
        for (Collector collector : collectors) {
            if (collector instanceof MultiCollectorWrapper) {
                for (Collector sub : ((MultiCollectorWrapper)collector).getCollectors()) {
                    if (!(sub instanceof HybridTopScoreDocCollector)) continue;
                    hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector)sub);
                }
                continue;
            }
            if (collector instanceof HybridTopScoreDocCollector) {
                hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector)collector);
                continue;
            }
            if (!(collector instanceof FilteredCollector) || !(((FilteredCollector)collector).getCollector() instanceof HybridTopScoreDocCollector)) continue;
            hybridTopScoreDocCollectors.add((HybridTopScoreDocCollector)((FilteredCollector)collector).getCollector());
        }
        if (!hybridTopScoreDocCollectors.isEmpty()) {
            HybridTopScoreDocCollector hybridTopScoreDocCollector = (HybridTopScoreDocCollector)hybridTopScoreDocCollectors.stream().findFirst().orElseThrow(() -> new IllegalStateException("cannot collect results of hybrid search query"));
            List<TopDocs> topDocs = hybridTopScoreDocCollector.topDocs();
            TopDocs newTopDocs = this.getNewTopDocs(this.getTotalHits(this.trackTotalHitsUpTo, topDocs, this.isSingleShard), topDocs);
            float maxScore = this.getMaxScore(topDocs);
            TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
            return result -> result.topDocs(topDocsAndMaxScore, this.getSortValueFormats(this.sortAndFormats));
        }
        throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
    }

    private TopDocs getNewTopDocs(TotalHits totalHits, List<TopDocs> topDocs) {
        ScoreDoc[] scoreDocs = new ScoreDoc[]{};
        if (Objects.nonNull(topDocs)) {
            int delimiterDocId = topDocs.stream().filter(Objects::nonNull).filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)).map(topDoc -> topDoc.scoreDocs).filter(scoreDoc -> ((ScoreDoc[])scoreDoc).length > 0).map(scoreDoc -> scoreDoc[0].doc).findFirst().orElse(-1);
            if (delimiterDocId == -1) {
                return new TopDocs(totalHits, scoreDocs);
            }
            ArrayList<ScoreDoc> result = new ArrayList<ScoreDoc>();
            result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
            for (TopDocs topDoc2 : topDocs) {
                if (Objects.isNull(topDoc2) || Objects.isNull(topDoc2.scoreDocs)) {
                    result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                    continue;
                }
                result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                result.addAll(Arrays.asList(topDoc2.scoreDocs));
            }
            result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
            scoreDocs = (ScoreDoc[])result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
        }
        return new TopDocs(totalHits, scoreDocs);
    }

    private TotalHits getTotalHits(int trackTotalHitsUpTo, List<TopDocs> topDocs, boolean isSingleShard) {
        TotalHits.Relation relation;
        TotalHits.Relation relation2 = relation = trackTotalHitsUpTo == -1 ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO;
        if (topDocs == null || topDocs.isEmpty()) {
            return new TotalHits(0L, relation);
        }
        List scoreDocs = topDocs.stream().map(topdDoc -> topdDoc.scoreDocs).filter(Objects::nonNull).collect(Collectors.toList());
        HashSet uniqueDocIds = new HashSet();
        for (ScoreDoc[] scoreDocsArray : scoreDocs) {
            uniqueDocIds.addAll(Arrays.stream(scoreDocsArray).map(scoreDoc -> scoreDoc.doc).collect(Collectors.toList()));
        }
        long maxTotalHits = uniqueDocIds.size();
        return new TotalHits(maxTotalHits, relation);
    }

    private float getMaxScore(List<TopDocs> topDocs) {
        if (topDocs.isEmpty()) {
            return 0.0f;
        }
        return topDocs.stream().map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]).map(scoreDoc -> Float.valueOf(scoreDoc.score)).max(Float::compare).get().floatValue();
    }

    private DocValueFormat[] getSortValueFormats(SortAndFormats sortAndFormats) {
        return sortAndFormats == null ? null : sortAndFormats.formats;
    }

    @Generated
    public HybridCollectorManager(int numHits, HitsThresholdChecker hitsThresholdChecker, boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filterWeight) {
        this.numHits = numHits;
        this.hitsThresholdChecker = hitsThresholdChecker;
        this.isSingleShard = isSingleShard;
        this.trackTotalHitsUpTo = trackTotalHitsUpTo;
        this.sortAndFormats = sortAndFormats;
        this.filterWeight = filterWeight;
    }

    static class HybridCollectorConcurrentSearchManager
    extends HybridCollectorManager {
        public HybridCollectorConcurrentSearchManager(int numHits, HitsThresholdChecker hitsThresholdChecker, boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight) {
            super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
        }
    }

    static class HybridCollectorNonConcurrentManager
    extends HybridCollectorManager {
        private final Collector scoreCollector = Objects.requireNonNull(super.newCollector(), "collector for hybrid query cannot be null");

        public HybridCollectorNonConcurrentManager(int numHits, HitsThresholdChecker hitsThresholdChecker, boolean isSingleShard, int trackTotalHitsUpTo, SortAndFormats sortAndFormats, Weight filteringWeight) {
            super(numHits, hitsThresholdChecker, isSingleShard, trackTotalHitsUpTo, sortAndFormats, filteringWeight);
        }

        @Override
        public Collector newCollector() {
            return this.scoreCollector;
        }

        @Override
        public ReduceableSearchResult reduce(Collection<Collector> collectors) {
            assert (collectors.isEmpty()) : "reduce on HybridCollectorNonConcurrentManager called with non-empty collectors";
            return super.reduce(List.of(this.scoreCollector));
        }
    }
}

