/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Matches;
import org.apache.lucene.search.MatchesUtils;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.neuralsearch.query.HybridQuery;
import org.opensearch.neuralsearch.query.HybridQueryScorer;

public final class HybridQueryWeight
extends Weight {
    private final List<Weight> weights;
    private final ScoreMode scoreMode;

    public HybridQueryWeight(HybridQuery hybridQuery, IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
        super((Query)hybridQuery);
        this.weights = hybridQuery.getSubQueries().stream().map(q -> {
            try {
                return searcher.createWeight(q, scoreMode, boost);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }).collect(Collectors.toList());
        this.scoreMode = scoreMode;
    }

    public Matches matches(LeafReaderContext context, int doc) throws IOException {
        List mis = this.weights.stream().map(weight -> {
            try {
                return weight.matches(context, doc);
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }).filter(Objects::nonNull).collect(Collectors.toList());
        return MatchesUtils.fromSubMatches(mis);
    }

    public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
        ArrayList<ScorerSupplier> scorerSuppliers = new ArrayList<ScorerSupplier>();
        for (Weight w : this.weights) {
            ScorerSupplier ss = w.scorerSupplier(context);
            scorerSuppliers.add(ss);
        }
        if (scorerSuppliers.isEmpty()) {
            return null;
        }
        return new HybridScorerSupplier(scorerSuppliers, this, this.scoreMode);
    }

    public Scorer scorer(LeafReaderContext context) throws IOException {
        ScorerSupplier supplier = this.scorerSupplier(context);
        if (supplier == null) {
            return null;
        }
        supplier.setTopLevelScoringClause();
        return supplier.get(Long.MAX_VALUE);
    }

    public boolean isCacheable(LeafReaderContext ctx) {
        if (this.weights.size() > 5) {
            return false;
        }
        return this.weights.stream().allMatch(w -> w.isCacheable(ctx));
    }

    public Explanation explain(LeafReaderContext context, int doc) throws IOException {
        throw new UnsupportedOperationException("Explain is not supported");
    }

    static class HybridScorerSupplier
    extends ScorerSupplier {
        private long cost = -1L;
        private final List<ScorerSupplier> scorerSuppliers;
        private final Weight weight;
        private final ScoreMode scoreMode;

        public Scorer get(long leadCost) throws IOException {
            ArrayList<Scorer> tScorers = new ArrayList<Scorer>();
            for (ScorerSupplier ss : this.scorerSuppliers) {
                if (Objects.nonNull(ss)) {
                    tScorers.add(ss.get(leadCost));
                    continue;
                }
                tScorers.add(null);
            }
            return new HybridQueryScorer(this.weight, tScorers, this.scoreMode);
        }

        public long cost() {
            if (this.cost == -1L) {
                long cost = 0L;
                for (ScorerSupplier ss : this.scorerSuppliers) {
                    if (!Objects.nonNull(ss)) continue;
                    cost += ss.cost();
                }
                this.cost = cost;
            }
            return this.cost;
        }

        public void setTopLevelScoringClause() throws IOException {
            for (ScorerSupplier ss : this.scorerSuppliers) {
                if (!Objects.nonNull(ss)) continue;
                ss.setTopLevelScoringClause();
            }
        }

        @Generated
        public HybridScorerSupplier(List<ScorerSupplier> scorerSuppliers, Weight weight, ScoreMode scoreMode) {
            this.scorerSuppliers = scorerSuppliers;
            this.weight = weight;
            this.scoreMode = scoreMode;
        }
    }
}

