/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
import lombok.Generated;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.ParsingException;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.AbstractQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.neuralsearch.util.TokenWeightUtil;

public class NeuralSparseQueryBuilder
extends AbstractQueryBuilder<NeuralSparseQueryBuilder>
implements ModelInferenceQueryBuilder {
    @Generated
    private static final Logger log = LogManager.getLogger(NeuralSparseQueryBuilder.class);
    public static final String NAME = "neural_sparse";
    @VisibleForTesting
    static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text", new String[0]);
    @VisibleForTesting
    static final ParseField MODEL_ID_FIELD = new ParseField("model_id", new String[0]);
    @Deprecated
    @VisibleForTesting
    static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score", new String[0]).withAllDeprecated();
    private static MLCommonsClientAccessor ML_CLIENT;
    private String fieldName;
    private String queryText;
    private String modelId;
    private Float maxTokenScore;
    private Supplier<Map<String, Float>> queryTokensSupplier;
    private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID;

    public static void initialize(MLCommonsClientAccessor mlClient) {
        ML_CLIENT = mlClient;
    }

    public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
        super(in);
        this.fieldName = in.readString();
        this.queryText = in.readString();
        this.modelId = NeuralSparseQueryBuilder.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() ? in.readOptionalString() : in.readString();
        this.maxTokenScore = in.readOptionalFloat();
        if (in.readBoolean()) {
            Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
            this.queryTokensSupplier = () -> queryTokens;
        }
    }

    protected void doWriteTo(StreamOutput out) throws IOException {
        out.writeString(this.fieldName);
        out.writeString(this.queryText);
        if (NeuralSparseQueryBuilder.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            out.writeOptionalString(this.modelId);
        } else {
            out.writeString(this.modelId);
        }
        out.writeOptionalFloat(this.maxTokenScore);
        if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) {
            out.writeBoolean(true);
            out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
        } else {
            out.writeBoolean(false);
        }
    }

    protected void doXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject(NAME);
        xContentBuilder.startObject(this.fieldName);
        xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), this.queryText);
        if (Objects.nonNull(this.modelId)) {
            xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), this.modelId);
        }
        if (this.maxTokenScore != null) {
            xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), this.maxTokenScore);
        }
        this.printBoostAndQueryName(xContentBuilder);
        xContentBuilder.endObject();
        xContentBuilder.endObject();
    }

    public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throws IOException {
        NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder();
        if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), "First token of neural_sparsequery must be START_OBJECT", new Object[0]);
        }
        parser.nextToken();
        sparseEncodingQueryBuilder.fieldName(parser.currentName());
        parser.nextToken();
        NeuralSparseQueryBuilder.parseQueryParams(parser, sparseEncodingQueryBuilder);
        if (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] query doesn't support multiple fields, found [%s] and [%s]", NAME, sparseEncodingQueryBuilder.fieldName(), parser.currentName()), new Object[0]);
        }
        NeuralSparseQueryBuilder.requireValue((Object)sparseEncodingQueryBuilder.fieldName(), (String)"Field name must be provided for neural_sparse query");
        NeuralSparseQueryBuilder.requireValue((Object)sparseEncodingQueryBuilder.queryText(), (String)String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME));
        if (!NeuralSparseQueryBuilder.isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
            NeuralSparseQueryBuilder.requireValue((Object)sparseEncodingQueryBuilder.modelId(), (String)String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME));
        }
        return sparseEncodingQueryBuilder;
    }

    private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBuilder sparseEncodingQueryBuilder) throws IOException {
        XContentParser.Token token;
        String currentFieldName = "";
        while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
            if (token == XContentParser.Token.FIELD_NAME) {
                currentFieldName = parser.currentName();
                continue;
            }
            if (token.isValue()) {
                if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    sparseEncodingQueryBuilder.queryName(parser.text());
                    continue;
                }
                if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    sparseEncodingQueryBuilder.boost(parser.floatValue());
                    continue;
                }
                if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    sparseEncodingQueryBuilder.queryText(parser.text());
                    continue;
                }
                if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    sparseEncodingQueryBuilder.modelId(parser.text());
                    continue;
                }
                if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
                    sparseEncodingQueryBuilder.maxTokenScore(Float.valueOf(parser.floatValue()));
                    continue;
                }
                throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName), new Object[0]);
            }
            throw new ParsingException(parser.getTokenLocation(), String.format(Locale.ROOT, "[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName), new Object[0]);
        }
    }

    protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException {
        if (null != this.queryTokensSupplier) {
            return this;
        }
        NeuralSparseQueryBuilder.validateForRewrite(this.queryText, this.modelId);
        SetOnce queryTokensSetOnce = new SetOnce();
        queryRewriteContext.registerAsyncAction((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult(this.modelId(), List.of(this.queryText), ActionListener.wrap(mapResultList -> {
            queryTokensSetOnce.set(TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0));
            actionListener.onResponse(null);
        }, arg_0 -> ((ActionListener)actionListener).onFailure(arg_0))));
        return new NeuralSparseQueryBuilder().fieldName(this.fieldName).queryText(this.queryText).modelId(this.modelId).maxTokenScore(this.maxTokenScore).queryTokensSupplier(() -> ((SetOnce)queryTokensSetOnce).get());
    }

    protected Query doToQuery(QueryShardContext context) throws IOException {
        MappedFieldType ft = context.fieldMapper(this.fieldName);
        NeuralSparseQueryBuilder.validateFieldType(ft);
        Map<String, Float> queryTokens = this.queryTokensSupplier.get();
        NeuralSparseQueryBuilder.validateQueryTokens(queryTokens);
        BooleanQuery.Builder builder = new BooleanQuery.Builder();
        for (Map.Entry<String, Float> entry : queryTokens.entrySet()) {
            builder.add(FeatureField.newLinearQuery((String)this.fieldName, (String)entry.getKey(), (float)entry.getValue().floatValue()), BooleanClause.Occur.SHOULD);
        }
        return builder.build();
    }

    private static void validateForRewrite(String queryText, String modelId) {
        if (StringUtils.isBlank((String)queryText) || StringUtils.isBlank((String)modelId)) {
            throw new IllegalArgumentException(String.format(Locale.ROOT, "%s and %s cannot be null", QUERY_TEXT_FIELD.getPreferredName(), MODEL_ID_FIELD.getPreferredName()));
        }
    }

    private static void validateFieldType(MappedFieldType fieldType) {
        if (null == fieldType || !fieldType.typeName().equals("rank_features")) {
            throw new IllegalArgumentException("[neural_sparse] query only works on [rank_features] fields");
        }
    }

    private static void validateQueryTokens(Map<String, Float> queryTokens) {
        if (null == queryTokens) {
            throw new IllegalArgumentException("Query tokens cannot be null.");
        }
        for (Map.Entry<String, Float> entry : queryTokens.entrySet()) {
            if (!(entry.getValue().floatValue() <= 0.0f)) continue;
            throw new IllegalArgumentException("Feature weight must be larger than 0, feature [" + entry.getValue() + "] has negative weight.");
        }
    }

    protected boolean doEquals(NeuralSparseQueryBuilder obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || this.getClass() != obj.getClass()) {
            return false;
        }
        if (this.queryTokensSupplier == null && obj.queryTokensSupplier != null) {
            return false;
        }
        if (this.queryTokensSupplier != null && obj.queryTokensSupplier == null) {
            return false;
        }
        EqualsBuilder equalsBuilder = new EqualsBuilder().append((Object)this.fieldName, (Object)obj.fieldName).append((Object)this.queryText, (Object)obj.queryText).append((Object)this.modelId, (Object)obj.modelId).append((Object)this.maxTokenScore, (Object)obj.maxTokenScore);
        if (this.queryTokensSupplier != null) {
            equalsBuilder.append(this.queryTokensSupplier.get(), obj.queryTokensSupplier.get());
        }
        return equalsBuilder.isEquals();
    }

    protected int doHashCode() {
        HashCodeBuilder builder = new HashCodeBuilder().append((Object)this.fieldName).append((Object)this.queryText).append((Object)this.modelId).append((Object)this.maxTokenScore);
        if (this.queryTokensSupplier != null) {
            builder.append(this.queryTokensSupplier.get());
        }
        return builder.toHashCode();
    }

    public String getWriteableName() {
        return NAME;
    }

    private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() {
        return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID);
    }

    @Override
    @Generated
    public String fieldName() {
        return this.fieldName;
    }

    @Generated
    public String queryText() {
        return this.queryText;
    }

    @Override
    @Generated
    public String modelId() {
        return this.modelId;
    }

    @Generated
    public Float maxTokenScore() {
        return this.maxTokenScore;
    }

    @Generated
    public Supplier<Map<String, Float>> queryTokensSupplier() {
        return this.queryTokensSupplier;
    }

    @Generated
    public NeuralSparseQueryBuilder fieldName(String fieldName) {
        this.fieldName = fieldName;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder queryText(String queryText) {
        this.queryText = queryText;
        return this;
    }

    @Override
    @Generated
    public NeuralSparseQueryBuilder modelId(String modelId) {
        this.modelId = modelId;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder maxTokenScore(Float maxTokenScore) {
        this.maxTokenScore = maxTokenScore;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder queryTokensSupplier(Supplier<Map<String, Float>> queryTokensSupplier) {
        this.queryTokensSupplier = queryTokensSupplier;
        return this;
    }

    @Generated
    public NeuralSparseQueryBuilder() {
    }

    @Generated
    public NeuralSparseQueryBuilder(String fieldName, String queryText, String modelId, Float maxTokenScore, Supplier<Map<String, Float>> queryTokensSupplier) {
        this.fieldName = fieldName;
        this.queryText = queryText;
        this.modelId = modelId;
        this.maxTokenScore = maxTokenScore;
        this.queryTokensSupplier = queryTokensSupplier;
    }

    static {
        MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;
    }
}

