/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.DocumentContext;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.text.StringSubstitutor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.processor.InferenceProcessorAttributes;
import org.opensearch.ml.processor.ModelExecutor;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;

public class MLInferenceSearchRequestProcessor
extends AbstractProcessor
implements SearchRequestProcessor,
ModelExecutor {
    private final NamedXContentRegistry xContentRegistry;
    private static final Logger logger = LogManager.getLogger(MLInferenceSearchRequestProcessor.class);
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final String functionName;
    private String queryTemplate;
    private final boolean fullResponsePath;
    private final boolean ignoreFailure;
    private final String modelInput;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final String QUERY_TEMPLATE = "query_template";
    public static final String FUNCTION_NAME = "function_name";
    public static final String FULL_RESPONSE_PATH = "full_response_path";
    public static final String MODEL_INPUT = "model_input";
    public static final String DEFAULT_MODEl_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;

    protected MLInferenceSearchRequestProcessor(String modelId, String queryTemplate, List<Map<String, String>> inputMaps, List<Map<String, String>> outputMaps, Map<String, String> modelConfigMaps, int maxPredictionTask, String tag, String description, boolean ignoreMissing, String functionName, boolean fullResponsePath, boolean ignoreFailure, String modelInput, Client client, NamedXContentRegistry xContentRegistry) {
        super(tag, description, ignoreFailure);
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask);
        this.ignoreMissing = ignoreMissing;
        this.functionName = functionName;
        this.fullResponsePath = fullResponsePath;
        this.queryTemplate = queryTemplate;
        this.ignoreFailure = ignoreFailure;
        this.modelInput = modelInput;
        MLInferenceSearchRequestProcessor.client = client;
        this.xContentRegistry = xContentRegistry;
    }

    public SearchRequest processRequest(SearchRequest request) throws Exception {
        throw new RuntimeException("ML inference search request processor make asynchronous calls and does not call processRequest");
    }

    public void processRequestAsync(SearchRequest request, PipelineProcessingContext requestContext, ActionListener<SearchRequest> requestListener) {
        try {
            if (request.source() == null) {
                throw new IllegalArgumentException("query body is empty, cannot processor inference on empty query request.");
            }
            String queryString = request.source().toString();
            this.rewriteQueryString(request, queryString, requestListener);
        }
        catch (Exception e) {
            if (this.ignoreFailure) {
                requestListener.onResponse((Object)request);
            }
            requestListener.onFailure(e);
        }
    }

    private void rewriteQueryString(SearchRequest request, String queryString, ActionListener<SearchRequest> requestListener) throws IOException {
        int inputMapSize;
        List<Map<String, String>> processInputMap = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> processOutputMap = this.inferenceProcessorAttributes.getOutputMaps();
        int n = inputMapSize = processInputMap != null ? processInputMap.size() : 0;
        if (inputMapSize == 0) {
            requestListener.onResponse((Object)request);
            return;
        }
        try {
            if (!this.validateQueryFieldInQueryString(processInputMap, processOutputMap, queryString)) {
                requestListener.onResponse((Object)request);
            }
        }
        catch (Exception e) {
            if (this.ignoreMissing) {
                requestListener.onResponse((Object)request);
                return;
            }
            requestListener.onFailure(e);
            return;
        }
        ActionListener<Map<Integer, MLOutput>> rewriteRequestListener = this.createRewriteRequestListener(request, queryString, requestListener, processOutputMap);
        GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = this.createBatchPredictionListener(rewriteRequestListener, inputMapSize);
        for (int inputMapIndex = 0; inputMapIndex < inputMapSize; ++inputMapIndex) {
            this.processPredictions(queryString, processInputMap, inputMapIndex, batchPredictionListener);
        }
    }

    private ActionListener<Map<Integer, MLOutput>> createRewriteRequestListener(final SearchRequest request, final String queryString, final ActionListener<SearchRequest> requestListener, final List<Map<String, String>> processOutputMap) {
        return new ActionListener<Map<Integer, MLOutput>>(){

            public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
                for (Map.Entry<Integer, MLOutput> entry : multipleMLOutputs.entrySet()) {
                    Integer mappingIndex = entry.getKey();
                    MLOutput mlOutput = entry.getValue();
                    Map outputMapping = (Map)processOutputMap.get(mappingIndex);
                    try {
                        SearchSourceBuilder searchSourceBuilder;
                        if (MLInferenceSearchRequestProcessor.this.queryTemplate == null) {
                            Object incomeQueryObject = JsonPath.parse((String)queryString).read("$", new Predicate[0]);
                            this.updateIncomeQueryObject(incomeQueryObject, outputMapping, mlOutput);
                            searchSourceBuilder = MLInferenceSearchRequestProcessor.getSearchSourceBuilder(MLInferenceSearchRequestProcessor.this.xContentRegistry, StringUtils.toJson((Object)incomeQueryObject));
                            request.source(searchSourceBuilder);
                            requestListener.onResponse((Object)request);
                            continue;
                        }
                        String newQueryString = this.updateQueryTemplate(MLInferenceSearchRequestProcessor.this.queryTemplate, outputMapping, mlOutput);
                        searchSourceBuilder = MLInferenceSearchRequestProcessor.getSearchSourceBuilder(MLInferenceSearchRequestProcessor.this.xContentRegistry, newQueryString);
                        request.source(searchSourceBuilder);
                        requestListener.onResponse((Object)request);
                    }
                    catch (Exception e) {
                        if (MLInferenceSearchRequestProcessor.this.ignoreFailure) {
                            logger.error("Failed in writing prediction outcomes to new query", (Throwable)e);
                            requestListener.onResponse((Object)request);
                            continue;
                        }
                        requestListener.onFailure(e);
                    }
                }
            }

            public void onFailure(Exception e) {
                if (MLInferenceSearchRequestProcessor.this.ignoreFailure) {
                    logger.error("Failed in writing prediction outcomes to new query", (Throwable)e);
                    requestListener.onResponse((Object)request);
                } else {
                    requestListener.onFailure(e);
                }
            }

            private void updateIncomeQueryObject(Object incomeQueryObject, Map<String, String> outputMapping, MLOutput mlOutput) {
                for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
                    String newQueryField = outputMapEntry.getKey();
                    String modelOutputFieldName = outputMapEntry.getValue();
                    Object modelOutputValue = MLInferenceSearchRequestProcessor.this.getModelOutputValue(mlOutput, modelOutputFieldName, MLInferenceSearchRequestProcessor.this.ignoreMissing, MLInferenceSearchRequestProcessor.this.fullResponsePath);
                    String jsonPathExpression = "$." + newQueryField;
                    JsonPath.parse((Object)incomeQueryObject).set(jsonPathExpression, modelOutputValue, new Predicate[0]);
                }
            }

            private String updateQueryTemplate(String queryTemplate, Map<String, String> outputMapping, MLOutput mlOutput) {
                HashMap<String, Object> valuesMap = new HashMap<String, Object>();
                for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
                    String newQueryField = outputMapEntry.getKey();
                    String modelOutputFieldName = outputMapEntry.getValue();
                    Object modelOutputValue = MLInferenceSearchRequestProcessor.this.getModelOutputValue(mlOutput, modelOutputFieldName, MLInferenceSearchRequestProcessor.this.ignoreMissing, MLInferenceSearchRequestProcessor.this.fullResponsePath);
                    valuesMap.put(newQueryField, modelOutputValue);
                }
                StringSubstitutor sub = new StringSubstitutor(valuesMap);
                return sub.replace(queryTemplate);
            }
        };
    }

    private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListener(final ActionListener<Map<Integer, MLOutput>> rewriteRequestListner, int inputMapSize) {
        return new GroupedActionListener((ActionListener)new ActionListener<Collection<Map<Integer, MLOutput>>>(){

            public void onResponse(Collection<Map<Integer, MLOutput>> mlOutputMapCollection) {
                HashMap<Integer, MLOutput> mlOutputMaps = new HashMap<Integer, MLOutput>();
                for (Map<Integer, MLOutput> mlOutputMap : mlOutputMapCollection) {
                    mlOutputMaps.putAll(mlOutputMap);
                }
                rewriteRequestListner.onResponse(mlOutputMaps);
            }

            public void onFailure(Exception e) {
                logger.error("Prediction Failed:", (Throwable)e);
                rewriteRequestListner.onFailure(e);
            }
        }, Math.max(inputMapSize, 1));
    }

    private boolean validateQueryFieldInQueryString(List<Map<String, String>> processInputMap, List<Map<String, String>> processOutputMap, String queryString) {
        Object pathData;
        String queryField;
        Configuration suppressExceptionConfiguration = Configuration.defaultConfiguration().addOptions(new Option[]{Option.SUPPRESS_EXCEPTIONS});
        DocumentContext jsonData = JsonPath.using((Configuration)suppressExceptionConfiguration).parse(queryString);
        for (Map<String, String> inputMap : processInputMap) {
            for (Map.Entry<String, String> entry : inputMap.entrySet()) {
                queryField = entry.getValue();
                pathData = jsonData.read(queryField, new Predicate[0]);
                if (pathData != null) continue;
                throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString());
            }
        }
        if (this.queryTemplate == null) {
            for (Map<String, String> outputMap : processOutputMap) {
                for (Map.Entry<String, String> entry : outputMap.entrySet()) {
                    queryField = entry.getKey();
                    pathData = jsonData.read(queryField, new Predicate[0]);
                    if (pathData != null) continue;
                    throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString());
                }
            }
        }
        return true;
    }

    private void processPredictions(String queryString, List<Map<String, String>> processInputMap, final int inputMapIndex, final GroupedActionListener batchPredictionListener) throws IOException {
        HashMap<String, String> modelParameters = new HashMap<String, String>();
        HashMap<String, String> modelConfigs = new HashMap<String, String>();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            modelParameters.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
            modelConfigs.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        Map<Object, Object> inputMapping = new HashMap();
        if (processInputMap != null) {
            inputMapping = processInputMap.get(inputMapIndex);
            Object newQuery = JsonPath.parse((String)queryString).read("$", new Predicate[0]);
            for (Map.Entry<Object, Object> entry : inputMapping.entrySet()) {
                String modelInputFieldName = (String)entry.getKey();
                String queryFieldName = (String)entry.getValue();
                String queryFieldValue = StringUtils.toJson((Object)JsonPath.parse((Object)newQuery).read(queryFieldName, new Predicate[0]));
                modelParameters.put(modelInputFieldName, queryFieldValue);
            }
        }
        HashSet inputMapKeys = new HashSet(modelParameters.keySet());
        inputMapKeys.removeAll(modelConfigs.keySet());
        HashMap<String, String> inputMappings = new HashMap<String, String>();
        for (String k : inputMapKeys) {
            inputMappings.put(k, (String)modelParameters.get(k));
        }
        ActionRequest actionRequest = this.getMLModelInferenceRequest(this.xContentRegistry, modelParameters, modelConfigs, inputMappings, this.inferenceProcessorAttributes.getModelId(), this.functionName, this.modelInput);
        client.execute((ActionType)MLPredictionTaskAction.INSTANCE, actionRequest, (ActionListener)new ActionListener<MLTaskResponse>(){

            public void onResponse(MLTaskResponse mlTaskResponse) {
                MLOutput mlOutput = mlTaskResponse.getOutput();
                HashMap<Integer, MLOutput> mlOutputMap = new HashMap<Integer, MLOutput>();
                mlOutputMap.put(inputMapIndex, mlOutput);
                batchPredictionListener.onResponse(mlOutputMap);
            }

            public void onFailure(Exception e) {
                batchPredictionListener.onFailure(e);
            }
        });
    }

    private static SearchSourceBuilder getSearchSourceBuilder(NamedXContentRegistry xContentRegistry, String queryString) throws IOException {
        SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
        XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, queryString);
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)queryParser.nextToken(), (XContentParser)queryParser);
        searchSourceBuilder.parseXContent(queryParser);
        return searchSourceBuilder;
    }

    public String getType() {
        return TYPE;
    }

    public static class Factory
    implements Processor.Factory<SearchRequestProcessor> {
        private final Client client;
        private final NamedXContentRegistry xContentRegistry;

        public Factory(Client client, NamedXContentRegistry xContentRegistry) {
            this.client = client;
            this.xContentRegistry = xContentRegistry;
        }

        public MLInferenceSearchRequestProcessor create(Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories, String processorTag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) {
            String modelId = ConfigurationUtils.readStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"model_id");
            String queryTemplate = ConfigurationUtils.readOptionalStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.QUERY_TEMPLATE);
            Map modelConfigInput = ConfigurationUtils.readOptionalMap((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"model_config");
            List inputMaps = ConfigurationUtils.readList((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"input_map");
            List outputMaps = ConfigurationUtils.readList((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"output_map");
            int maxPredictionTask = ConfigurationUtils.readIntProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"max_prediction_tasks", (Integer)10);
            boolean ignoreMissing = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.IGNORE_MISSING, (boolean)false);
            String functionName = ConfigurationUtils.readStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.FUNCTION_NAME, (String)FunctionName.REMOTE.name());
            String modelInput = ConfigurationUtils.readOptionalStringProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.MODEL_INPUT);
            if (functionName.equalsIgnoreCase("remote")) {
                modelInput = modelInput != null ? modelInput : MLInferenceSearchRequestProcessor.DEFAULT_MODEl_INPUT;
            } else if (modelInput == null) {
                throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
            }
            boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
            boolean fullResponsePath = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchRequestProcessor.FULL_RESPONSE_PATH, (boolean)defaultFullResponsePath);
            ignoreFailure = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchRequestProcessor.TYPE, (String)processorTag, config, (String)"ignore_failure", (boolean)false);
            Map modelConfigMaps = null;
            if (modelConfigInput != null) {
                modelConfigMaps = StringUtils.getParameterMap((Map)modelConfigInput);
            }
            if (inputMaps != null && inputMaps.size() > maxPredictionTask) {
                throw new IllegalArgumentException("The number of prediction task setting in this process is " + inputMaps.size() + ". It exceeds the max_prediction_tasks of " + maxPredictionTask + ". Please reduce the size of input_map or increase max_prediction_tasks.");
            }
            return new MLInferenceSearchRequestProcessor(modelId, queryTemplate, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask, processorTag, description, ignoreMissing, functionName, fullResponsePath, ignoreFailure, modelInput, this.client, this.xContentRegistry);
        }
    }
}

