/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query;
import org.opensearch.core.ParseField;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.Strings;
import org.opensearch.core.common.io.stream.NamedWriteable;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentLocation;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.common.KNNValidationUtil;
import org.opensearch.knn.index.IndexUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.VectorQueryType;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.BaseQueryFactory;
import org.opensearch.knn.index.query.KNNQueryFactory;
import org.opensearch.knn.index.query.RNNQueryFactory;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;

public class KNNQueryBuilder
extends AbstractQueryBuilder<KNNQueryBuilder> {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNQueryBuilder.class);
    private static ModelDao modelDao;
    public static final ParseField VECTOR_FIELD;
    public static final ParseField K_FIELD;
    public static final ParseField FILTER_FIELD;
    public static final ParseField IGNORE_UNMAPPED_FIELD;
    public static final ParseField MAX_DISTANCE_FIELD;
    public static final ParseField MIN_SCORE_FIELD;
    public static final int K_MAX = 10000;
    public static final String NAME = "knn";
    private final String fieldName;
    private final float[] vector;
    private int k = 0;
    private Float maxDistance = null;
    private Float minScore = null;
    private QueryBuilder filter;
    private boolean ignoreUnmapped = false;

    public KNNQueryBuilder(String fieldName, float[] vector) {
        if (Strings.isNullOrEmpty((String)fieldName)) {
            throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
        }
        if (vector == null) {
            throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME));
        }
        if (vector.length == 0) {
            throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME));
        }
        this.fieldName = fieldName;
        this.vector = vector;
    }

    public KNNQueryBuilder k(Integer k) {
        if (k == null) {
            throw new IllegalArgumentException(String.format("[%s] requires k to be set", NAME));
        }
        KNNQueryBuilder.validateSingleQueryType(k, this.maxDistance, this.minScore);
        if (k <= 0 || k > 10000) {
            throw new IllegalArgumentException(String.format("[%s] requires k to be in the range (0, %d]", NAME, 10000));
        }
        this.k = k;
        return this;
    }

    public KNNQueryBuilder maxDistance(Float maxDistance) {
        if (maxDistance == null) {
            throw new IllegalArgumentException(String.format("[%s] requires maxDistance to be set", NAME));
        }
        KNNQueryBuilder.validateSingleQueryType(this.k, maxDistance, this.minScore);
        this.maxDistance = maxDistance;
        return this;
    }

    public KNNQueryBuilder minScore(Float minScore) {
        if (minScore == null) {
            throw new IllegalArgumentException(String.format("[%s] requires minScore to be set", NAME));
        }
        KNNQueryBuilder.validateSingleQueryType(this.k, this.maxDistance, minScore);
        if (minScore.floatValue() <= 0.0f) {
            throw new IllegalArgumentException(String.format("[%s] requires minScore to be greater than 0", NAME));
        }
        this.minScore = minScore;
        return this;
    }

    public KNNQueryBuilder filter(QueryBuilder filter) {
        this.filter = filter;
        return this;
    }

    public KNNQueryBuilder(String fieldName, float[] vector, int k) {
        this(fieldName, vector, k, null);
    }

    public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
        if (StringUtils.isBlank((String)fieldName)) {
            throw new IllegalArgumentException(String.format("[%s] requires fieldName", NAME));
        }
        if (vector == null) {
            throw new IllegalArgumentException(String.format("[%s] requires query vector", NAME));
        }
        if (vector.length == 0) {
            throw new IllegalArgumentException(String.format("[%s] query vector is empty", NAME));
        }
        if (k <= 0) {
            throw new IllegalArgumentException(String.format("[%s] requires k > 0", NAME));
        }
        if (k > 10000) {
            throw new IllegalArgumentException(String.format("[%s] requires k <= %d", NAME, 10000));
        }
        this.fieldName = fieldName;
        this.vector = vector;
        this.k = k;
        this.filter = filter;
        this.ignoreUnmapped = false;
        this.maxDistance = null;
        this.minScore = null;
    }

    public static void initialize(ModelDao modelDao) {
        KNNQueryBuilder.modelDao = modelDao;
    }

    private static float[] ObjectsToFloats(List<Object> objs) {
        if (Objects.isNull(objs) || objs.isEmpty()) {
            throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be non-null and non-empty", NAME));
        }
        float[] vec = new float[objs.size()];
        for (int i = 0; i < objs.size(); ++i) {
            if (!(objs.get(i) instanceof Number)) {
                throw new IllegalArgumentException(String.format("[%s] field 'vector' requires to be an array of numbers", NAME));
            }
            vec[i] = ((Number)objs.get(i)).floatValue();
        }
        return vec;
    }

    public KNNQueryBuilder(StreamInput in) throws IOException {
        super(in);
        try {
            this.fieldName = in.readString();
            this.vector = in.readFloatArray();
            this.k = in.readInt();
            if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("filter")) {
                this.filter = (QueryBuilder)in.readOptionalNamedWriteable(QueryBuilder.class);
            }
            if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
                this.ignoreUnmapped = in.readOptionalBoolean();
            }
            if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("radial_search")) {
                this.maxDistance = in.readOptionalFloat();
            }
            if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("radial_search")) {
                this.minScore = in.readOptionalFloat();
            }
        }
        catch (IOException ex) {
            throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
        }
    }

    public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOException {
        XContentParser.Token token;
        String fieldName = null;
        List vector = null;
        float boost = 1.0f;
        Integer k = null;
        Float maxDistance = null;
        Float minScore = null;
        QueryBuilder filter = null;
        boolean ignoreUnmapped = false;
        String queryName = null;
        String currentFieldName = null;
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token == XContentParser.Token.START_OBJECT) {
                KNNQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)currentFieldName);
                fieldName = currentFieldName;
                while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                    if (token == XContentParser.Token.FIELD_NAME) {
                        currentFieldName = parser.currentName();
                        continue;
                    }
                    if (token.isValue() || token == XContentParser.Token.START_ARRAY) {
                        if (VECTOR_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            vector = parser.list();
                            continue;
                        }
                        if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            boost = parser.floatValue();
                            continue;
                        }
                        if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            k = (Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
                            continue;
                        }
                        if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            queryName = parser.text();
                            continue;
                        }
                        if (MAX_DISTANCE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            maxDistance = (Float)NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
                            continue;
                        }
                        if (MIN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            minScore = (Float)NumberFieldMapper.NumberType.FLOAT.parse(parser.objectBytes(), false);
                            continue;
                        }
                        if (IGNORE_UNMAPPED_FIELD.getPreferredName().equals(currentFieldName)) {
                            if (!IndexUtil.isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) continue;
                            ignoreUnmapped = parser.booleanValue();
                            continue;
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[knn] query does not support [" + currentFieldName + "]", new Object[0]);
                    }
                    if (token == XContentParser.Token.START_OBJECT) {
                        String tokenName = parser.currentName();
                        if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
                            log.debug(String.format("Start parsing filter for field [%s]", fieldName));
                            if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("filter")) {
                                filter = KNNQueryBuilder.parseInnerQueryBuilder((XContentParser)parser);
                                continue;
                            }
                            log.debug(String.format("This version of k-NN doesn't support [filter] field, minimal required version is [%s]", IndexUtil.minimalRequiredVersionMap.get("filter")));
                            throw new IllegalArgumentException(String.format("%s field is supported from version %s", FILTER_FIELD.getPreferredName(), IndexUtil.minimalRequiredVersionMap.get("filter")));
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[knn] unknown token [" + token + "]", new Object[0]);
                    }
                    throw new ParsingException(parser.getTokenLocation(), "[knn] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
                }
                continue;
            }
            KNNQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)parser.currentName());
            fieldName = parser.currentName();
            vector = parser.list();
        }
        VectorQueryType vectorQueryType = KNNQueryBuilder.validateSingleQueryType(k, maxDistance, minScore);
        vectorQueryType.getQueryStatCounter().increment();
        if (filter != null) {
            vectorQueryType.getQueryWithFilterStatCounter().increment();
        }
        KNNQueryBuilder knnQueryBuilder = (KNNQueryBuilder)((KNNQueryBuilder)new KNNQueryBuilder(fieldName, KNNQueryBuilder.ObjectsToFloats(vector)).filter(filter).boost(boost)).queryName(queryName);
        if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("ignoreUnmapped")) {
            knnQueryBuilder.ignoreUnmapped(ignoreUnmapped);
        }
        if (k != null) {
            knnQueryBuilder.k(k);
        } else if (maxDistance != null) {
            knnQueryBuilder.maxDistance(maxDistance);
        } else if (minScore != null) {
            knnQueryBuilder.minScore(minScore);
        }
        return knnQueryBuilder;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeFloatArray(this.vector);
        out.writeInt(this.k);
        if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("filter")) {
            out.writeOptionalNamedWriteable((NamedWriteable)this.filter);
        }
        if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
            out.writeOptionalBoolean(Boolean.valueOf(this.ignoreUnmapped));
        }
        if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("radial_search")) {
            out.writeOptionalFloat(this.maxDistance);
        }
        if (IndexUtil.isClusterOnOrAfterMinRequiredVersion("radial_search")) {
            out.writeOptionalFloat(this.minScore);
        }
    }

    public String fieldName() {
        return this.fieldName;
    }

    public Object vector() {
        return this.vector;
    }

    public int getK() {
        return this.k;
    }

    public float getMaxDistance() {
        return this.maxDistance.floatValue();
    }

    public float getMinScore() {
        return this.minScore.floatValue();
    }

    public QueryBuilder getFilter() {
        return this.filter;
    }

    public KNNQueryBuilder ignoreUnmapped(boolean ignoreUnmapped) {
        this.ignoreUnmapped = ignoreUnmapped;
        return this;
    }

    public boolean getIgnoreUnmapped() {
        return this.ignoreUnmapped;
    }

    public void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startObject(this.fieldName);
        builder.field(VECTOR_FIELD.getPreferredName(), (Object)this.vector);
        builder.field(K_FIELD.getPreferredName(), this.k);
        if (this.filter != null) {
            builder.field(FILTER_FIELD.getPreferredName(), (ToXContent)this.filter);
        }
        if (this.maxDistance != null) {
            builder.field(MAX_DISTANCE_FIELD.getPreferredName(), this.maxDistance);
        }
        if (this.ignoreUnmapped) {
            builder.field(IGNORE_UNMAPPED_FIELD.getPreferredName(), this.ignoreUnmapped);
        }
        if (this.minScore != null) {
            builder.field(MIN_SCORE_FIELD.getPreferredName(), this.minScore);
        }
        this.printBoostAndQueryName(builder);
        builder.endObject();
        builder.endObject();
    }

    protected Query doToQuery(QueryShardContext context) {
        MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);
        if (mappedFieldType == null && this.ignoreUnmapped) {
            return new MatchNoDocsQuery();
        }
        if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
            throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
        }
        KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType)mappedFieldType;
        int fieldDimension = knnVectorFieldType.getDimension();
        KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext();
        KNNEngine knnEngine = KNNEngine.DEFAULT;
        VectorDataType vectorDataType = knnVectorFieldType.getVectorDataType();
        SpaceType spaceType = knnVectorFieldType.getSpaceType();
        if (fieldDimension == -1) {
            if (spaceType != null) {
                throw new IllegalStateException("Space type should be null when the field uses a model");
            }
            ModelMetadata modelMetadata = this.getModelMetadataForField(knnVectorFieldType);
            fieldDimension = modelMetadata.getDimension();
            knnEngine = modelMetadata.getKnnEngine();
            spaceType = modelMetadata.getSpaceType();
        } else if (knnMethodContext != null) {
            knnEngine = knnMethodContext.getKnnEngine();
            spaceType = knnMethodContext.getSpaceType();
        }
        Float radius = null;
        if (this.maxDistance != null) {
            if (this.maxDistance.floatValue() < 0.0f && !SpaceType.INNER_PRODUCT.equals((Object)spaceType)) {
                throw new IllegalArgumentException(String.format("[knn] requires distance to be non-negative for space type: %s", new Object[]{spaceType}));
            }
            radius = knnEngine.distanceToRadialThreshold(this.maxDistance, spaceType);
        }
        if (this.minScore != null) {
            if (this.minScore.floatValue() > 1.0f && !SpaceType.INNER_PRODUCT.equals((Object)spaceType)) {
                throw new IllegalArgumentException(String.format("[knn] requires score to be in the range [0, 1] for space type: %s", new Object[]{spaceType}));
            }
            radius = knnEngine.scoreToRadialThreshold(this.minScore, spaceType);
        }
        if (fieldDimension != this.vector.length) {
            throw new IllegalArgumentException(String.format("Query vector has invalid dimension: %d. Dimension should be: %d", this.vector.length, fieldDimension));
        }
        byte[] byteVector = new byte[]{};
        if (VectorDataType.BYTE == vectorDataType) {
            byteVector = new byte[this.vector.length];
            for (int i = 0; i < this.vector.length; ++i) {
                KNNValidationUtil.validateByteVectorValue(this.vector[i]);
                byteVector[i] = (byte)this.vector[i];
            }
            spaceType.validateVector(byteVector);
        } else {
            spaceType.validateVector(this.vector);
        }
        if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && this.filter != null && !KNNEngine.getEnginesThatSupportsFilters().contains(knnEngine)) {
            throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
        }
        String indexName = context.index().getName();
        if (this.k != 0) {
            BaseQueryFactory.CreateQueryRequest createQueryRequest = BaseQueryFactory.CreateQueryRequest.builder().knnEngine(knnEngine).indexName(indexName).fieldName(this.fieldName).vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null).byteVector((byte[])(VectorDataType.BYTE == vectorDataType ? byteVector : null)).vectorDataType(vectorDataType).k(this.k).filter(this.filter).context(context).build();
            return KNNQueryFactory.create(createQueryRequest);
        }
        if (radius != null) {
            if (!KNNEngine.ENGINES_SUPPORTING_RADIAL_SEARCH.contains(knnEngine)) {
                throw new UnsupportedOperationException(String.format("Engine [%s] does not support radial search", knnEngine));
            }
            BaseQueryFactory.CreateQueryRequest createQueryRequest = BaseQueryFactory.CreateQueryRequest.builder().knnEngine(knnEngine).indexName(indexName).fieldName(this.fieldName).vector(VectorDataType.FLOAT == vectorDataType ? this.vector : null).byteVector((byte[])(VectorDataType.BYTE == vectorDataType ? byteVector : null)).vectorDataType(vectorDataType).radius(radius).filter(this.filter).context(context).build();
            return RNNQueryFactory.create(createQueryRequest);
        }
        throw new IllegalArgumentException(String.format("[%s] requires k or distance or score to be set", NAME));
    }

    private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
        String modelId = knnVectorField.getModelId();
        if (modelId == null) {
            throw new IllegalArgumentException(String.format("Field '%s' does not have model.", this.fieldName));
        }
        ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
        if (!ModelUtil.isModelCreated(modelMetadata)) {
            throw new IllegalArgumentException(String.format("Model ID '%s' is not created.", modelId));
        }
        return modelMetadata;
    }

    protected boolean doEquals(KNNQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Arrays.equals(this.vector, other.vector) && Objects.equals(this.k, other.k) && Objects.equals(this.filter, other.filter) && Objects.equals(this.ignoreUnmapped, other.ignoreUnmapped);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, Arrays.hashCode(this.vector), this.k, this.filter, this.ignoreUnmapped);
    }

    public String getWriteableName() {
        return NAME;
    }

    private static VectorQueryType validateSingleQueryType(Integer k, Float distance, Float score) {
        int countSetFields = 0;
        VectorQueryType vectorQueryType = null;
        if (k != null && k != 0) {
            ++countSetFields;
            vectorQueryType = VectorQueryType.K;
        }
        if (distance != null) {
            ++countSetFields;
            vectorQueryType = VectorQueryType.MAX_DISTANCE;
        }
        if (score != null) {
            ++countSetFields;
            vectorQueryType = VectorQueryType.MIN_SCORE;
        }
        if (countSetFields != 1) {
            throw new IllegalArgumentException(String.format("[%s] requires exactly one of k, distance or score to be set", NAME));
        }
        return vectorQueryType;
    }

    static {
        VECTOR_FIELD = new ParseField("vector", new String[0]);
        K_FIELD = new ParseField("k", new String[0]);
        FILTER_FIELD = new ParseField("filter", new String[0]);
        IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped", new String[0]);
        MAX_DISTANCE_FIELD = new ParseField("max_distance", new String[0]);
        MIN_SCORE_FIELD = new ParseField("min_score", new String[0]);
    }
}

