/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.ByteVectorSimilarityQuery;
import org.apache.lucene.search.FloatVectorSimilarityQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.BaseQueryFactory;
import org.opensearch.knn.index.query.KNNQuery;

public class RNNQueryFactory
extends BaseQueryFactory {
    @Generated
    private static final Logger log = LogManager.getLogger(RNNQueryFactory.class);

    public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, Float radius, VectorDataType vectorDataType) {
        BaseQueryFactory.CreateQueryRequest createQueryRequest = BaseQueryFactory.CreateQueryRequest.builder().knnEngine(knnEngine).indexName(indexName).fieldName(fieldName).vector(vector).vectorDataType(vectorDataType).radius(radius).build();
        return RNNQueryFactory.create(createQueryRequest);
    }

    public static Query create(BaseQueryFactory.CreateQueryRequest createQueryRequest) {
        String indexName = createQueryRequest.getIndexName();
        String fieldName = createQueryRequest.getFieldName();
        Float radius = createQueryRequest.getRadius();
        float[] vector = createQueryRequest.getVector();
        byte[] byteVector = createQueryRequest.getByteVector();
        VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
        Query filterQuery = RNNQueryFactory.getFilterQuery(createQueryRequest);
        Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
        if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
            BitSetProducer parentFilter = null;
            QueryShardContext context = createQueryRequest.getContext().get();
            if (createQueryRequest.getContext().isPresent()) {
                parentFilter = context.getParentFilter();
            }
            IndexSettings indexSettings = context.getIndexSettings();
            KNNQuery.Context knnQueryContext = new KNNQuery.Context(indexSettings.getMaxResultWindow());
            return KNNQuery.builder().field(fieldName).queryVector(vector).indexName(indexName).parentsFilter(parentFilter).radius(radius).methodParameters(methodParameters).context(knnQueryContext).filterQuery(filterQuery).build();
        }
        log.debug(String.format("Creating Lucene r-NN query for index: %s \"\", field: %s \"\", k: %f", indexName, fieldName, radius));
        switch (vectorDataType) {
            case BYTE: {
                return RNNQueryFactory.getByteVectorSimilarityQuery(fieldName, byteVector, radius.floatValue(), filterQuery);
            }
            case FLOAT: {
                return RNNQueryFactory.getFloatVectorSimilarityQuery(fieldName, vector, radius.floatValue(), filterQuery);
            }
        }
        throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid value provided for [%s] field. Supported values are [%s], but got: %s", new Object[]{"data_type", VectorDataType.SUPPORTED_VECTOR_DATA_TYPES, vectorDataType}));
    }

    private static Query getFloatVectorSimilarityQuery(String fieldName, float[] floatVector, float resultSimilarity, Query filterQuery) {
        return new FloatVectorSimilarityQuery(fieldName, floatVector, KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO.floatValue() * resultSimilarity, resultSimilarity, filterQuery);
    }

    private static Query getByteVectorSimilarityQuery(String fieldName, byte[] byteVector, float resultSimilarity, Query filterQuery) {
        return new ByteVectorSimilarityQuery(fieldName, byteVector, KNNConstants.DEFAULT_LUCENE_RADIAL_SEARCH_TRAVERSAL_SIMILARITY_RATIO.floatValue() * resultSimilarity, resultSimilarity, filterQuery);
    }
}

