/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import lombok.Generated;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DisjunctionDISIApproximation;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TwoPhaseIterator;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.PriorityQueue;
import org.opensearch.neuralsearch.query.HybridScoreBlockBoundaryPropagator;

public final class HybridQueryScorer
extends Scorer {
    private final List<Scorer> subScorers;
    private final DisiPriorityQueue subScorersPQ;
    private final float[] subScores;
    private final Map<Query, List<Integer>> queryToIndex;
    private final DocIdSetIterator approximation;
    private final HybridScoreBlockBoundaryPropagator disjunctionBlockPropagator;
    private final TwoPhase twoPhase;

    public HybridQueryScorer(Weight weight, List<Scorer> subScorers) throws IOException {
        this(weight, subScorers, ScoreMode.TOP_SCORES);
    }

    HybridQueryScorer(Weight weight, List<Scorer> subScorers, ScoreMode scoreMode) throws IOException {
        super(weight);
        this.subScorers = Collections.unmodifiableList(subScorers);
        this.subScores = new float[subScorers.size()];
        this.queryToIndex = this.mapQueryToIndex();
        this.subScorersPQ = this.initializeSubScorersPQ();
        boolean needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;
        this.approximation = new HybridSubqueriesDISIApproximation(this.subScorersPQ);
        this.disjunctionBlockPropagator = scoreMode == ScoreMode.TOP_SCORES ? new HybridScoreBlockBoundaryPropagator(subScorers) : null;
        boolean hasApproximation = false;
        float sumMatchCost = 0.0f;
        long sumApproxCost = 0L;
        for (DisiWrapper w : this.subScorersPQ) {
            long costWeight = w.cost <= 1L ? 1L : w.cost;
            sumApproxCost += costWeight;
            if (w.twoPhaseView == null) continue;
            hasApproximation = true;
            sumMatchCost += w.matchCost * (float)costWeight;
        }
        if (!hasApproximation) {
            this.twoPhase = null;
        } else {
            float matchCost = sumMatchCost / (float)sumApproxCost;
            this.twoPhase = new TwoPhase(this.approximation, matchCost, this.subScorersPQ, needsScores);
        }
    }

    public int advanceShallow(int target) throws IOException {
        if (this.disjunctionBlockPropagator != null) {
            return this.disjunctionBlockPropagator.advanceShallow(target);
        }
        return super.advanceShallow(target);
    }

    public float score() throws IOException {
        return this.score(this.getSubMatches());
    }

    private float score(DisiWrapper topList) throws IOException {
        float totalScore = 0.0f;
        DisiWrapper disiWrapper = topList;
        while (disiWrapper != null) {
            if (disiWrapper.scorer.docID() != Integer.MAX_VALUE) {
                totalScore += disiWrapper.scorer.score();
            }
            disiWrapper = disiWrapper.next;
        }
        return totalScore;
    }

    DisiWrapper getSubMatches() throws IOException {
        if (this.twoPhase == null) {
            return this.subScorersPQ.topList();
        }
        return this.twoPhase.getSubMatches();
    }

    public DocIdSetIterator iterator() {
        if (this.twoPhase != null) {
            return TwoPhaseIterator.asDocIdSetIterator((TwoPhaseIterator)this.twoPhase);
        }
        return this.approximation;
    }

    public TwoPhaseIterator twoPhaseIterator() {
        return this.twoPhase;
    }

    public float getMaxScore(int upTo) throws IOException {
        return this.subScorers.stream().filter(Objects::nonNull).filter(scorer -> scorer.docID() <= upTo).map(scorer -> {
            try {
                return Float.valueOf(scorer.getMaxScore(upTo));
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }).max(Float::compare).orElse(Float.valueOf(0.0f)).floatValue();
    }

    public void setMinCompetitiveScore(float minScore) throws IOException {
        if (this.disjunctionBlockPropagator != null) {
            this.disjunctionBlockPropagator.setMinCompetitiveScore(minScore);
        }
        for (Scorer scorer : this.subScorers) {
            if (!Objects.nonNull(scorer)) continue;
            scorer.setMinCompetitiveScore(minScore);
        }
    }

    public int docID() {
        if (this.subScorersPQ.size() == 0) {
            return Integer.MAX_VALUE;
        }
        return this.subScorersPQ.top().doc;
    }

    public float[] hybridScores() throws IOException {
        DisiWrapper topList;
        float[] scores = new float[this.subScores.length];
        DisiWrapper disiWrapper = topList = this.subScorersPQ.topList();
        while (disiWrapper != null) {
            Scorer scorer = disiWrapper.scorer;
            if (scorer.docID() != Integer.MAX_VALUE) {
                Query query = scorer.getWeight().getQuery();
                List<Integer> indexes = this.queryToIndex.get(query);
                int index = indexes.stream().mapToInt(idx -> idx).filter(idx -> Float.compare(scores[idx], 0.0f) == 0).findFirst().orElseThrow(() -> new IllegalStateException(String.format(Locale.ROOT, "cannot set score for one of hybrid search subquery [%s] and document [%d]", query.toString(), scorer.docID())));
                scores[index] = scorer.score();
            }
            disiWrapper = disiWrapper.next;
        }
        return scores;
    }

    private Map<Query, List<Integer>> mapQueryToIndex() {
        HashMap<Query, List<Integer>> queryToIndex = new HashMap<Query, List<Integer>>();
        int idx = 0;
        for (Scorer scorer : this.subScorers) {
            if (scorer == null) {
                ++idx;
                continue;
            }
            Query query = scorer.getWeight().getQuery();
            queryToIndex.putIfAbsent(query, new ArrayList());
            ((List)queryToIndex.get(query)).add(idx);
            ++idx;
        }
        return queryToIndex;
    }

    private DisiPriorityQueue initializeSubScorersPQ() {
        Objects.requireNonNull(this.queryToIndex, "should not be null");
        Objects.requireNonNull(this.subScorers, "should not be null");
        int numOfSubQueries = this.queryToIndex.values().stream().map(List::size).reduce(0, Integer::sum);
        DisiPriorityQueue subScorersPQ = new DisiPriorityQueue(numOfSubQueries);
        for (Scorer scorer : this.subScorers) {
            if (scorer == null) continue;
            DisiWrapper w = new DisiWrapper(scorer);
            subScorersPQ.add(w);
        }
        return subScorersPQ;
    }

    public Collection<Scorable.ChildScorable> getChildren() throws IOException {
        ArrayList<Scorable.ChildScorable> children = new ArrayList<Scorable.ChildScorable>();
        DisiWrapper scorer = this.getSubMatches();
        while (scorer != null) {
            children.add(new Scorable.ChildScorable((Scorable)scorer.scorer, "SHOULD"));
            scorer = scorer.next;
        }
        return children;
    }

    @Generated
    public List<Scorer> getSubScorers() {
        return this.subScorers;
    }

    static class HybridSubqueriesDISIApproximation
    extends DocIdSetIterator {
        final DocIdSetIterator docIdSetIterator;
        final DisiPriorityQueue subIterators;

        public HybridSubqueriesDISIApproximation(DisiPriorityQueue subIterators) {
            this.docIdSetIterator = new DisjunctionDISIApproximation(subIterators);
            this.subIterators = subIterators;
        }

        public long cost() {
            return this.docIdSetIterator.cost();
        }

        public int docID() {
            if (this.subIterators.size() == 0) {
                return Integer.MAX_VALUE;
            }
            return this.docIdSetIterator.docID();
        }

        public int nextDoc() throws IOException {
            if (this.subIterators.size() == 0) {
                return Integer.MAX_VALUE;
            }
            return this.docIdSetIterator.nextDoc();
        }

        public int advance(int target) throws IOException {
            if (this.subIterators.size() == 0) {
                return Integer.MAX_VALUE;
            }
            return this.docIdSetIterator.advance(target);
        }
    }

    static class TwoPhase
    extends TwoPhaseIterator {
        private final float matchCost;
        DisiWrapper verifiedMatches;
        final PriorityQueue<DisiWrapper> unverifiedMatches;
        DisiPriorityQueue subScorers;
        boolean needsScores;

        private TwoPhase(DocIdSetIterator approximation, float matchCost, DisiPriorityQueue subScorers, boolean needsScores) {
            super(approximation);
            this.matchCost = matchCost;
            this.subScorers = subScorers;
            this.unverifiedMatches = new PriorityQueue<DisiWrapper>(subScorers.size()){

                protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
                    return a.matchCost < b.matchCost;
                }
            };
            this.needsScores = needsScores;
        }

        DisiWrapper getSubMatches() throws IOException {
            for (DisiWrapper wrapper : this.unverifiedMatches) {
                if (!wrapper.twoPhaseView.matches()) continue;
                wrapper.next = this.verifiedMatches;
                this.verifiedMatches = wrapper;
            }
            this.unverifiedMatches.clear();
            return this.verifiedMatches;
        }

        public boolean matches() throws IOException {
            this.verifiedMatches = null;
            this.unverifiedMatches.clear();
            DisiWrapper wrapper = this.subScorers.topList();
            while (wrapper != null) {
                DisiWrapper next = wrapper.next;
                if (Objects.isNull(wrapper.twoPhaseView)) {
                    wrapper.next = this.verifiedMatches;
                    this.verifiedMatches = wrapper;
                    if (!this.needsScores) {
                        return true;
                    }
                } else {
                    this.unverifiedMatches.add((Object)wrapper);
                }
                wrapper = next;
            }
            if (Objects.nonNull(this.verifiedMatches)) {
                return true;
            }
            while (this.unverifiedMatches.size() > 0) {
                wrapper = (DisiWrapper)this.unverifiedMatches.pop();
                if (!wrapper.twoPhaseView.matches()) continue;
                wrapper.next = null;
                this.verifiedMatches = wrapper;
                return true;
            }
            return false;
        }

        public float matchCost() {
            return this.matchCost;
        }
    }
}

