/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.aggregations.bucket;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.NumericUtils;
import org.opensearch.common.CheckedRunnable;
import org.opensearch.common.Rounding;
import org.opensearch.common.lucene.search.function.FunctionScoreQuery;
import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.DateRangeIncludingNowQuery;
import org.opensearch.search.aggregations.bucket.composite.CompositeAggregator;
import org.opensearch.search.aggregations.bucket.composite.CompositeValuesSourceConfig;
import org.opensearch.search.aggregations.bucket.composite.RoundingValuesSource;
import org.opensearch.search.aggregations.bucket.histogram.LongBounds;
import org.opensearch.search.internal.SearchContext;

public final class FastFilterRewriteHelper {
    private static final Logger logger = LogManager.getLogger(FastFilterRewriteHelper.class);
    private static final Map<Class<?>, Function<Query, Query>> queryWrappers = new HashMap();

    private FastFilterRewriteHelper() {
    }

    private static Query unwrapIntoConcreteQuery(Query query) {
        while (queryWrappers.containsKey(query.getClass())) {
            query = queryWrappers.get(query.getClass()).apply(query);
        }
        return query;
    }

    private static long[] getShardBounds(SearchContext context, String fieldName) throws IOException {
        List leaves = context.searcher().getIndexReader().leaves();
        long min = Long.MAX_VALUE;
        long max = Long.MIN_VALUE;
        for (LeafReaderContext leaf : leaves) {
            PointValues values = leaf.reader().getPointValues(fieldName);
            if (values == null) continue;
            min = Math.min(min, NumericUtils.sortableBytesToLong((byte[])values.getMinPackedValue(), (int)0));
            max = Math.max(max, NumericUtils.sortableBytesToLong((byte[])values.getMaxPackedValue(), (int)0));
        }
        if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
            return null;
        }
        return new long[]{min, max};
    }

    private static long[] getSegmentBounds(LeafReaderContext context, String fieldName) throws IOException {
        long min = Long.MAX_VALUE;
        long max = Long.MIN_VALUE;
        PointValues values = context.reader().getPointValues(fieldName);
        if (values != null) {
            min = Math.min(min, NumericUtils.sortableBytesToLong((byte[])values.getMinPackedValue(), (int)0));
            max = Math.max(max, NumericUtils.sortableBytesToLong((byte[])values.getMaxPackedValue(), (int)0));
        }
        if (min == Long.MAX_VALUE || max == Long.MIN_VALUE) {
            return null;
        }
        return new long[]{min, max};
    }

    public static long[] getDateHistoAggBounds(SearchContext context, String fieldName) throws IOException {
        Query cq = FastFilterRewriteHelper.unwrapIntoConcreteQuery(context.query());
        if (cq instanceof PointRangeQuery) {
            PointRangeQuery prq = (PointRangeQuery)cq;
            long[] indexBounds = FastFilterRewriteHelper.getShardBounds(context, fieldName);
            if (indexBounds == null) {
                return null;
            }
            return FastFilterRewriteHelper.getBoundsWithRangeQuery(prq, fieldName, indexBounds);
        }
        if (cq instanceof MatchAllDocsQuery) {
            return FastFilterRewriteHelper.getShardBounds(context, fieldName);
        }
        if (cq instanceof FieldExistsQuery && ((FieldExistsQuery)cq).getField().equals(fieldName)) {
            return FastFilterRewriteHelper.getShardBounds(context, fieldName);
        }
        return null;
    }

    private static long[] getBoundsWithRangeQuery(PointRangeQuery prq, String fieldName, long[] indexBounds) {
        if (prq.getField().equals(fieldName)) {
            long upper;
            long lower = Math.max(NumericUtils.sortableBytesToLong((byte[])prq.getLowerPoint(), (int)0), indexBounds[0]);
            if (lower > (upper = Math.min(NumericUtils.sortableBytesToLong((byte[])prq.getUpperPoint(), (int)0), indexBounds[1]))) {
                return null;
            }
            return new long[]{lower, upper};
        }
        return null;
    }

    public static boolean isCompositeAggRewriteable(CompositeValuesSourceConfig[] sourceConfigs) {
        return sourceConfigs.length == 1 && sourceConfigs[0].valuesSource() instanceof RoundingValuesSource;
    }

    public static long getBucketOrd(long bucketOrd) {
        if (bucketOrd < 0L) {
            bucketOrd = -1L - bucketOrd;
        }
        return bucketOrd;
    }

    public static boolean tryFastFilterAggregation(LeafReaderContext ctx, FastFilterContext fastFilterContext, BiConsumer<Long, Integer> incrementDocCount) throws IOException {
        ++fastFilterContext.segments;
        if (!fastFilterContext.rewriteable) {
            return false;
        }
        if (ctx.reader().hasDeletions()) {
            return false;
        }
        PointValues values = ctx.reader().getPointValues(fastFilterContext.fieldName);
        if (values == null) {
            return false;
        }
        if ((long)values.getDocCount() != values.size()) {
            return false;
        }
        NumericDocValues docCountValues = DocValues.getNumeric((LeafReader)ctx.reader(), (String)"_doc_count");
        if (docCountValues.nextDoc() != Integer.MAX_VALUE) {
            logger.debug("Shard {} segment {} has at least one document with _doc_count field, skip fast filter optimization", (Object)fastFilterContext.context.indexShard().shardId(), (Object)ctx.ord);
            return false;
        }
        if (!fastFilterContext.rangesBuiltAtShardLevel && !FastFilterRewriteHelper.segmentMatchAll(fastFilterContext.context, ctx)) {
            return false;
        }
        long[][] ranges = fastFilterContext.ranges;
        if (ranges == null) {
            logger.debug("Shard {} segment {} functionally match all documents. Build the fast filter", (Object)fastFilterContext.context.indexShard().shardId(), (Object)ctx.ord);
            ranges = fastFilterContext.buildRanges(ctx);
            if (ranges == null) {
                return false;
            }
        }
        AggregationType aggregationType = fastFilterContext.aggregationType;
        assert (aggregationType instanceof AbstractDateHistogramAggregationType);
        DateFieldMapper.DateFieldType fieldType = ((AbstractDateHistogramAggregationType)aggregationType).getFieldType();
        int size = Integer.MAX_VALUE;
        if (aggregationType instanceof CompositeAggregator.CompositeAggregationType) {
            size = ((CompositeAggregator.CompositeAggregationType)aggregationType).getSize();
        }
        DebugInfo debugInfo = FastFilterRewriteHelper.multiRangesTraverse(values.getPointTree(), ranges, incrementDocCount, fieldType, size);
        fastFilterContext.consumeDebugInfo(debugInfo);
        ++fastFilterContext.optimizedSegments;
        logger.debug("Fast filter optimization applied to shard {} segment {}", (Object)fastFilterContext.context.indexShard().shardId(), (Object)ctx.ord);
        logger.debug("crossed leaf nodes: {}, inner nodes: {}", (Object)fastFilterContext.leaf, (Object)fastFilterContext.inner);
        return true;
    }

    private static boolean segmentMatchAll(SearchContext ctx, LeafReaderContext leafCtx) throws IOException {
        Weight weight = ctx.searcher().createWeight(ctx.query(), ScoreMode.COMPLETE_NO_SCORES, 1.0f);
        return weight != null && weight.count(leafCtx) == leafCtx.reader().numDocs();
    }

    private static long[][] createRangesFromAgg(SearchContext context, DateFieldMapper.DateFieldType fieldType, long interval, Rounding.Prepared preparedRounding, long low, long high) {
        long roundedLow;
        long prevRounded = roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
        int bucketCount = 0;
        while (roundedLow <= fieldType.convertNanosToMillis(high)) {
            int maxNumFilterBuckets = context.maxAggRewriteFilters();
            if (++bucketCount > maxNumFilterBuckets) {
                logger.debug("Max number of filters reached [{}], skip the fast filter optimization", (Object)maxNumFilterBuckets);
                return null;
            }
            if (prevRounded == (roundedLow = preparedRounding.round(roundedLow + interval))) break;
            prevRounded = roundedLow;
        }
        long[][] ranges = new long[bucketCount][2];
        if (bucketCount > 0) {
            roundedLow = preparedRounding.round(fieldType.convertNanosToMillis(low));
            for (int i = 0; i < bucketCount; ++i) {
                long lower = i == 0 ? low : fieldType.convertRoundedMillisToNanos(roundedLow);
                roundedLow = preparedRounding.round(roundedLow + interval);
                long upper = i + 1 == bucketCount ? high : fieldType.convertRoundedMillisToNanos(roundedLow) - 1L;
                ranges[i][0] = lower;
                ranges[i][1] = upper;
            }
        }
        return ranges;
    }

    private static DebugInfo multiRangesTraverse(PointValues.PointTree tree, long[][] ranges, BiConsumer<Long, Integer> incrementDocCount, DateFieldMapper.DateFieldType fieldType, int maxNumNonZeroRanges) throws IOException {
        Iterator<long[]> rangeIter = Arrays.stream(ranges).iterator();
        long[] activeRange = (long[])rangeIter.next();
        DebugInfo debugInfo = new DebugInfo();
        if (activeRange[0] > NumericUtils.sortableBytesToLong((byte[])tree.getMaxPackedValue(), (int)0)) {
            logger.debug("No ranges match the query, skip the fast filter optimization");
            return debugInfo;
        }
        while (activeRange[1] < NumericUtils.sortableBytesToLong((byte[])tree.getMinPackedValue(), (int)0)) {
            if (!rangeIter.hasNext()) {
                logger.debug("No ranges match the query, skip the fast filter optimization");
                return debugInfo;
            }
            activeRange = (long[])rangeIter.next();
        }
        RangeCollectorForPointTree collector = new RangeCollectorForPointTree(incrementDocCount, fieldType, rangeIter, maxNumNonZeroRanges, activeRange);
        ArrayUtil.ByteArrayComparator comparator = ArrayUtil.getUnsignedComparator((int)8);
        PointValues.IntersectVisitor visitor = FastFilterRewriteHelper.getIntersectVisitor(collector, comparator);
        try {
            FastFilterRewriteHelper.intersectWithRanges(visitor, tree, collector, debugInfo);
        }
        catch (CollectionTerminatedException e) {
            logger.debug("Early terminate since no more range to collect");
        }
        collector.finalizePreviousRange();
        return debugInfo;
    }

    private static void intersectWithRanges(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, RangeCollectorForPointTree collector, DebugInfo debug) throws IOException {
        PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
        switch (r) {
            case CELL_INSIDE_QUERY: {
                collector.countNode((int)pointTree.size());
                debug.visitInner();
                break;
            }
            case CELL_CROSSES_QUERY: {
                if (pointTree.moveToChild()) {
                    do {
                        FastFilterRewriteHelper.intersectWithRanges(visitor, pointTree, collector, debug);
                    } while (pointTree.moveToSibling());
                    pointTree.moveToParent();
                    break;
                }
                pointTree.visitDocValues(visitor);
                debug.visitLeaf();
                break;
            }
        }
    }

    private static PointValues.IntersectVisitor getIntersectVisitor(final RangeCollectorForPointTree collector, final ArrayUtil.ByteArrayComparator comparator) {
        return new PointValues.IntersectVisitor(){

            public void visit(int docID) throws IOException {
                throw new UnsupportedOperationException("This IntersectVisitor does not perform any actions on a docID=" + docID + " node being visited");
            }

            public void visit(int docID, byte[] packedValue) throws IOException {
                this.visitPoints(packedValue, (CheckedRunnable<IOException>)((CheckedRunnable)() -> collector.count()));
            }

            public void visit(DocIdSetIterator iterator, byte[] packedValue) throws IOException {
                this.visitPoints(packedValue, (CheckedRunnable<IOException>)((CheckedRunnable)() -> {
                    int doc = iterator.nextDoc();
                    while (doc != Integer.MAX_VALUE) {
                        collector.count();
                        doc = iterator.nextDoc();
                    }
                }));
            }

            private void visitPoints(byte[] packedValue, CheckedRunnable<IOException> collect) throws IOException {
                if (comparator.compare(packedValue, 0, collector.activeRangeAsByteArray[1], 0) > 0) {
                    collector.finalizePreviousRange();
                    if (collector.iterateRangeEnd(packedValue, this::compareByteValue)) {
                        throw new CollectionTerminatedException();
                    }
                }
                if (this.pointCompare(collector.activeRangeAsByteArray[0], collector.activeRangeAsByteArray[1], packedValue)) {
                    collect.run();
                }
            }

            private boolean pointCompare(byte[] lower, byte[] upper, byte[] packedValue) {
                if (this.compareByteValue(packedValue, lower) < 0) {
                    return false;
                }
                return this.compareByteValue(packedValue, upper) <= 0;
            }

            private int compareByteValue(byte[] value1, byte[] value2) {
                return comparator.compare(value1, 0, value2, 0);
            }

            public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue) {
                byte[] rangeMin = collector.activeRangeAsByteArray[0];
                byte[] rangeMax = collector.activeRangeAsByteArray[1];
                if (this.compareByteValue(rangeMax, minPackedValue) < 0) {
                    collector.finalizePreviousRange();
                    if (collector.iterateRangeEnd(minPackedValue, this::compareByteValue)) {
                        throw new CollectionTerminatedException();
                    }
                    rangeMax = collector.activeRangeAsByteArray[1];
                }
                if (this.compareByteValue(rangeMin, minPackedValue) > 0 || this.compareByteValue(rangeMax, maxPackedValue) < 0) {
                    return PointValues.Relation.CELL_CROSSES_QUERY;
                }
                return PointValues.Relation.CELL_INSIDE_QUERY;
            }
        };
    }

    static {
        queryWrappers.put(ConstantScoreQuery.class, q -> ((ConstantScoreQuery)q).getQuery());
        queryWrappers.put(FunctionScoreQuery.class, q -> ((FunctionScoreQuery)((Object)q)).getSubQuery());
        queryWrappers.put(DateRangeIncludingNowQuery.class, q -> ((DateRangeIncludingNowQuery)((Object)q)).getQuery());
        queryWrappers.put(IndexOrDocValuesQuery.class, q -> ((IndexOrDocValuesQuery)q).getIndexQuery());
    }

    private static class DebugInfo {
        private int leaf = 0;
        private int inner = 0;

        private DebugInfo() {
        }

        private void visitLeaf() {
            ++this.leaf;
        }

        private void visitInner() {
            ++this.inner;
        }
    }

    private static class RangeCollectorForPointTree {
        private final BiConsumer<Long, Integer> incrementDocCount;
        private final DateFieldMapper.DateFieldType fieldType;
        private int counter = 0;
        private long[] activeRange;
        private byte[][] activeRangeAsByteArray;
        private final Iterator<long[]> rangeIter;
        private int visitedRange = 0;
        private final int maxNumNonZeroRange;

        public RangeCollectorForPointTree(BiConsumer<Long, Integer> incrementDocCount, DateFieldMapper.DateFieldType fieldType, Iterator<long[]> rangeIter, int maxNumNonZeroRange, long[] activeRange) {
            this.incrementDocCount = incrementDocCount;
            this.fieldType = fieldType;
            this.rangeIter = rangeIter;
            this.maxNumNonZeroRange = maxNumNonZeroRange;
            this.activeRange = activeRange;
            this.activeRangeAsByteArray = this.activeRangeAsByteArray();
        }

        private void count() {
            ++this.counter;
        }

        private void countNode(int count) {
            this.counter += count;
        }

        private void finalizePreviousRange() {
            if (this.counter > 0) {
                logger.debug("finalize previous range: {}", (Object)this.activeRange[0]);
                logger.debug("counter: {}", (Object)this.counter);
                this.incrementDocCount.accept(this.fieldType.convertNanosToMillis(this.activeRange[0]), this.counter);
                this.counter = 0;
            }
        }

        private boolean iterateRangeEnd(byte[] value, BiFunction<byte[], byte[], Integer> comparator) {
            while (comparator.apply(this.activeRangeAsByteArray[1], value) < 0) {
                if (!this.rangeIter.hasNext()) {
                    return true;
                }
                this.activeRange = this.rangeIter.next();
                this.activeRangeAsByteArray = this.activeRangeAsByteArray();
            }
            ++this.visitedRange;
            return this.visitedRange > this.maxNumNonZeroRange;
        }

        private byte[][] activeRangeAsByteArray() {
            byte[] lower = new byte[8];
            byte[] upper = new byte[8];
            NumericUtils.longToSortableBytes((long)this.activeRange[0], (byte[])lower, (int)0);
            NumericUtils.longToSortableBytes((long)this.activeRange[1], (byte[])upper, (int)0);
            return new byte[][]{lower, upper};
        }
    }

    public static abstract class AbstractDateHistogramAggregationType
    implements AggregationType {
        private final MappedFieldType fieldType;
        private final boolean missing;
        private final boolean hasScript;
        private LongBounds hardBounds;

        public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript) {
            this.fieldType = fieldType;
            this.missing = missing;
            this.hasScript = hasScript;
        }

        public AbstractDateHistogramAggregationType(MappedFieldType fieldType, boolean missing, boolean hasScript, LongBounds hardBounds) {
            this(fieldType, missing, hasScript);
            this.hardBounds = hardBounds;
        }

        @Override
        public boolean isRewriteable(Object parent, int subAggLength) {
            if (parent == null && subAggLength == 0 && !this.missing && !this.hasScript && this.fieldType != null && this.fieldType instanceof DateFieldMapper.DateFieldType) {
                return this.fieldType.isSearchable();
            }
            return false;
        }

        @Override
        public long[][] buildRanges(SearchContext context) throws IOException {
            long[] bounds = FastFilterRewriteHelper.getDateHistoAggBounds(context, this.fieldType.name());
            logger.debug("Bounds are {} for shard {}", (Object)bounds, (Object)context.indexShard().shardId());
            return this.buildRanges(context, bounds);
        }

        private long[][] buildRanges(SearchContext context, long[] bounds) throws IOException {
            if ((bounds = this.processHardBounds(bounds)) == null) {
                return null;
            }
            assert (bounds[0] <= bounds[1]) : "Low bound should be less than high bound";
            Rounding rounding = this.getRounding(bounds[0], bounds[1]);
            OptionalLong intervalOpt = Rounding.getInterval(rounding);
            if (intervalOpt.isEmpty()) {
                return null;
            }
            long interval = intervalOpt.getAsLong();
            this.processAfterKey(bounds, interval);
            return FastFilterRewriteHelper.createRangesFromAgg(context, (DateFieldMapper.DateFieldType)this.fieldType, interval, this.getRoundingPrepared(), bounds[0], bounds[1]);
        }

        @Override
        public long[][] buildRanges(LeafReaderContext leaf, SearchContext context) throws IOException {
            long[] bounds = FastFilterRewriteHelper.getSegmentBounds(leaf, this.fieldType.name());
            logger.debug("Bounds are {} for shard {} segment {}", (Object)bounds, (Object)context.indexShard().shardId(), (Object)leaf.ord);
            return this.buildRanges(context, bounds);
        }

        protected abstract Rounding getRounding(long var1, long var3);

        protected abstract Rounding.Prepared getRoundingPrepared();

        protected void processAfterKey(long[] bound, long interval) {
        }

        protected long[] processHardBounds(long[] bounds) {
            if (bounds != null && this.hardBounds != null) {
                if (this.hardBounds.getMin() > bounds[0]) {
                    bounds[0] = this.hardBounds.getMin();
                }
                if (this.hardBounds.getMax() - 1L < bounds[1]) {
                    bounds[1] = this.hardBounds.getMax() - 1L;
                }
                if (bounds[0] > bounds[1]) {
                    return null;
                }
            }
            return bounds;
        }

        public DateFieldMapper.DateFieldType getFieldType() {
            assert (this.fieldType instanceof DateFieldMapper.DateFieldType);
            return (DateFieldMapper.DateFieldType)this.fieldType;
        }
    }

    static interface AggregationType {
        public boolean isRewriteable(Object var1, int var2);

        public long[][] buildRanges(SearchContext var1) throws IOException;

        public long[][] buildRanges(LeafReaderContext var1, SearchContext var2) throws IOException;
    }

    public static class FastFilterContext {
        private boolean rewriteable = false;
        private boolean rangesBuiltAtShardLevel = false;
        private AggregationType aggregationType;
        private final SearchContext context;
        private String fieldName;
        private long[][] ranges;
        public int leaf;
        public int inner;
        public int segments;
        public int optimizedSegments;

        public void setFieldName(String fieldName) {
            this.fieldName = fieldName;
        }

        public FastFilterContext(SearchContext context) {
            this.context = context;
        }

        public AggregationType getAggregationType() {
            return this.aggregationType;
        }

        public void setAggregationType(AggregationType aggregationType) {
            this.aggregationType = aggregationType;
        }

        public boolean isRewriteable(Object parent, int subAggLength) {
            if (this.context.maxAggRewriteFilters() == 0) {
                return false;
            }
            boolean rewriteable = this.aggregationType.isRewriteable(parent, subAggLength);
            logger.debug("Fast filter rewriteable: {} for shard {}", (Object)rewriteable, (Object)this.context.indexShard().shardId());
            this.rewriteable = rewriteable;
            return rewriteable;
        }

        public void buildRanges() throws IOException {
            assert (this.ranges == null) : "Ranges should only be built once at shard level, but they are already built";
            this.ranges = this.aggregationType.buildRanges(this.context);
            if (this.ranges != null) {
                logger.debug("Ranges built for shard {}", (Object)this.context.indexShard().shardId());
                this.rangesBuiltAtShardLevel = true;
            }
        }

        public long[][] buildRanges(LeafReaderContext leaf) throws IOException {
            long[][] ranges = this.aggregationType.buildRanges(leaf, this.context);
            if (ranges != null) {
                logger.debug("Ranges built for shard {} segment {}", (Object)this.context.indexShard().shardId(), (Object)leaf.ord);
            }
            return ranges;
        }

        private void consumeDebugInfo(DebugInfo debug) {
            this.leaf += debug.leaf;
            this.inner += debug.inner;
        }
    }
}

