/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.List;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.SetOnce;
import org.opensearch.action.ActionListener;
import org.opensearch.common.ParseField;
import org.opensearch.common.ParsingException;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.common.VectorUtil;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

public class NeuralQueryBuilder
extends AbstractQueryBuilder<NeuralQueryBuilder> {
    @Generated
    private static final Logger log = LogManager.getLogger(NeuralQueryBuilder.class);
    public static final String NAME = "neural";
    @VisibleForTesting
    static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text", new String[0]);
    @VisibleForTesting
    static final ParseField MODEL_ID_FIELD = new ParseField("model_id", new String[0]);
    @VisibleForTesting
    static final ParseField K_FIELD = new ParseField("k", new String[0]);
    private static final int DEFAULT_K = 10;
    private static MLCommonsClientAccessor ML_CLIENT;
    private String fieldName;
    private String queryText;
    private String modelId;
    private int k = 10;
    @VisibleForTesting
    private Supplier<float[]> vectorSupplier;

    public static void initialize(MLCommonsClientAccessor mlClient) {
        ML_CLIENT = mlClient;
    }

    public NeuralQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.queryText = in.readString();
        this.modelId = in.readString();
        this.k = in.readVInt();
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeString(this.queryText);
        out.writeString(this.modelId);
        out.writeVInt(this.k);
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject(NAME);
        xContentBuilder.startObject(this.fieldName);
        xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), this.queryText);
        xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), this.modelId);
        xContentBuilder.field(K_FIELD.getPreferredName(), this.k);
        this.printBoostAndQueryName(xContentBuilder);
        xContentBuilder.endObject();
        xContentBuilder.endObject();
    }

    public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOException {
        NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder();
        if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "Token must be START_OBJECT", new Object[0]);
        }
        parser.nextToken();
        neuralQueryBuilder.fieldName(parser.currentName());
        parser.nextToken();
        NeuralQueryBuilder.parseQueryParams(parser, neuralQueryBuilder);
        if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "[neural] query doesn't support multiple fields, found [" + neuralQueryBuilder.fieldName() + "] and [" + parser.currentName() + "]", new Object[0]);
        }
        NeuralQueryBuilder.requireValue((Object)neuralQueryBuilder.queryText(), (String)"Query text must be provided for neural query");
        NeuralQueryBuilder.requireValue((Object)neuralQueryBuilder.fieldName(), (String)"Field name must be provided for neural query");
        NeuralQueryBuilder.requireValue((Object)neuralQueryBuilder.modelId(), (String)"Model ID must be provided for neural query");
        return neuralQueryBuilder;
    }

    private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder neuralQueryBuilder) throws IOException {
        XContentParser.Token token;
        String currentFieldName = "";
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token.isValue()) {
                if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryText(parser.text());
                    continue;
                }
                if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.modelId(parser.text());
                    continue;
                }
                if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.k((Integer)NumberFieldMapper.NumberType.INTEGER.parse(parser.objectBytes(), false));
                    continue;
                }
                if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.queryName(parser.text());
                    continue;
                }
                if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    neuralQueryBuilder.boost(parser.floatValue());
                    continue;
                }
                throw new ParsingException(parser.getTokenLocation(), "[neural] query does not support [" + currentFieldName + "]", new Object[0]);
            }
            throw new ParsingException(parser.getTokenLocation(), "[neural] unknown token [" + token + "] after [" + currentFieldName + "]", new Object[0]);
        }
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) {
        if (this.vectorSupplier() != null) {
            return this.vectorSupplier().get() == null ? this : new KNNQueryBuilder(this.fieldName(), this.vectorSupplier.get(), this.k());
        }
        SetOnce vectorSetOnce = new SetOnce();
        queryRewriteContext.registerAsyncAction((client, actionListener) -> ML_CLIENT.inferenceSentence(this.modelId(), this.queryText(), (ActionListener<List<Float>>)ActionListener.wrap(floatList -> {
            vectorSetOnce.set((Object)VectorUtil.vectorAsListToArray(floatList));
            actionListener.onResponse(null);
        }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))));
        return new NeuralQueryBuilder(this.fieldName(), this.queryText(), this.modelId(), this.k(), () -> ((SetOnce)vectorSetOnce).get());
    }

    protected Query doToQuery(QueryShardContext queryShardContext) {
        throw new UnsupportedOperationException("Query cannot be created by NeuralQueryBuilder directly");
    }

    protected boolean doEquals(NeuralQueryBuilder obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || ((Object)((Object)this)).getClass() != ((Object)((Object)obj)).getClass()) {
            return false;
        }
        EqualsBuilder equalsBuilder = new EqualsBuilder();
        equalsBuilder.append((Object)this.fieldName, (Object)obj.fieldName);
        equalsBuilder.append((Object)this.queryText, (Object)obj.queryText);
        equalsBuilder.append((Object)this.modelId, (Object)obj.modelId);
        equalsBuilder.append(this.k, obj.k);
        return equalsBuilder.isEquals();
    }

    protected int doHashCode() {
        return new HashCodeBuilder().append((Object)this.fieldName).append((Object)this.queryText).append((Object)this.modelId).append(this.k).toHashCode();
    }

    public String getWriteableName() {
        return NAME;
    }

    @Generated
    public String fieldName() {
        return this.fieldName;
    }

    @Generated
    public String queryText() {
        return this.queryText;
    }

    @Generated
    public String modelId() {
        return this.modelId;
    }

    @Generated
    public int k() {
        return this.k;
    }

    @Generated
    public NeuralQueryBuilder fieldName(String fieldName) {
        this.fieldName = fieldName;
        return this;
    }

    @Generated
    public NeuralQueryBuilder queryText(String queryText) {
        this.queryText = queryText;
        return this;
    }

    @Generated
    public NeuralQueryBuilder modelId(String modelId) {
        this.modelId = modelId;
        return this;
    }

    @Generated
    public NeuralQueryBuilder k(int k) {
        this.k = k;
        return this;
    }

    @Generated
    public NeuralQueryBuilder() {
    }

    @Generated
    public NeuralQueryBuilder(String fieldName, String queryText, String modelId, int k, Supplier<float[]> vectorSupplier) {
        this.fieldName = fieldName;
        this.queryText = queryText;
        this.modelId = modelId;
        this.k = k;
        this.vectorSupplier = vectorSupplier;
    }

    @Generated
    Supplier<float[]> vectorSupplier() {
        return this.vectorSupplier;
    }

    @Generated
    NeuralQueryBuilder vectorSupplier(Supplier<float[]> vectorSupplier) {
        this.vectorSupplier = vectorSupplier;
        return this;
    }
}

