/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.Predicate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionType;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.ValueSource;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.processor.InferenceProcessorAttributes;
import org.opensearch.ml.processor.ModelExecutor;
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;

public class MLInferenceIngestProcessor
extends AbstractProcessor
implements ModelExecutor {
    public static final String DOT_SYMBOL = ".";
    private final InferenceProcessorAttributes inferenceProcessorAttributes;
    private final boolean ignoreMissing;
    private final boolean ignoreFailure;
    private final ScriptService scriptService;
    private static Client client;
    public static final String TYPE = "ml_inference";
    public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
    public static final String IGNORE_MISSING = "ignore_missing";
    public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
    private Configuration suppressExceptionConfiguration = Configuration.builder().options(new Option[]{Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST}).build();

    protected MLInferenceIngestProcessor(String modelId, List<Map<String, String>> inputMaps, List<Map<String, String>> outputMaps, Map<String, String> modelConfigMaps, int maxPredictionTask, String tag, String description, boolean ignoreMissing, boolean ignoreFailure, ScriptService scriptService, Client client) {
        super(tag, description);
        this.inferenceProcessorAttributes = new InferenceProcessorAttributes(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask);
        this.ignoreMissing = ignoreMissing;
        this.ignoreFailure = ignoreFailure;
        this.scriptService = scriptService;
        MLInferenceIngestProcessor.client = client;
    }

    public void execute(final IngestDocument ingestDocument, final BiConsumer<IngestDocument, Exception> handler) {
        List<Map<String, String>> processInputMap = this.inferenceProcessorAttributes.getInputMaps();
        List<Map<String, String>> processOutputMap = this.inferenceProcessorAttributes.getOutputMaps();
        int inputMapSize = processInputMap != null ? processInputMap.size() : 0;
        GroupedActionListener batchPredictionListener = new GroupedActionListener((ActionListener)new ActionListener<Collection<Void>>(){

            public void onResponse(Collection<Void> voids) {
                handler.accept(ingestDocument, null);
            }

            public void onFailure(Exception e) {
                if (MLInferenceIngestProcessor.this.ignoreFailure) {
                    handler.accept(ingestDocument, null);
                } else {
                    handler.accept(null, e);
                }
            }
        }, Math.max(inputMapSize, 1));
        for (int inputMapIndex = 0; inputMapIndex < Math.max(inputMapSize, 1); ++inputMapIndex) {
            try {
                this.processPredictions(ingestDocument, (GroupedActionListener<Void>)batchPredictionListener, processInputMap, processOutputMap, inputMapIndex, inputMapSize);
                continue;
            }
            catch (Exception e) {
                batchPredictionListener.onFailure(e);
            }
        }
    }

    public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
        throw new UnsupportedOperationException("this method should not get executed.");
    }

    private void processPredictions(final IngestDocument ingestDocument, final GroupedActionListener<Void> batchPredictionListener, List<Map<String, String>> processInputMap, final List<Map<String, String>> processOutputMap, final int inputMapIndex, int inputMapSize) {
        HashMap<String, String> modelParameters = new HashMap<String, String>();
        if (this.inferenceProcessorAttributes.getModelConfigMaps() != null) {
            modelParameters.putAll(this.inferenceProcessorAttributes.getModelConfigMaps());
        }
        if (inputMapSize == 0) {
            Set documentFields = ingestDocument.getSourceAndMetadata().keySet();
            for (String field : documentFields) {
                this.getMappedModelInputFromDocuments(ingestDocument, modelParameters, field, field);
            }
        } else {
            Map<String, String> inputMapping = processInputMap.get(inputMapIndex);
            for (Map.Entry<String, String> entry : inputMapping.entrySet()) {
                String modelInputFieldName = entry.getKey();
                String documentFieldName = entry.getValue();
                this.getMappedModelInputFromDocuments(ingestDocument, modelParameters, documentFieldName, modelInputFieldName);
            }
        }
        ActionRequest request = this.getRemoteModelInferenceRequest(modelParameters, this.inferenceProcessorAttributes.getModelId());
        client.execute((ActionType)MLPredictionTaskAction.INSTANCE, request, (ActionListener)new ActionListener<MLTaskResponse>(){

            public void onResponse(MLTaskResponse mlTaskResponse) {
                ModelTensorOutput modelTensorOutput = (ModelTensorOutput)mlTaskResponse.getOutput();
                if (processOutputMap == null || processOutputMap.isEmpty()) {
                    MLInferenceIngestProcessor.this.appendFieldValue(modelTensorOutput, null, MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
                } else {
                    Map outputMapping = (Map)processOutputMap.get(inputMapIndex);
                    for (Map.Entry entry : outputMapping.entrySet()) {
                        String newDocumentFieldName = (String)entry.getKey();
                        String modelOutputFieldName = (String)entry.getValue();
                        if (ingestDocument.hasField(newDocumentFieldName)) {
                            throw new IllegalArgumentException("document already has field name " + newDocumentFieldName + ". Not allow to overwrite the same field name, please check output_map.");
                        }
                        MLInferenceIngestProcessor.this.appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
                    }
                }
                batchPredictionListener.onResponse(null);
            }

            public void onFailure(Exception e) {
                batchPredictionListener.onFailure(e);
            }
        });
    }

    private void getMappedModelInputFromDocuments(IngestDocument ingestDocument, Map<String, String> modelParameters, String documentFieldName, String modelInputFieldName) {
        String originalFieldPath = this.getFieldPath(ingestDocument, documentFieldName);
        if (originalFieldPath != null) {
            Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
            String documentFieldValueAsString = this.toString(documentFieldValue);
            this.updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
        } else if (documentFieldName.contains(DOT_SYMBOL)) {
            Map sourceObject = ingestDocument.getSourceAndMetadata();
            ArrayList fieldValueList = (ArrayList)JsonPath.using((Configuration)this.suppressExceptionConfiguration).parse((Object)sourceObject).read(documentFieldName, new Predicate[0]);
            if (!fieldValueList.isEmpty()) {
                this.updateModelParameters(modelInputFieldName, this.toString(fieldValueList), modelParameters);
            } else if (!this.ignoreMissing) {
                throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
            }
        } else if (!this.ignoreMissing) {
            throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
        }
    }

    private void updateModelParameters(String modelInputFieldName, String originalFieldValueAsString, Map<String, String> modelParameters) {
        if (modelParameters.containsKey(modelInputFieldName)) {
            String existingValue = modelParameters.get(modelInputFieldName);
            List updatedList = (List)((Object)existingValue);
            updatedList.add(originalFieldValueAsString);
            modelParameters.put(modelInputFieldName, this.toString(updatedList));
        } else {
            modelParameters.put(modelInputFieldName, originalFieldValueAsString);
        }
    }

    private String getFieldPath(IngestDocument ingestDocument, String documentFieldName) {
        if (Strings.isNullOrEmpty((String)documentFieldName) || !ingestDocument.hasField(documentFieldName, true)) {
            return null;
        }
        return documentFieldName;
    }

    private void appendFieldValue(ModelTensorOutput modelTensorOutput, String modelOutputFieldName, String newDocumentFieldName, IngestDocument ingestDocument) {
        Object modelOutputValue = null;
        if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) {
            modelOutputValue = this.getModelOutputValue(modelTensorOutput, modelOutputFieldName, this.ignoreMissing);
            List dotPathsInArray = this.writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName);
            if (dotPathsInArray.size() == 1) {
                ValueSource ingestValue = ValueSource.wrap((Object)modelOutputValue, (ScriptService)this.scriptService);
                TemplateScript.Factory ingestField = ConfigurationUtils.compileTemplate((String)TYPE, (String)this.tag, (String)newDocumentFieldName, (String)newDocumentFieldName, (ScriptService)this.scriptService);
                ingestDocument.setFieldValue(ingestField, ingestValue, this.ignoreMissing);
            } else {
                if (!(modelOutputValue instanceof List)) {
                    throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
                }
                List modelOutputValueArray = (List)modelOutputValue;
                if (dotPathsInArray.size() != modelOutputValueArray.size()) {
                    throw new RuntimeException("the prediction field: " + modelOutputFieldName + " is an array in size of " + modelOutputValueArray.size() + " but the document field array from field " + newDocumentFieldName + " is in size of " + dotPathsInArray.size());
                }
                for (int i = 0; i < dotPathsInArray.size(); ++i) {
                    String dotPathInArray = (String)dotPathsInArray.get(i);
                    Object modelOutputValueInArray = modelOutputValueArray.get(i);
                    ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, (ScriptService)this.scriptService);
                    TemplateScript.Factory ingestField = ConfigurationUtils.compileTemplate((String)TYPE, (String)this.tag, (String)dotPathInArray, (String)dotPathInArray, (ScriptService)this.scriptService);
                    ingestDocument.setFieldValue(ingestField, ingestValue, this.ignoreMissing);
                }
            }
        } else {
            throw new RuntimeException("model inference output cannot be null");
        }
    }

    public String getType() {
        return TYPE;
    }

    public static class Factory
    implements Processor.Factory {
        private final ScriptService scriptService;
        private final Client client;

        public Factory(ScriptService scriptService, Client client) {
            this.scriptService = scriptService;
            this.client = client;
        }

        public MLInferenceIngestProcessor create(Map<String, Processor.Factory> registry, String processorTag, String description, Map<String, Object> config) throws Exception {
            String modelId = ConfigurationUtils.readStringProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"model_id");
            Map modelConfigInput = ConfigurationUtils.readOptionalMap((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"model_config");
            List inputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"input_map");
            List outputMaps = ConfigurationUtils.readOptionalList((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"output_map");
            int maxPredictionTask = ConfigurationUtils.readIntProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"max_prediction_tasks", (Integer)10);
            boolean ignoreMissing = ConfigurationUtils.readBooleanProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)MLInferenceIngestProcessor.IGNORE_MISSING, (boolean)false);
            boolean ignoreFailure = ConfigurationUtils.readBooleanProperty((String)MLInferenceIngestProcessor.TYPE, (String)processorTag, config, (String)"ignore_failure", (boolean)false);
            Map modelConfigMaps = null;
            if (modelConfigInput != null) {
                modelConfigMaps = StringUtils.getParameterMap((Map)modelConfigInput);
            }
            if (inputMaps != null && inputMaps.size() > maxPredictionTask) {
                throw new IllegalArgumentException("The number of prediction task setting in this process is " + inputMaps.size() + ". It exceeds the max_prediction_tasks of " + maxPredictionTask + ". Please reduce the size of input_map or increase max_prediction_tasks.");
            }
            if (inputMaps != null && outputMaps != null && outputMaps.size() != inputMaps.size()) {
                throw new IllegalArgumentException("The length of output_map and the length of input_map do no match.");
            }
            return new MLInferenceIngestProcessor(modelId, inputMaps, outputMaps, modelConfigMaps, maxPredictionTask, processorTag, description, ignoreMissing, ignoreFailure, this.scriptService, this.client);
        }
    }
}

