/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.index.query;

import java.io.IOException;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.ParseField;
import org.opensearch.common.ParsingException;
import org.opensearch.common.Strings;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentLocation;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.query.KNNQueryFactory;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.plugin.stats.KNNCounter;

public class KNNQueryBuilder
extends AbstractQueryBuilder<KNNQueryBuilder> {
    @Generated
    private static final Logger log = LogManager.getLogger(KNNQueryBuilder.class);
    private static ModelDao modelDao;
    public static final ParseField VECTOR_FIELD;
    public static final ParseField K_FIELD;
    public static final ParseField FILTER_FIELD;
    public static int K_MAX;
    public static final String NAME = "knn";
    private final String fieldName;
    private final float[] vector;
    private int k = 0;
    private QueryBuilder filter;
    private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER;

    public KNNQueryBuilder(String fieldName, float[] vector, int k) {
        this(fieldName, vector, k, null);
    }

    public KNNQueryBuilder(String fieldName, float[] vector, int k, QueryBuilder filter) {
        if (Strings.isNullOrEmpty((String)fieldName)) {
            throw new IllegalArgumentException("[knn] requires fieldName");
        }
        if (vector == null) {
            throw new IllegalArgumentException("[knn] requires query vector");
        }
        if (vector.length == 0) {
            throw new IllegalArgumentException("[knn] query vector is empty");
        }
        if (k <= 0) {
            throw new IllegalArgumentException("[knn] requires k > 0");
        }
        if (k > K_MAX) {
            throw new IllegalArgumentException("[knn] requires k <= " + K_MAX);
        }
        this.fieldName = fieldName;
        this.vector = vector;
        this.k = k;
        this.filter = filter;
    }

    public static void initialize(ModelDao modelDao) {
        KNNQueryBuilder.modelDao = modelDao;
    }

    private static float[] ObjectsToFloats(List<Object> objs) {
        float[] vec = new float[objs.size()];
        for (int i = 0; i < objs.size(); ++i) {
            vec[i] = ((Number)objs.get(i)).floatValue();
        }
        return vec;
    }

    public KNNQueryBuilder(StreamInput in) throws IOException {
        super(in);
        try {
            this.fieldName = in.readString();
            this.vector = in.readFloatArray();
            this.k = in.readInt();
            if (KNNQueryBuilder.isClusterOnOrAfterMinRequiredVersion()) {
                this.filter = (QueryBuilder)in.readOptionalNamedWriteable(QueryBuilder.class);
            }
        }
        catch (IOException ex) {
            throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
        }
    }

    public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOException {
        XContentParser.Token token;
        String fieldName = null;
        List vector = null;
        float boost = 1.0f;
        int k = 0;
        QueryBuilder filter = null;
        String queryName = null;
        String currentFieldName = null;
        KNNCounter.KNN_QUERY_REQUESTS.increment();
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token == XContentParser.Token.START_OBJECT) {
                KNNQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)currentFieldName);
                fieldName = currentFieldName;
                while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
                    if (token == XContentParser.Token.FIELD_NAME) {
                        currentFieldName = parser.currentName();
                        continue;
                    }
                    if (token.isValue() || token == XContentParser.Token.START_ARRAY) {
                        if (VECTOR_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            vector = parser.list();
                            continue;
                        }
                        if (AbstractQueryBuilder.BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            boost = parser.floatValue();
                            continue;
                        }
                        if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            k = (Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false);
                            continue;
                        }
                        if (AbstractQueryBuilder.NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                            queryName = parser.text();
                            continue;
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[knn] query does not support [" + currentFieldName + "]", new Object[0]);
                    }
                    if (token == XContentParser.Token.START_OBJECT) {
                        String tokenName = parser.currentName();
                        if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
                            log.debug(String.format("Start parsing filter for field [%s]", fieldName));
                            KNNCounter.KNN_QUERY_WITH_FILTER_REQUESTS.increment();
                            if (KNNQueryBuilder.isClusterOnOrAfterMinRequiredVersion()) {
                                filter = KNNQueryBuilder.parseInnerQueryBuilder((XContentParser)parser);
                                continue;
                            }
                            log.debug(String.format("This version of k-NN doesn't support [filter] field, minimal required version is [%s]", MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER));
                            throw new IllegalArgumentException(String.format("%s field is supported from version %s", FILTER_FIELD.getPreferredName(), MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER));
                        }
                        throw new ParsingException(parser.getTokenLocation(), "[knn] unknown token [" + token + "]", new Object[0]);
                    }
                    throw new ParsingException(parser.getTokenLocation(), "[knn] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
                }
                continue;
            }
            KNNQueryBuilder.throwParsingExceptionOnMultipleFields((String)NAME, (XContentLocation)parser.getTokenLocation(), fieldName, (String)parser.currentName());
            fieldName = parser.currentName();
            vector = parser.list();
        }
        KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(fieldName, KNNQueryBuilder.ObjectsToFloats(vector), k, filter);
        knnQueryBuilder.queryName(queryName);
        knnQueryBuilder.boost(boost);
        return knnQueryBuilder;
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeFloatArray(this.vector);
        out.writeInt(this.k);
        if (KNNQueryBuilder.isClusterOnOrAfterMinRequiredVersion()) {
            out.writeOptionalNamedWriteable((NamedWriteable)this.filter);
        }
    }

    public String fieldName() {
        return this.fieldName;
    }

    public Object vector() {
        return this.vector;
    }

    public int getK() {
        return this.k;
    }

    public QueryBuilder getFilter() {
        return this.filter;
    }

    public void doXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject(NAME);
        builder.startObject(this.fieldName);
        builder.field(VECTOR_FIELD.getPreferredName(), (Object)this.vector);
        builder.field(K_FIELD.getPreferredName(), this.k);
        if (this.filter != null) {
            builder.field(FILTER_FIELD.getPreferredName(), (ToXContent)this.filter);
        }
        this.printBoostAndQueryName(builder);
        builder.endObject();
        builder.endObject();
    }

    protected Query doToQuery(QueryShardContext context) {
        MappedFieldType mappedFieldType = context.fieldMapper(this.fieldName);
        if (!(mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType)) {
            throw new IllegalArgumentException(String.format("Field '%s' is not knn_vector type.", this.fieldName));
        }
        KNNVectorFieldMapper.KNNVectorFieldType knnVectorFieldType = (KNNVectorFieldMapper.KNNVectorFieldType)mappedFieldType;
        int fieldDimension = knnVectorFieldType.getDimension();
        KNNMethodContext knnMethodContext = knnVectorFieldType.getKnnMethodContext();
        KNNEngine knnEngine = KNNEngine.DEFAULT;
        if (fieldDimension == -1) {
            ModelMetadata modelMetadata = this.getModelMetadataForField(knnVectorFieldType);
            fieldDimension = modelMetadata.getDimension();
            knnEngine = modelMetadata.getKnnEngine();
        } else if (knnMethodContext != null) {
            knnEngine = knnMethodContext.getKnnEngine();
        }
        if (fieldDimension != this.vector.length) {
            throw new IllegalArgumentException(String.format("Query vector has invalid dimension: %d. Dimension should be: %d", this.vector.length, fieldDimension));
        }
        if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine) && this.filter != null) {
            throw new IllegalArgumentException(String.format("Engine [%s] does not support filters", knnEngine));
        }
        String indexName = context.index().getName();
        KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder().knnEngine(knnEngine).indexName(indexName).fieldName(this.fieldName).vector(this.vector).k(this.k).filter(this.filter).context(context).build();
        return KNNQueryFactory.create(createQueryRequest);
    }

    private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFieldType knnVectorField) {
        String modelId = knnVectorField.getModelId();
        if (modelId == null) {
            throw new IllegalArgumentException(String.format("Field '%s' does not have model.", this.fieldName));
        }
        ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
        if (modelMetadata == null) {
            throw new IllegalArgumentException(String.format("Model ID '%s' does not exist.", modelId));
        }
        return modelMetadata;
    }

    protected boolean doEquals(KNNQueryBuilder other) {
        return Objects.equals(this.fieldName, other.fieldName) && Objects.equals(this.vector, other.vector) && Objects.equals(this.k, other.k);
    }

    protected int doHashCode() {
        return Objects.hash(this.fieldName, this.vector, this.k);
    }

    public String getWriteableName() {
        return NAME;
    }

    private static boolean isClusterOnOrAfterMinRequiredVersion() {
        return KNNClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
    }

    static {
        VECTOR_FIELD = new ParseField("vector", new String[0]);
        K_FIELD = new ParseField("k", new String[0]);
        FILTER_FIELD = new ParseField("filter", new String[0]);
        K_MAX = 10000;
        MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_2_4_0;
    }
}

