/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.io.IOException;
import java.util.Arrays;
import java.util.Locale;
import java.util.Optional;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.KNNQuery;
import org.opensearch.knn.index.util.KNNEngine;

public class KNNQueryFactory {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNQueryFactory.class);

    public static Query create(KNNEngine knnEngine, String indexName, String fieldName, float[] vector, int k, VectorDataType vectorDataType) {
        CreateQueryRequest createQueryRequest = CreateQueryRequest.builder().knnEngine(knnEngine).indexName(indexName).fieldName(fieldName).vector(vector).vectorDataType(vectorDataType).k(k).build();
        return KNNQueryFactory.create(createQueryRequest);
    }

    public static Query create(CreateQueryRequest createQueryRequest) {
        String indexName = createQueryRequest.getIndexName();
        String fieldName = createQueryRequest.getFieldName();
        int k = createQueryRequest.getK();
        float[] vector = createQueryRequest.getVector();
        byte[] byteVector = createQueryRequest.getByteVector();
        VectorDataType vectorDataType = createQueryRequest.getVectorDataType();
        Query filterQuery = KNNQueryFactory.getFilterQuery(createQueryRequest);
        if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
            if (filterQuery != null && KNNEngine.getEnginesThatSupportsFilters().contains(createQueryRequest.getKnnEngine())) {
                log.debug(String.format("Creating custom k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
                return new KNNQuery(fieldName, vector, k, indexName, filterQuery);
            }
            log.debug(String.format("Creating custom k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
            return new KNNQuery(fieldName, vector, k, indexName);
        }
        if (VectorDataType.BYTE == vectorDataType) {
            return KNNQueryFactory.getKnnByteVectorQuery(indexName, fieldName, byteVector, k, filterQuery);
        }
        if (VectorDataType.FLOAT == vectorDataType) {
            return KNNQueryFactory.getKnnFloatVectorQuery(indexName, fieldName, vector, k, filterQuery);
        }
        throw new IllegalArgumentException(String.format(Locale.ROOT, "Invalid value provided for [%s] field. Supported values are [%s]", "data_type", VectorDataType.SUPPORTED_VECTOR_DATA_TYPES));
    }

    private static Query getKnnByteVectorQuery(String indexName, String fieldName, byte[] byteVector, int k, Query filterQuery) {
        if (filterQuery != null) {
            log.debug(String.format(Locale.ROOT, "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
            return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
        }
        log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
        return new KnnByteVectorQuery(fieldName, byteVector, k);
    }

    private static Query getKnnFloatVectorQuery(String indexName, String fieldName, float[] floatVector, int k, Query filterQuery) {
        if (filterQuery != null) {
            log.debug(String.format(Locale.ROOT, "Creating Lucene k-NN query with filters for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
            return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
        }
        log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
        return new KnnFloatVectorQuery(fieldName, floatVector, k);
    }

    private static Query getFilterQuery(CreateQueryRequest createQueryRequest) {
        if (createQueryRequest.getFilter().isPresent()) {
            QueryShardContext queryShardContext = createQueryRequest.getContext().orElseThrow(() -> new RuntimeException("Shard context cannot be null"));
            log.debug(String.format("Creating k-NN query with filter for index [%s], field [%s] and k [%d]", createQueryRequest.getIndexName(), createQueryRequest.fieldName, createQueryRequest.k));
            try {
                return createQueryRequest.getFilter().get().toQuery(queryShardContext);
            }
            catch (IOException e) {
                throw new RuntimeException("Cannot create knn query with filter", e);
            }
        }
        return null;
    }

    static class CreateQueryRequest {
        @NonNull
        private KNNEngine knnEngine;
        @NonNull
        private String indexName;
        private String fieldName;
        private float[] vector;
        private byte[] byteVector;
        private VectorDataType vectorDataType;
        private int k;
        private QueryBuilder filter;
        private QueryShardContext context;

        public Optional<QueryBuilder> getFilter() {
            return Optional.ofNullable(this.filter);
        }

        public Optional<QueryShardContext> getContext() {
            return Optional.ofNullable(this.context);
        }

        @Generated
        public static CreateQueryRequestBuilder builder() {
            return new CreateQueryRequestBuilder();
        }

        @Generated
        public CreateQueryRequest(@NonNull KNNEngine knnEngine, @NonNull String indexName, String fieldName, float[] vector, byte[] byteVector, VectorDataType vectorDataType, int k, QueryBuilder filter, QueryShardContext context) {
            if (knnEngine == null) {
                throw new NullPointerException("knnEngine is marked non-null but is null");
            }
            if (indexName == null) {
                throw new NullPointerException("indexName is marked non-null but is null");
            }
            this.knnEngine = knnEngine;
            this.indexName = indexName;
            this.fieldName = fieldName;
            this.vector = vector;
            this.byteVector = byteVector;
            this.vectorDataType = vectorDataType;
            this.k = k;
            this.filter = filter;
            this.context = context;
        }

        @Generated
        public void setKnnEngine(@NonNull KNNEngine knnEngine) {
            if (knnEngine == null) {
                throw new NullPointerException("knnEngine is marked non-null but is null");
            }
            this.knnEngine = knnEngine;
        }

        @Generated
        public void setIndexName(@NonNull String indexName) {
            if (indexName == null) {
                throw new NullPointerException("indexName is marked non-null but is null");
            }
            this.indexName = indexName;
        }

        @Generated
        public void setFieldName(String fieldName) {
            this.fieldName = fieldName;
        }

        @Generated
        public void setVector(float[] vector) {
            this.vector = vector;
        }

        @Generated
        public void setByteVector(byte[] byteVector) {
            this.byteVector = byteVector;
        }

        @Generated
        public void setVectorDataType(VectorDataType vectorDataType) {
            this.vectorDataType = vectorDataType;
        }

        @Generated
        public void setK(int k) {
            this.k = k;
        }

        @Generated
        public void setFilter(QueryBuilder filter) {
            this.filter = filter;
        }

        @Generated
        public void setContext(QueryShardContext context) {
            this.context = context;
        }

        @NonNull
        @Generated
        public KNNEngine getKnnEngine() {
            return this.knnEngine;
        }

        @NonNull
        @Generated
        public String getIndexName() {
            return this.indexName;
        }

        @Generated
        public String getFieldName() {
            return this.fieldName;
        }

        @Generated
        public float[] getVector() {
            return this.vector;
        }

        @Generated
        public byte[] getByteVector() {
            return this.byteVector;
        }

        @Generated
        public VectorDataType getVectorDataType() {
            return this.vectorDataType;
        }

        @Generated
        public int getK() {
            return this.k;
        }

        @Generated
        public static class CreateQueryRequestBuilder {
            @Generated
            private KNNEngine knnEngine;
            @Generated
            private String indexName;
            @Generated
            private String fieldName;
            @Generated
            private float[] vector;
            @Generated
            private byte[] byteVector;
            @Generated
            private VectorDataType vectorDataType;
            @Generated
            private int k;
            @Generated
            private QueryBuilder filter;
            @Generated
            private QueryShardContext context;

            @Generated
            CreateQueryRequestBuilder() {
            }

            @Generated
            public CreateQueryRequestBuilder knnEngine(@NonNull KNNEngine knnEngine) {
                if (knnEngine == null) {
                    throw new NullPointerException("knnEngine is marked non-null but is null");
                }
                this.knnEngine = knnEngine;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder indexName(@NonNull String indexName) {
                if (indexName == null) {
                    throw new NullPointerException("indexName is marked non-null but is null");
                }
                this.indexName = indexName;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder fieldName(String fieldName) {
                this.fieldName = fieldName;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder vector(float[] vector) {
                this.vector = vector;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder byteVector(byte[] byteVector) {
                this.byteVector = byteVector;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder vectorDataType(VectorDataType vectorDataType) {
                this.vectorDataType = vectorDataType;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder k(int k) {
                this.k = k;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder filter(QueryBuilder filter) {
                this.filter = filter;
                return this;
            }

            @Generated
            public CreateQueryRequestBuilder context(QueryShardContext context) {
                this.context = context;
                return this;
            }

            @Generated
            public CreateQueryRequest build() {
                return new CreateQueryRequest(this.knnEngine, this.indexName, this.fieldName, this.vector, this.byteVector, this.vectorDataType, this.k, this.filter, this.context);
            }

            @Generated
            public String toString() {
                return "KNNQueryFactory.CreateQueryRequest.CreateQueryRequestBuilder(knnEngine=" + this.knnEngine + ", indexName=" + this.indexName + ", fieldName=" + this.fieldName + ", vector=" + Arrays.toString(this.vector) + ", byteVector=" + Arrays.toString(this.byteVector) + ", vectorDataType=" + this.vectorDataType + ", k=" + this.k + ", filter=" + this.filter + ", context=" + this.context + ")";
            }
        }
    }
}

