/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.training;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.search.SearchRequestBuilder;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchScrollRequest;
import org.opensearch.action.search.SearchScrollRequestBuilder;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.ValidationException;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.search.SearchHit;
import org.opensearch.search.sort.SortOrder;

public class VectorReader {
    public static Logger logger = LogManager.getLogger(VectorReader.class);
    private final Client client;
    private final TimeValue scrollTime = new TimeValue(60000L);

    public VectorReader(Client client) {
        this.client = client;
    }

    public void read(ClusterService clusterService, String indexName, String fieldName, int maxVectorCount, int searchSize, Consumer<List<Float[]>> vectorConsumer, ActionListener<SearchResponse> listener) {
        IndexMetadata indexMetadata;
        ValidationException validationException = null;
        if (maxVectorCount <= 0) {
            validationException = new ValidationException();
            validationException.addValidationError("maxVectorCount must be >= 0");
        }
        if (searchSize > 10000 || searchSize <= 0) {
            validationException = validationException == null ? new ValidationException() : validationException;
            validationException.addValidationError("searchSize must be > 0 and <= 10000");
        }
        if ((indexMetadata = clusterService.state().metadata().index(indexName)) == null) {
            validationException = validationException == null ? new ValidationException() : validationException;
            validationException.addValidationError("index \"" + indexName + "\" does not exist");
            throw validationException;
        }
        ValidationException fieldValidationException = IndexUtil.validateKnnField(indexMetadata, fieldName, -1, null);
        if (fieldValidationException != null) {
            validationException = validationException == null ? new ValidationException() : validationException;
            validationException.addValidationErrors((Iterable)validationException.validationErrors());
        }
        if (validationException != null) {
            throw validationException;
        }
        SearchScrollRequestBuilder searchScrollRequestBuilder = this.createSearchScrollRequestBuilder();
        VectorReaderListener vectorReaderListener = new VectorReaderListener(this.client, fieldName, maxVectorCount, 0, listener, vectorConsumer, searchScrollRequestBuilder);
        this.createSearchRequestBuilder(indexName, fieldName, Integer.min(maxVectorCount, searchSize)).execute((ActionListener)vectorReaderListener);
    }

    private SearchRequestBuilder createSearchRequestBuilder(String indexName, String fieldName, int resultSize) {
        ExistsQueryBuilder queryBuilder = new ExistsQueryBuilder(fieldName);
        SearchRequestBuilder searchRequestBuilder = this.client.prepareSearch(new String[]{indexName});
        searchRequestBuilder.setScroll(this.scrollTime);
        searchRequestBuilder.setQuery((QueryBuilder)queryBuilder);
        searchRequestBuilder.setSize(resultSize);
        searchRequestBuilder.addSort("_doc", SortOrder.ASC);
        searchRequestBuilder.setFetchSource(fieldName, null);
        return searchRequestBuilder;
    }

    private SearchScrollRequestBuilder createSearchScrollRequestBuilder() {
        SearchScrollRequestBuilder searchScrollRequestBuilder = this.client.prepareSearchScroll(null);
        searchScrollRequestBuilder.setScroll(this.scrollTime);
        return searchScrollRequestBuilder;
    }

    private static class VectorReaderListener
    implements ActionListener<SearchResponse> {
        final Client client;
        final String fieldName;
        final int maxVectorCount;
        int collectedVectorCount;
        final ActionListener<SearchResponse> listener;
        final Consumer<List<Float[]>> vectorConsumer;
        SearchScrollRequestBuilder searchScrollRequestBuilder;

        public VectorReaderListener(Client client, String fieldName, int maxVectorCount, int collectedVectorCount, ActionListener<SearchResponse> listener, Consumer<List<Float[]>> vectorConsumer, SearchScrollRequestBuilder searchScrollRequestBuilder) {
            this.client = client;
            this.fieldName = fieldName;
            this.maxVectorCount = maxVectorCount;
            this.collectedVectorCount = collectedVectorCount;
            this.listener = listener;
            this.vectorConsumer = vectorConsumer;
            this.searchScrollRequestBuilder = searchScrollRequestBuilder;
        }

        public void onResponse(SearchResponse searchResponse) {
            SearchHit[] hits = searchResponse.getHits().getHits();
            int vectorsToAdd = Integer.min(this.maxVectorCount - this.collectedVectorCount, hits.length);
            List<Float[]> trainingData = this.extractVectorsFromHits(searchResponse, vectorsToAdd);
            this.collectedVectorCount += trainingData.size();
            this.vectorConsumer.accept(trainingData);
            if (vectorsToAdd <= 0 || this.collectedVectorCount >= this.maxVectorCount) {
                String scrollId = searchResponse.getScrollId();
                if (scrollId != null) {
                    this.client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap(clearScrollResponse -> this.listener.onResponse((Object)searchResponse), arg_0 -> this.listener.onFailure(arg_0)));
                } else {
                    this.listener.onResponse((Object)searchResponse);
                }
            } else {
                this.searchScrollRequestBuilder.setScrollId(searchResponse.getScrollId());
                this.searchScrollRequestBuilder.execute((ActionListener)this);
            }
        }

        public void onFailure(Exception e) {
            String scrollId = ((SearchScrollRequest)this.searchScrollRequestBuilder.request()).scrollId();
            if (scrollId != null) {
                this.client.prepareClearScroll().addScrollId(scrollId).execute(ActionListener.wrap(clearScrollResponse -> this.listener.onFailure(e), arg_0 -> this.listener.onFailure(arg_0)));
            } else {
                this.listener.onFailure(e);
            }
        }

        private List<Float[]> extractVectorsFromHits(SearchResponse searchResponse, int vectorsToAdd) {
            SearchHit[] hits = searchResponse.getHits().getHits();
            ArrayList<Float[]> trainingData = new ArrayList<Float[]>();
            String[] fieldPath = this.fieldName.split("\\.");
            int nullVectorCount = 0;
            for (int vector = 0; vector < vectorsToAdd; ++vector) {
                Map currentMap = hits[vector].getSourceAsMap();
                for (int pathPart = 0; pathPart < fieldPath.length - 1; ++pathPart) {
                    currentMap = (Map)currentMap.get(fieldPath[pathPart]);
                }
                if (!(currentMap.get(fieldPath[fieldPath.length - 1]) instanceof List)) {
                    ++nullVectorCount;
                    continue;
                }
                List fieldList = (List)currentMap.get(fieldPath[fieldPath.length - 1]);
                trainingData.add((Float[])fieldList.stream().map(Number::floatValue).toArray(Float[]::new));
            }
            if (nullVectorCount > 0) {
                logger.warn("Found {} documents with null vectors in field {}", (Object)nullVectorCount, (Object)this.fieldName);
            }
            return trainingData;
        }
    }
}

