/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.MediaType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.processor.InferenceProcessorAttributes;
import org.opensearch.ml.processor.ModelExecutor;
import org.opensearch.ml.utils.MapUtils;
import org.opensearch.search.SearchHit;
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

public class MLInferenceSearchResponseProcessor
extends AbstractProcessor
implements SearchResponseProcessor,
ModelExecutor {
    private final NamedXContentRegistry xContentRegistry;
    private static final Logger logger = LogManager.getLogger(MLInferenceSearchResponseProcessor.class);
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final String functionName;
    private final boolean override;
    private final boolean fullResponsePath;
    private final boolean oneToOne;
    private final boolean ignoreFailure;
    private final String modelInput;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final String FUNCTION_NAME = "function_name";
    public static final String FULL_RESPONSE_PATH = "full_response_path";
    public static final String MODEL_INPUT = "model_input";
    public static final String ONE_TO_ONE = "one_to_one";
    public static final String DEFAULT_MODEL_INPUT = "{ \"parameters\": ${ml_inference.parameters} }";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
    public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";

    protected MLInferenceSearchResponseProcessor(String modelId, List<Map<String, String>> inputMaps, List<Map<String, String>> outputMaps, Map<String, String> modelConfigMaps, int maxPredictionTask, String tag, String description, boolean ignoreMissing, String functionName, boolean fullResponsePath, boolean ignoreFailure, boolean override, String modelInput, Client client, NamedXContentRegistry xContentRegistry, boolean oneToOne) {
        super(tag, description, ignoreFailure);
        this.oneToOne = oneToOne;
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask);
        this.ignoreMissing = ignoreMissing;
        this.functionName = functionName;
        this.fullResponsePath = fullResponsePath;
        this.ignoreFailure = ignoreFailure;
        this.override = override;
        this.modelInput = modelInput;
        MLInferenceSearchResponseProcessor.client = client;
        this.xContentRegistry = xContentRegistry;
    }

    public SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception {
        throw new RuntimeException("ML inference search response processor make asynchronous calls and does not call processRequest");
    }

    public void processResponseAsync(SearchRequest request, SearchResponse response, PipelineProcessingContext responseContext, ActionListener<SearchResponse> responseListener) {
        try {
            SearchHit[] hits = response.getHits().getHits();
            if (hits.length == 0) {
                responseListener.onResponse((Object)response);
                return;
            }
            this.rewriteResponseDocuments(response, responseListener);
        }
        catch (Exception e) {
            if (this.ignoreFailure) {
                responseListener.onResponse((Object)response);
            }
            responseListener.onFailure(e);
        }
    }

    private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener) throws IOException {
        List<Map<String, String>> processInputMap = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> processOutputMap = this.inferenceProcessorAttributes.getOutputMaps();
        int inputMapSize = processInputMap == null ? 0 : processInputMap.size();
        HashMap<Integer, Integer> hitCountInPredictions = new HashMap<Integer, Integer>();
        if (!this.oneToOne) {
            ActionListener<Map<Integer, MLOutput>> rewriteResponseListener = this.createRewriteResponseListenerManyToOne(response, responseListener, processInputMap, processOutputMap, hitCountInPredictions);
            GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener = this.createBatchPredictionListenerManyToOne(rewriteResponseListener, inputMapSize);
            SearchHit[] hits = response.getHits().getHits();
            for (int inputMapIndex = 0; inputMapIndex < Math.max(inputMapSize, 1); ++inputMapIndex) {
                this.processPredictionsManyToOne(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
            }
        } else {
            responseListener.onFailure((Exception)new IllegalArgumentException("one to one prediction is not supported yet."));
        }
    }

    private void processPredictionsManyToOne(SearchHit[] hits, List<Map<String, String>> processInputMap, final int inputMapIndex, final GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener, Map<Integer, Integer> hitCountInPredictions) throws IOException {
        Map document;
        Map modelParameters = new HashMap<String, String>();
        HashMap<String, String> modelConfigs = new HashMap<String, String>();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            modelParameters.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
            modelConfigs.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        HashMap<String, Object> modelInputParameters = new HashMap<String, Object>();
        if (processInputMap != null && !processInputMap.isEmpty()) {
            Map<String, String> inputMapping = processInputMap.get(inputMapIndex);
            for (SearchHit hit : hits) {
                document = hit.getSourceAsMap();
                boolean isModelInputMissing = this.checkIsModelInputMissing(document, inputMapping);
                if (!isModelInputMissing) {
                    MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex);
                    for (Map.Entry<String, String> entry : inputMapping.entrySet()) {
                        String modelInputFieldName = entry.getKey();
                        String documentFieldName = entry.getValue();
                        Object documentJson = JsonPath.parse((Object)document).read("$", new Predicate[0]);
                        Configuration configuration = Configuration.builder().options(new Option[]{Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL}).build();
                        Object documentValue = JsonPath.using((Configuration)configuration).parse(documentJson).read(documentFieldName, new Predicate[0]);
                        if (documentValue == null) continue;
                        this.updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
                    }
                    continue;
                }
                if (this.ignoreMissing) continue;
                throw new IllegalArgumentException("cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit);
            }
        } else {
            for (SearchHit hit : hits) {
                document = hit.getSourceAsMap();
                MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex);
                for (Map.Entry entry : document.entrySet()) {
                    String modelInputFieldName = (String)entry.getKey();
                    Object documentValue = entry.getValue();
                    this.updateModelInputParametersManyToOne(modelInputParameters, modelInputFieldName, documentValue);
                }
            }
        }
        modelParameters = StringUtils.getParameterMap(modelInputParameters);
        HashSet inputMapKeys = new HashSet(modelParameters.keySet());
        inputMapKeys.removeAll(modelConfigs.keySet());
        HashMap<String, String> inputMappings = new HashMap<String, String>();
        for (String k : inputMapKeys) {
            inputMappings.put(k, (String)modelParameters.get(k));
        }
        ActionRequest request = this.getMLModelInferenceRequest(this.xContentRegistry, modelParameters, modelConfigs, inputMappings, this.inferenceProcessorAttributes.getModelId(), this.functionName, this.modelInput);
        client.execute((ActionType)MLPredictionTaskAction.INSTANCE, request, (ActionListener)new ActionListener<MLTaskResponse>(){

            public void onResponse(MLTaskResponse mlTaskResponse) {
                MLOutput mlOutput = mlTaskResponse.getOutput();
                HashMap<Integer, MLOutput> mlOutputMap = new HashMap<Integer, MLOutput>();
                mlOutputMap.put(inputMapIndex, mlOutput);
                batchPredictionListener.onResponse(mlOutputMap);
            }

            public void onFailure(Exception e) {
                batchPredictionListener.onFailure(e);
            }
        });
    }

    private void updateModelInputParametersManyToOne(Map<String, Object> modelInputParameters, String modelInputFieldName, Object documentValue) {
        if (!modelInputParameters.containsKey(modelInputFieldName)) {
            ArrayList<Object> documentValueList = new ArrayList<Object>();
            documentValueList.add(documentValue);
            modelInputParameters.put(modelInputFieldName, documentValueList);
        } else {
            List valueList = (List)modelInputParameters.get(modelInputFieldName);
            valueList.add(documentValue);
        }
    }

    private GroupedActionListener<Map<Integer, MLOutput>> createBatchPredictionListenerManyToOne(final ActionListener<Map<Integer, MLOutput>> rewriteResponseListener, int inputMapSize) {
        return new GroupedActionListener((ActionListener)new ActionListener<Collection<Map<Integer, MLOutput>>>(){

            public void onResponse(Collection<Map<Integer, MLOutput>> mlOutputMapCollection) {
                HashMap<Integer, MLOutput> mlOutputMaps = new HashMap<Integer, MLOutput>();
                for (Map<Integer, MLOutput> mlOutputMap : mlOutputMapCollection) {
                    mlOutputMaps.putAll(mlOutputMap);
                }
                rewriteResponseListener.onResponse(mlOutputMaps);
            }

            public void onFailure(Exception e) {
                logger.error("Prediction Failed:", (Throwable)e);
                rewriteResponseListener.onFailure(e);
            }
        }, Math.max(inputMapSize, 1));
    }

    private ActionListener<Map<Integer, MLOutput>> createRewriteResponseListenerManyToOne(final SearchResponse response, final ActionListener<SearchResponse> responseListener, final List<Map<String, String>> processInputMap, final List<Map<String, String>> processOutputMap, final Map<Integer, Integer> hitCountInPredictions) {
        return new ActionListener<Map<Integer, MLOutput>>(){

            public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
                try {
                    HashMap<Integer, Map<String, Integer>> writeOutputMapDocCounter = new HashMap<Integer, Map<String, Integer>>();
                    for (SearchHit hit : response.getHits().getHits()) {
                        HashMap<String, Object> sourceAsMapWithInference = new HashMap<String, Object>();
                        if (!hit.hasSource()) continue;
                        BytesReference sourceRef = hit.getSourceRef();
                        Tuple typeAndSourceMap = XContentHelper.convertToMap((BytesReference)sourceRef, (boolean)false, (MediaType)null);
                        Map sourceAsMap = (Map)typeAndSourceMap.v2();
                        sourceAsMapWithInference.putAll(sourceAsMap);
                        Map document = hit.getSourceAsMap();
                        for (Map.Entry<Integer, MLOutput> entry : multipleMLOutputs.entrySet()) {
                            Integer mappingIndex = entry.getKey();
                            MLOutput mlOutput = entry.getValue();
                            Map<String, String> inputMapping = MLInferenceSearchResponseProcessor.getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap);
                            Map<String, String> outputMapping = MLInferenceSearchResponseProcessor.getDefaultOutputMapping(mappingIndex, processOutputMap);
                            boolean isModelInputMissing = false;
                            if (processInputMap != null) {
                                isModelInputMissing = MLInferenceSearchResponseProcessor.this.checkIsModelInputMissing(document, inputMapping);
                            }
                            if (isModelInputMissing) continue;
                            for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {
                                Object modelOutputValuePerDoc;
                                String newDocumentFieldName = outputMapEntry.getKey();
                                String modelOutputFieldName = outputMapEntry.getValue();
                                MapUtils.incrementCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName);
                                Object modelOutputValue = MLInferenceSearchResponseProcessor.this.getModelOutputValue(mlOutput, modelOutputFieldName, MLInferenceSearchResponseProcessor.this.ignoreMissing, MLInferenceSearchResponseProcessor.this.fullResponsePath);
                                if (modelOutputValue instanceof List && ((List)modelOutputValue).size() == ((Integer)hitCountInPredictions.get(mappingIndex)).intValue()) {
                                    Object valuePerDoc = ((List)modelOutputValue).get(MapUtils.getCounter(writeOutputMapDocCounter, mappingIndex, modelOutputFieldName));
                                    modelOutputValuePerDoc = valuePerDoc;
                                } else {
                                    modelOutputValuePerDoc = modelOutputValue;
                                }
                                if (sourceAsMap.containsKey(newDocumentFieldName)) {
                                    if (MLInferenceSearchResponseProcessor.this.override) {
                                        sourceAsMapWithInference.remove(newDocumentFieldName);
                                        sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
                                        continue;
                                    }
                                    logger.debug("{} already exists in the search response hit. Skip processing this field.", (Object)newDocumentFieldName);
                                    continue;
                                }
                                sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
                            }
                        }
                        XContentBuilder builder = XContentBuilder.builder((XContent)((MediaType)typeAndSourceMap.v1()).xContent());
                        builder.map(sourceAsMapWithInference);
                        hit.sourceRef(BytesReference.bytes((XContentBuilder)builder));
                    }
                }
                catch (Exception e) {
                    if (MLInferenceSearchResponseProcessor.this.ignoreFailure) {
                        responseListener.onResponse((Object)response);
                    }
                    responseListener.onFailure(e);
                }
                responseListener.onResponse((Object)response);
            }

            public void onFailure(Exception e) {
                if (MLInferenceSearchResponseProcessor.this.ignoreFailure) {
                    logger.error("Failed in writing prediction outcomes to search response", (Throwable)e);
                    responseListener.onResponse((Object)response);
                } else {
                    responseListener.onFailure(e);
                }
            }
        };
    }

    private boolean checkIsModelInputMissing(Map<String, Object> document, Map<String, String> inputMapping) {
        boolean isModelInputMissing = false;
        for (Map.Entry<String, String> inputMapEntry : inputMapping.entrySet()) {
            String oldDocumentFieldName = inputMapEntry.getValue();
            boolean checkSingleModelInputPresent = this.hasField(document, oldDocumentFieldName);
            if (checkSingleModelInputPresent) continue;
            isModelInputMissing = true;
            break;
        }
        return isModelInputMissing;
    }

    private static Map<String, String> getDefaultOutputMapping(Integer mappingIndex, List<Map<String, String>> processOutputMap) {
        HashMap<String, String> outputMapping;
        if (processOutputMap == null || processOutputMap.size() == 0) {
            outputMapping = new HashMap<String, String>();
            outputMapping.put(DEFAULT_OUTPUT_FIELD_NAME, "$.inference_results");
        } else {
            outputMapping = processOutputMap.get(mappingIndex);
        }
        return outputMapping;
    }

    private static Map<String, String> getDefaultInputMapping(Map<String, Object> sourceAsMap, Integer mappingIndex, List<Map<String, String>> processInputMap) {
        HashMap<String, String> inputMapping;
        if (processInputMap == null || processInputMap.size() == 0) {
            inputMapping = new HashMap();
            inputMapping.putAll(StringUtils.getParameterMap(sourceAsMap));
        } else {
            inputMapping = processInputMap.get(mappingIndex);
        }
        return inputMapping;
    }

    public String getType() {
        return TYPE;
    }

    public static class Factory
    implements Processor.Factory<SearchResponseProcessor> {
        private final Client client;
        private final NamedXContentRegistry xContentRegistry;

        public Factory(Client client, NamedXContentRegistry xContentRegistry) {
            this.client = client;
            this.xContentRegistry = xContentRegistry;
        }

        public MLInferenceSearchResponseProcessor create(Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories, String processorTag, String description, boolean ignoreFailure, Map<String, Object> config, Processor.PipelineContext pipelineContext) {
            String modelId = ConfigurationUtils.readStringProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"model_id");
            Map modelConfigInput = ConfigurationUtils.readOptionalMap((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"model_config");
            List inputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"input_map");
            List outputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"output_map");
            int maxPredictionTask = ConfigurationUtils.readIntProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"max_prediction_tasks", (Integer)10);
            boolean ignoreMissing = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchResponseProcessor.IGNORE_MISSING, (boolean)false);
            String functionName = ConfigurationUtils.readStringProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchResponseProcessor.FUNCTION_NAME, (String)FunctionName.REMOTE.name());
            boolean override = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"override", (boolean)false);
            boolean oneToOne = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchResponseProcessor.ONE_TO_ONE, (boolean)false);
            String modelInput = ConfigurationUtils.readOptionalStringProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchResponseProcessor.MODEL_INPUT);
            if (functionName.equalsIgnoreCase("remote")) {
                modelInput = modelInput != null ? modelInput : MLInferenceSearchResponseProcessor.DEFAULT_MODEL_INPUT;
            } else if (modelInput == null) {
                throw new IllegalArgumentException("Please provide model input when using a local model in ML Inference Processor");
            }
            boolean defaultFullResponsePath = !functionName.equalsIgnoreCase(FunctionName.REMOTE.name());
            boolean fullResponsePath = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)MLInferenceSearchResponseProcessor.FULL_RESPONSE_PATH, (boolean)defaultFullResponsePath);
            ignoreFailure = ConfigurationUtils.readBooleanProperty((String)MLInferenceSearchResponseProcessor.TYPE, (String)processorTag, config, (String)"ignore_failure", (boolean)false);
            Map modelConfigMaps = null;
            if (modelConfigInput != null) {
                modelConfigMaps = StringUtils.getParameterMap((Map)modelConfigInput);
            }
            if (inputMaps != null && inputMaps.size() > maxPredictionTask) {
                throw new IllegalArgumentException("The number of prediction task setting in this process is " + inputMaps.size() + ". It exceeds the max_prediction_tasks of " + maxPredictionTask + ". Please reduce the size of input_map or increase max_prediction_tasks.");
            }
            if (outputMaps != null && inputMaps != null && outputMaps.size() != inputMaps.size()) {
                throw new IllegalArgumentException("when output_maps and input_maps are provided, their length needs to match. The input_maps is in length of " + inputMaps.size() + ", while output_maps is in the length of " + outputMaps.size() + ". Please adjust mappings.");
            }
            return new MLInferenceSearchResponseProcessor(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask, processorTag, description, ignoreMissing, functionName, fullResponsePath, ignoreFailure, override, modelInput, this.client, this.xContentRegistry, oneToOne);
        }
    }
}

