/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.input;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.Generated;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.QuestionAnsweringInputDataSet;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.search.builder.SearchSourceBuilder;

public class MLInput
implements Input {
    public static final String ALGORITHM_FIELD = "algorithm";
    public static final String ML_PARAMETERS_FIELD = "parameters";
    public static final String INPUT_INDEX_FIELD = "input_index";
    public static final String INPUT_QUERY_FIELD = "input_query";
    public static final String INPUT_DATA_FIELD = "input_data";
    public static final String RETURN_BYTES_FIELD = "return_bytes";
    public static final String RETURN_NUMBER_FIELD = "return_number";
    public static final String TARGET_RESPONSE_FIELD = "target_response";
    public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
    public static final String TEXT_DOCS_FIELD = "text_docs";
    public static final String QUERY_TEXT_FIELD = "query_text";
    public static final String PARAMETERS_FIELD = "parameters";
    public static final String QUESTION_FIELD = "question";
    public static final String CONTEXT_FIELD = "context";
    protected FunctionName algorithm;
    protected MLAlgoParams parameters;
    protected MLInputDataset inputDataset;
    private int version = 1;

    public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset inputDataset) {
        this.validate(algorithm);
        this.algorithm = algorithm;
        this.parameters = parameters;
        this.inputDataset = inputDataset;
    }

    public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) {
        this.validate(algorithm);
        this.algorithm = algorithm;
        this.parameters = parameters;
        this.inputDataset = inputDataset != null ? inputDataset : this.createInputDataSet(searchSourceBuilder, sourceIndices, dataFrame);
    }

    private void validate(FunctionName algorithm) {
        if (algorithm == null) {
            throw new IllegalArgumentException("algorithm can't be null");
        }
    }

    public MLInput(StreamInput in) throws IOException {
        this.algorithm = (FunctionName)in.readEnum(FunctionName.class);
        if (in.readBoolean()) {
            this.parameters = (MLAlgoParams)MLCommonsClassLoader.initMLInstance(this.algorithm, in, StreamInput.class);
        }
        if (in.readBoolean()) {
            this.inputDataset = MLInputDataset.fromStream(in);
        }
        this.version = in.readInt();
    }

    public void writeTo(StreamOutput out) throws IOException {
        out.writeEnum((Enum)this.algorithm);
        if (this.parameters != null) {
            out.writeBoolean(true);
            this.parameters.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        if (this.inputDataset != null) {
            out.writeBoolean(true);
            this.inputDataset.writeTo(out);
        } else {
            out.writeBoolean(false);
        }
        out.writeInt(this.version);
    }

    public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
        builder.startObject();
        builder.field(ALGORITHM_FIELD, this.algorithm.name());
        if (this.parameters != null) {
            builder.field("parameters", (ToXContent)this.parameters);
        }
        if (this.inputDataset != null) {
            switch (this.inputDataset.getInputDataType()) {
                case SEARCH_QUERY: {
                    builder.field(INPUT_INDEX_FIELD, (Object)((SearchQueryInputDataset)this.inputDataset).getIndices().toArray(new String[0]));
                    builder.field(INPUT_QUERY_FIELD, (ToXContent)((SearchQueryInputDataset)this.inputDataset).getSearchSourceBuilder());
                    break;
                }
                case DATA_FRAME: {
                    builder.startObject(INPUT_DATA_FIELD);
                    ((DataFrameInputDataset)this.inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS);
                    builder.endObject();
                    break;
                }
                case TEXT_DOCS: {
                    List<Integer> targetPositions;
                    TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet)this.inputDataset;
                    List<String> docs = textInputDataSet.getDocs();
                    ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
                    if (docs != null && docs.size() > 0) {
                        builder.field(TEXT_DOCS_FIELD, (Object)docs.toArray(new String[0]));
                    }
                    if (resultFilter == null) break;
                    builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
                    builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
                    List<String> targetResponse = resultFilter.getTargetResponse();
                    if (targetResponse != null && targetResponse.size() > 0) {
                        builder.field(TARGET_RESPONSE_FIELD, (Object)targetResponse.toArray(new String[0]));
                    }
                    if ((targetPositions = resultFilter.getTargetResponsePositions()) == null || targetPositions.size() <= 0) break;
                    builder.field(TARGET_RESPONSE_POSITIONS_FIELD, (Object)targetPositions.toArray(new Integer[0]));
                    break;
                }
                case TEXT_SIMILARITY: {
                    TextSimilarityInputDataSet inputDataSet = (TextSimilarityInputDataSet)this.inputDataset;
                    List<String> documents = inputDataSet.getTextDocs();
                    String queryText = inputDataSet.getQueryText();
                    builder.field(QUERY_TEXT_FIELD, queryText);
                    if (documents == null || documents.isEmpty()) break;
                    builder.startArray(TEXT_DOCS_FIELD);
                    for (String d : documents) {
                        builder.value(d);
                    }
                    builder.endArray();
                    break;
                }
                case QUESTION_ANSWERING: {
                    QuestionAnsweringInputDataSet qaInputDataSet = (QuestionAnsweringInputDataSet)this.inputDataset;
                    String question = qaInputDataSet.getQuestion();
                    String context = qaInputDataSet.getContext();
                    builder.field(QUESTION_FIELD, question);
                    builder.field(CONTEXT_FIELD, context);
                    break;
                }
                case REMOTE: {
                    RemoteInferenceInputDataSet remoteInferenceInputDataSet = (RemoteInferenceInputDataSet)this.inputDataset;
                    Map<String, String> parameters = remoteInferenceInputDataSet.getParameters();
                    builder.field("parameters", parameters);
                    break;
                }
            }
        }
        builder.endObject();
        return builder;
    }

    public static MLInput parse(XContentParser parser, String inputAlgoName) throws IOException {
        String algorithmName = inputAlgoName.toUpperCase(Locale.ROOT);
        FunctionName algorithm = FunctionName.from(algorithmName);
        if (MLCommonsClassLoader.canInitMLInput(algorithm)) {
            MLInput mlInput = (MLInput)MLCommonsClassLoader.initMLInput(algorithm, new Object[]{parser, algorithm}, XContentParser.class, FunctionName.class);
            mlInput.setAlgorithm(algorithm);
            return mlInput;
        }
        MLAlgoParams mlParameters = null;
        SearchSourceBuilder searchSourceBuilder = null;
        ArrayList<String> sourceIndices = new ArrayList<String>();
        DefaultDataFrame dataFrame = null;
        boolean returnBytes = false;
        boolean returnNumber = true;
        ArrayList<String> targetResponse = new ArrayList<String>();
        ArrayList<Integer> targetResponsePositions = new ArrayList<Integer>();
        ArrayList<String> textDocs = new ArrayList<String>();
        String queryText = null;
        String question = null;
        String context = null;
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
        block28: while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
            String fieldName = parser.currentName();
            parser.nextToken();
            switch (fieldName) {
                case "parameters": {
                    mlParameters = (MLAlgoParams)parser.namedObject(MLAlgoParams.class, algorithmName, null);
                    continue block28;
                }
                case "input_index": {
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        sourceIndices.add(parser.text());
                    }
                    continue block28;
                }
                case "input_query": {
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    searchSourceBuilder = SearchSourceBuilder.fromXContent((XContentParser)parser, (boolean)false);
                    continue block28;
                }
                case "input_data": {
                    dataFrame = DefaultDataFrame.parse(parser);
                    continue block28;
                }
                case "return_bytes": {
                    returnBytes = parser.booleanValue();
                    continue block28;
                }
                case "return_number": {
                    returnNumber = parser.booleanValue();
                    continue block28;
                }
                case "target_response": {
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        targetResponse.add(parser.text());
                    }
                    continue block28;
                }
                case "target_response_positions": {
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        targetResponsePositions.add(parser.intValue());
                    }
                    continue block28;
                }
                case "text_docs": {
                    XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_ARRAY, (XContentParser.Token)parser.currentToken(), (XContentParser)parser);
                    while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
                        textDocs.add(parser.text());
                    }
                    continue block28;
                }
                case "query_text": {
                    queryText = parser.text();
                    continue block28;
                }
                case "question": {
                    question = parser.text();
                    continue block28;
                }
                case "context": {
                    context = parser.text();
                    continue block28;
                }
            }
            parser.skipChildren();
        }
        MLInputDataset inputDataSet = null;
        if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.SPARSE_ENCODING || algorithm == FunctionName.SPARSE_TOKENIZE) {
            ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
            inputDataSet = new TextDocsInputDataSet(textDocs, filter);
        } else if (algorithm == FunctionName.TEXT_SIMILARITY) {
            inputDataSet = new TextSimilarityInputDataSet(queryText, textDocs);
        } else if (algorithm == FunctionName.QUESTION_ANSWERING) {
            inputDataSet = new QuestionAnsweringInputDataSet(question, context);
        }
        return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
    }

    private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame) {
        if (dataFrame != null) {
            return new DataFrameInputDataset(dataFrame);
        }
        if (sourceIndices != null && searchSourceBuilder != null) {
            return new SearchQueryInputDataset(sourceIndices, searchSourceBuilder);
        }
        return null;
    }

    @Override
    public FunctionName getFunctionName() {
        return this.algorithm;
    }

    @Generated
    public static MLInputBuilder builder() {
        return new MLInputBuilder();
    }

    @Generated
    public MLInputBuilder toBuilder() {
        return new MLInputBuilder().algorithm(this.algorithm).parameters(this.parameters).inputDataset(this.inputDataset);
    }

    @Generated
    public FunctionName getAlgorithm() {
        return this.algorithm;
    }

    @Generated
    public MLAlgoParams getParameters() {
        return this.parameters;
    }

    @Generated
    public MLInputDataset getInputDataset() {
        return this.inputDataset;
    }

    @Generated
    public int getVersion() {
        return this.version;
    }

    @Generated
    public void setAlgorithm(FunctionName algorithm) {
        this.algorithm = algorithm;
    }

    @Generated
    public void setParameters(MLAlgoParams parameters) {
        this.parameters = parameters;
    }

    @Generated
    public void setInputDataset(MLInputDataset inputDataset) {
        this.inputDataset = inputDataset;
    }

    @Generated
    public void setVersion(int version) {
        this.version = version;
    }

    @Generated
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MLInput)) {
            return false;
        }
        MLInput other = (MLInput)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getVersion() != other.getVersion()) {
            return false;
        }
        FunctionName this$algorithm = this.getAlgorithm();
        FunctionName other$algorithm = other.getAlgorithm();
        if (this$algorithm == null ? other$algorithm != null : !((Object)((Object)this$algorithm)).equals((Object)other$algorithm)) {
            return false;
        }
        MLAlgoParams this$parameters = this.getParameters();
        MLAlgoParams other$parameters = other.getParameters();
        if (this$parameters == null ? other$parameters != null : !this$parameters.equals(other$parameters)) {
            return false;
        }
        MLInputDataset this$inputDataset = this.getInputDataset();
        MLInputDataset other$inputDataset = other.getInputDataset();
        return !(this$inputDataset == null ? other$inputDataset != null : !this$inputDataset.equals(other$inputDataset));
    }

    @Generated
    protected boolean canEqual(Object other) {
        return other instanceof MLInput;
    }

    @Generated
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getVersion();
        FunctionName $algorithm = this.getAlgorithm();
        result = result * 59 + ($algorithm == null ? 43 : ((Object)((Object)$algorithm)).hashCode());
        MLAlgoParams $parameters = this.getParameters();
        result = result * 59 + ($parameters == null ? 43 : $parameters.hashCode());
        MLInputDataset $inputDataset = this.getInputDataset();
        result = result * 59 + ($inputDataset == null ? 43 : $inputDataset.hashCode());
        return result;
    }

    @Generated
    public String toString() {
        return "MLInput(algorithm=" + String.valueOf((Object)this.getAlgorithm()) + ", parameters=" + String.valueOf(this.getParameters()) + ", inputDataset=" + String.valueOf(this.getInputDataset()) + ", version=" + this.getVersion() + ")";
    }

    @Generated
    public MLInput() {
    }

    @Generated
    public static class MLInputBuilder {
        @Generated
        private FunctionName algorithm;
        @Generated
        private MLAlgoParams parameters;
        @Generated
        private MLInputDataset inputDataset;

        @Generated
        MLInputBuilder() {
        }

        @Generated
        public MLInputBuilder algorithm(FunctionName algorithm) {
            this.algorithm = algorithm;
            return this;
        }

        @Generated
        public MLInputBuilder parameters(MLAlgoParams parameters) {
            this.parameters = parameters;
            return this;
        }

        @Generated
        public MLInputBuilder inputDataset(MLInputDataset inputDataset) {
            this.inputDataset = inputDataset;
            return this;
        }

        @Generated
        public MLInput build() {
            return new MLInput(this.algorithm, this.parameters, this.inputDataset);
        }

        @Generated
        public String toString() {
            return "MLInput.MLInputBuilder(algorithm=" + String.valueOf((Object)this.algorithm) + ", parameters=" + String.valueOf(this.parameters) + ", inputDataset=" + String.valueOf(this.inputDataset) + ")";
        }
    }
}

