/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.opensearch.planner.logical.rule;

import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import lombok.Generated;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg;
import org.opensearch.sql.opensearch.planner.logical.rule.OptimizationRuleUtils;
import org.opensearch.sql.planner.logical.LogicalPlan;
import org.opensearch.sql.planner.logical.LogicalSort;
import org.opensearch.sql.planner.optimizer.Rule;
import org.opensearch.sql.planner.optimizer.pattern.Patterns;

public class MergeSortAndIndexAgg
implements Rule<LogicalSort> {
    private final Capture<OpenSearchLogicalIndexAgg> indexAggCapture = Capture.newCapture();
    private final Pattern<LogicalSort> pattern;

    public MergeSortAndIndexAgg() {
        AtomicReference sortRef = new AtomicReference();
        this.pattern = Pattern.typeOf(LogicalSort.class).matching(OptimizationRuleUtils::sortByFieldsOnly).matching(sort -> {
            sortRef.set(sort);
            return true;
        }).with(Patterns.source().matching(Pattern.typeOf(OpenSearchLogicalIndexAgg.class).matching(indexAgg -> !this.hasAggregatorInSortBy((LogicalSort)sortRef.get(), (OpenSearchLogicalIndexAgg)((Object)indexAgg))).capturedAs(this.indexAggCapture)));
    }

    public LogicalPlan apply(LogicalSort sort, Captures captures) {
        OpenSearchLogicalIndexAgg indexAgg = (OpenSearchLogicalIndexAgg)((Object)captures.get(this.indexAggCapture));
        return OpenSearchLogicalIndexAgg.builder().relationName(indexAgg.getRelationName()).filter(indexAgg.getFilter()).groupByList(indexAgg.getGroupByList()).aggregatorList(indexAgg.getAggregatorList()).sortList(sort.getSortList()).build();
    }

    private boolean hasAggregatorInSortBy(LogicalSort sort, OpenSearchLogicalIndexAgg agg) {
        Set aggregatorNames = agg.getAggregatorList().stream().map(NamedAggregator::getName).collect(Collectors.toSet());
        for (Pair sortPair : sort.getSortList()) {
            if (!aggregatorNames.contains(((ReferenceExpression)sortPair.getRight()).getAttr())) continue;
            return true;
        }
        return false;
    }

    @Generated
    public Pattern<LogicalSort> pattern() {
        return this.pattern;
    }
}

