/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.search.asynchronous.listener;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.LongSupplier;
import java.util.function.Supplier;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.SetOnce;
import org.opensearch.action.search.SearchProgressActionListener;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.SearchShard;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.asynchronous.listener.CompositeSearchProgressActionListener;
import org.opensearch.search.asynchronous.listener.PartialResponseProvider;
import org.opensearch.search.asynchronous.response.AsynchronousSearchResponse;
import org.opensearch.search.internal.InternalSearchResponse;

public class AsynchronousSearchProgressListener
extends SearchProgressActionListener
implements PartialResponseProvider {
    private PartialResultsHolder partialResultsHolder;
    private final CompositeSearchProgressActionListener<AsynchronousSearchResponse> searchProgressActionListener;
    private final Function<SearchResponse, AsynchronousSearchResponse> successFunction;
    private final Function<Exception, AsynchronousSearchResponse> failureFunction;
    private final ExecutorService executor;

    public AsynchronousSearchProgressListener(long relativeStartMillis, Function<SearchResponse, AsynchronousSearchResponse> successFunction, Function<Exception, AsynchronousSearchResponse> failureFunction, ExecutorService executor, LongSupplier relativeTimeSupplier, Supplier<InternalAggregation.ReduceContextBuilder> reduceContextBuilder) {
        this.successFunction = successFunction;
        this.failureFunction = failureFunction;
        this.executor = executor;
        this.partialResultsHolder = new PartialResultsHolder(relativeStartMillis, relativeTimeSupplier, reduceContextBuilder);
        this.searchProgressActionListener = new CompositeSearchProgressActionListener();
    }

    @Override
    public SearchResponse partialResponse() {
        return this.partialResultsHolder.partialResponse();
    }

    protected void onListShards(List<SearchShard> shards, List<SearchShard> skippedShards, SearchResponse.Clusters clusters, boolean fetchPhase) {
        this.partialResultsHolder.hasFetchPhase.set((Object)fetchPhase);
        this.partialResultsHolder.totalShards.set((Object)shards.size());
        this.partialResultsHolder.skippedShards.set((Object)skippedShards.size());
        this.partialResultsHolder.successfulShards.set(skippedShards.size());
        this.partialResultsHolder.clusters.set((Object)clusters);
        this.partialResultsHolder.isInitialized = true;
    }

    protected void onPartialReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
        assert (reducePhase > this.partialResultsHolder.reducePhase.get()) : "reduce phase " + reducePhase + "less than previous phase" + this.partialResultsHolder.reducePhase.get();
        this.partialResultsHolder.partialInternalAggregations.set(aggs);
        this.partialResultsHolder.reducePhase.set(reducePhase);
        this.partialResultsHolder.totalHits.set(totalHits);
    }

    protected void onFinalReduce(List<SearchShard> shards, TotalHits totalHits, InternalAggregations aggs, int reducePhase) {
        assert (reducePhase > this.partialResultsHolder.reducePhase.get()) : "reduce phase " + reducePhase + "less than previous phase" + this.partialResultsHolder.reducePhase.get();
        this.partialResultsHolder.internalAggregations.set(aggs);
        this.partialResultsHolder.partialInternalAggregations.set(null);
        this.partialResultsHolder.reducePhase.set(reducePhase);
        this.partialResultsHolder.totalHits.set(totalHits);
    }

    protected void onFetchFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
        assert (shardIndex < (Integer)this.partialResultsHolder.totalShards.get());
        this.onSearchFailure(shardIndex, shardTarget, exc);
    }

    protected void onFetchResult(int shardIndex) {
        assert (shardIndex < (Integer)this.partialResultsHolder.totalShards.get());
        this.onShardResult(shardIndex);
    }

    protected void onQueryFailure(int shardIndex, SearchShardTarget shardTarget, Exception exc) {
        assert (shardIndex < (Integer)this.partialResultsHolder.totalShards.get());
        this.onSearchFailure(shardIndex, shardTarget, exc);
    }

    protected void onQueryResult(int shardIndex) {
        assert (shardIndex < (Integer)this.partialResultsHolder.totalShards.get());
        this.onShardResult(shardIndex);
    }

    private synchronized void onShardResult(int shardIndex) {
        if (!this.partialResultsHolder.successfulShardIds.contains(shardIndex)) {
            this.partialResultsHolder.successfulShardIds.add(shardIndex);
            this.partialResultsHolder.successfulShards.incrementAndGet();
        }
    }

    private synchronized void onSearchFailure(int shardIndex, SearchShardTarget shardTarget, Exception e) {
        if (this.partialResultsHolder.successfulShardIds.contains(shardIndex)) {
            this.partialResultsHolder.successfulShardIds.remove(shardIndex);
            this.partialResultsHolder.successfulShards.decrementAndGet();
        }
    }

    public CompositeSearchProgressActionListener<AsynchronousSearchResponse> searchProgressActionListener() {
        return this.searchProgressActionListener;
    }

    public void onResponse(SearchResponse searchResponse) {
        this.executor.execute(() -> {
            try {
                AsynchronousSearchResponse result = this.successFunction.apply(searchResponse);
                this.searchProgressActionListener.onResponse(result);
            }
            catch (Exception ex) {
                this.searchProgressActionListener.onFailure(ex);
            }
            finally {
                this.clearPartialResult();
            }
        });
    }

    public void onFailure(Exception e) {
        this.executor.execute(() -> {
            try {
                AsynchronousSearchResponse result = this.failureFunction.apply(e);
                this.searchProgressActionListener.onResponse(result);
            }
            catch (Exception ex) {
                this.searchProgressActionListener.onFailure(ex);
            }
            finally {
                this.clearPartialResult();
            }
        });
    }

    private void clearPartialResult() {
        this.partialResultsHolder = null;
    }

    static class PartialResultsHolder {
        volatile boolean isInitialized = false;
        final AtomicInteger reducePhase;
        final SetOnce<Integer> totalShards;
        final SetOnce<Integer> skippedShards;
        final SetOnce<SearchResponse.Clusters> clusters;
        final Set<Integer> successfulShardIds;
        final SetOnce<Boolean> hasFetchPhase;
        final AtomicInteger successfulShards;
        final AtomicReference<TotalHits> totalHits;
        final AtomicReference<InternalAggregations> internalAggregations = new AtomicReference();
        final AtomicReference<InternalAggregations> partialInternalAggregations;
        final long relativeStartMillis;
        final LongSupplier relativeTimeSupplier;
        final Supplier<InternalAggregation.ReduceContextBuilder> reduceContextBuilder;

        PartialResultsHolder(long relativeStartMillis, LongSupplier relativeTimeSupplier, Supplier<InternalAggregation.ReduceContextBuilder> reduceContextBuilder) {
            this.totalShards = new SetOnce();
            this.successfulShards = new AtomicInteger();
            this.skippedShards = new SetOnce();
            this.reducePhase = new AtomicInteger();
            this.hasFetchPhase = new SetOnce();
            this.totalHits = new AtomicReference();
            this.clusters = new SetOnce();
            this.partialInternalAggregations = new AtomicReference();
            this.relativeStartMillis = relativeStartMillis;
            this.successfulShardIds = new HashSet<Integer>(1);
            this.relativeTimeSupplier = relativeTimeSupplier;
            this.reduceContextBuilder = reduceContextBuilder;
        }

        public SearchResponse partialResponse() {
            if (this.isInitialized) {
                SearchHits searchHits = new SearchHits(SearchHits.EMPTY, this.totalHits.get(), Float.NaN);
                InternalAggregations finalAggregation = null;
                if (this.internalAggregations.get() != null) {
                    finalAggregation = this.internalAggregations.get();
                } else if (this.partialInternalAggregations.get() != null) {
                    finalAggregation = InternalAggregations.topLevelReduce(Collections.singletonList(this.partialInternalAggregations.get()), (InternalAggregation.ReduceContext)this.reduceContextBuilder.get().forFinalReduction());
                }
                InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, finalAggregation, null, null, false, null, this.reducePhase.get());
                long tookInMillis = this.relativeTimeSupplier.getAsLong() - this.relativeStartMillis;
                return new SearchResponse((SearchResponseSections)internalSearchResponse, null, ((Integer)this.totalShards.get()).intValue(), this.successfulShards.get(), ((Integer)this.skippedShards.get()).intValue(), tookInMillis, ShardSearchFailure.EMPTY_ARRAY, (SearchResponse.Clusters)this.clusters.get());
            }
            return null;
        }
    }
}

