/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.util.Locale;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.BaseQueryFactory;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.util.KNNEngine;

public class KNNQueryFactory
extends BaseQueryFactory {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNQueryFactory.class);

    public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k, VectorDataType vectorDataType) {
        BaseQueryFactory.CreateQueryRequest createQueryRequest = BaseQueryFactory.CreateQueryRequest.builder().knnEngine(knnEngine).indexName(indexName).fieldName(fieldName).vector(vector).vectorDataType(vectorDataType).k(k).build();
        return KNNQueryFactory.create(createQueryRequest);
    }

    public static Query create(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
        String indexName = createQueryRequest.getIndexName();
        String fieldName = createQueryRequest.getFieldName();
        int k = createQueryRequest.getK();
        float[] vector = createQueryRequest.getVector();
        byte[] byteVector = createQueryRequest.getByteVector();
        VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
        Query filterQuery = KNNQueryFactory.getFilterQuery(createQueryRequest);
        BitSetProducer parentFilter = null;
        if (createQueryRequest.getContext().isPresent()) {
            QueryShardContext context = createQueryRequest.getContext().get();
            parentFilter = context.getParentFilter();
        }
        if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
            if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) {
                log.debug("Creating custom k-NN query with filters for index: {}, field: {} , k: {}", (Object)indexName, (Object)fieldName, (Object)k);
                return new KNNQuery(fieldName, vector, k, indexName, filterQuery, parentFilter);
            }
            log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
            return new KNNQuery(fieldName, vector, k, indexName, parentFilter);
        }
        log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
        switch (vectorDataType) {
            case BYTE: {
                return KNNQueryFactory.getKnnByteVectorQuery(fieldName, byteVector, k, filterQuery, parentFilter);
            }
            case FLOAT: {
                return KNNQueryFactory.getKnnFloatVectorQuery(fieldName, vector, k, filterQuery, parentFilter);
            }
        }
        throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", new Object[]{"data_type", VectorDataType.SUPPORTED_VECTOR_DATA_TYPES, vectorDataType}));
    }

    private static Query getKnnByteVectorQuery(String fieldName, byte[] byteVector, int k, Query filterQuery, BitSetProducer parentFilter) {
        if (parentFilter == null) {
            return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
        }
        return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
    }

    private static Query getKnnFloatVectorQuery(String fieldName, float[] floatVector, int k, Query filterQuery, BitSetProducer parentFilter) {
        if (parentFilter == null) {
            return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
        }
        return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
    }
}

