/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.processor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;
import com.jayway.jsonpath.Predicate;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.action.ActionRequest;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.XContentParserUtils;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;

public interface ModelExecutor {
    public static final Configuration suppressExceptionConfiguration = Configuration.builder().options(new Option[]{Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL}).build();

    default public <T> ActionRequest getMLModelInferenceRequest(NamedXContentRegistry xContentRegistry, Map<String, String> parameters, Map<String, String> modelConfigs, Map<String, String> inputMappings, String modelId, String functionNameStr, String modelInput) throws IOException {
        if (parameters == null) {
            throw new IllegalArgumentException("wrong input. The model input cannot be empty.");
        }
        FunctionName functionName = FunctionName.REMOTE;
        if (functionNameStr != null) {
            functionName = FunctionName.from((String)functionNameStr);
        }
        HashMap<String, String> inputParams = new HashMap<String, String>();
        if (FunctionName.REMOTE == functionName) {
            inputParams.put("parameters", StringUtils.toJson(parameters));
        } else {
            inputParams.putAll(parameters);
        }
        String payload = modelInput;
        StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}");
        payload = modelConfigSubstitutor.replace(payload);
        StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}");
        payload = inputMapSubstitutor.replace(payload);
        StringSubstitutor parametersSubstitutor = new StringSubstitutor(inputParams, "${ml_inference.", "}");
        payload = parametersSubstitutor.replace(payload);
        if (!StringUtils.isJson((String)payload)) {
            throw new IllegalArgumentException("Invalid payload: " + payload);
        }
        XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload);
        XContentParserUtils.ensureExpectedToken((XContentParser.Token)XContentParser.Token.START_OBJECT, (XContentParser.Token)parser.nextToken(), (XContentParser)parser);
        MLInput mlInput = MLInput.parse((XContentParser)parser, (String)functionName.name());
        return new MLPredictionTaskRequest(modelId, mlInput);
    }

    default public Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String modelOutputFieldName, boolean ignoreMissing) {
        Object modelOutputValue;
        block10: {
            try {
                if (modelTensorOutput != null && modelTensorOutput.getMlModelOutputs() != null && !modelTensorOutput.getMlModelOutputs().isEmpty()) {
                    ModelTensors output = (ModelTensors)modelTensorOutput.getMlModelOutputs().get(0);
                    if (output != null && output.getMlModelTensors() != null && !output.getMlModelTensors().isEmpty()) {
                        if (output.getMlModelTensors().size() == 1) {
                            ModelTensor tensor = (ModelTensor)output.getMlModelTensors().get(0);
                            Map tensorInDataAsMap = tensor.getDataAsMap();
                            modelOutputValue = tensorInDataAsMap != null ? this.getModelOutputField(tensorInDataAsMap, modelOutputFieldName, ignoreMissing) : ModelExecutor.parseDataInTensor(tensor);
                        } else {
                            ArrayList<Object> tensorArray = new ArrayList<Object>();
                            for (int i = 0; i < output.getMlModelTensors().size(); ++i) {
                                ModelTensor tensor = (ModelTensor)output.getMlModelTensors().get(i);
                                if (tensor == null) continue;
                                try {
                                    Map tensorInDataAsMap = tensor.getDataAsMap();
                                    if (tensorInDataAsMap != null) {
                                        tensorArray.add(this.getModelOutputField(tensorInDataAsMap, modelOutputFieldName, ignoreMissing));
                                        continue;
                                    }
                                    tensorArray.add(ModelExecutor.parseDataInTensor(tensor));
                                    continue;
                                }
                                catch (Exception e) {
                                    throw new RuntimeException("Error accessing tensor data: " + e.getMessage());
                                }
                            }
                            modelOutputValue = tensorArray;
                        }
                        break block10;
                    }
                    throw new RuntimeException("Output tensors are null or empty.");
                }
                throw new RuntimeException("Model outputs are null or empty.");
            }
            catch (Exception e) {
                throw new RuntimeException(e.getMessage());
            }
        }
        return modelOutputValue;
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    default public Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldName, boolean ignoreMissing, boolean fullResponsePath) {
        try (XContentBuilder builder = XContentFactory.jsonBuilder();){
            String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString();
            Map modelTensorOutputMap = (Map)StringUtils.gson.fromJson(modelOutputJsonStr, Map.class);
            if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) {
                Object object = this.getModelOutputValue((ModelTensorOutput)mlOutput, modelOutputFieldName, ignoreMissing);
                return object;
            }
            if (modelOutputFieldName == null || modelTensorOutputMap == null) {
                Map map = modelTensorOutputMap;
                return map;
            }
            Object modelOutputValue = JsonPath.parse((Object)modelTensorOutputMap).read(modelOutputFieldName, new Predicate[0]);
            if (modelOutputValue == null) {
                throw new IllegalArgumentException("model inference output cannot find such json path: " + modelOutputFieldName + " in " + modelTensorOutputMap);
            }
            Object object = modelOutputValue;
            return object;
        }
        catch (Exception e) {
            throw new RuntimeException("An unexpected error occurred: " + e.getMessage());
        }
    }

    public static Object parseDataInTensor(ModelTensor tensor) {
        List modelOutputValue;
        if (tensor.getDataType().isInteger()) {
            modelOutputValue = Arrays.stream(tensor.getData()).map(Number::intValue).map(Integer::new).collect(Collectors.toList());
        } else if (tensor.getDataType().isFloating()) {
            modelOutputValue = Arrays.stream(tensor.getData()).map(Number::floatValue).map(Float::new).collect(Collectors.toList());
        } else if (tensor.getDataType().isString()) {
            modelOutputValue = Arrays.stream(tensor.getData()).map(String::valueOf).map(String::new).collect(Collectors.toList());
        } else if (tensor.getDataType().isBoolean()) {
            modelOutputValue = Arrays.stream(tensor.getData()).map(num -> num.intValue() != 0).map(Boolean::new).collect(Collectors.toList());
        } else {
            throw new RuntimeException("unsupported data type in prediction data.");
        }
        return modelOutputValue;
    }

    default public Object getModelOutputField(Map<String, ?> modelTensorOutputMap, String fieldName, boolean ignoreMissing) throws IOException {
        if (fieldName == null || modelTensorOutputMap == null) {
            return modelTensorOutputMap;
        }
        if (modelTensorOutputMap.containsKey(fieldName)) {
            return modelTensorOutputMap.get(fieldName);
        }
        try {
            return JsonPath.parse(modelTensorOutputMap).read(fieldName, new Predicate[0]);
        }
        catch (Exception e) {
            if (ignoreMissing) {
                return modelTensorOutputMap;
            }
            throw new IllegalArgumentException("model inference output cannot find field name: " + fieldName, e);
        }
    }

    default public String toString(Object originalFieldValue) {
        return StringUtils.toJson((Object)originalFieldValue);
    }

    default public boolean hasField(Object json, String path) {
        Object value = JsonPath.using((Configuration)suppressExceptionConfiguration).parse(json).read(path, new Predicate[0]);
        return value != null;
    }

    default public List<String> writeNewDotPathForNestedObject(Object json, String dotPath) {
        int lastDotIndex = dotPath.lastIndexOf(46);
        ArrayList<String> dotPaths = new ArrayList<String>();
        if (lastDotIndex != -1) {
            String leadingDotPath = dotPath.substring(0, lastDotIndex);
            String lastLeave = dotPath.substring(lastDotIndex + 1, dotPath.length());
            Configuration configuration = Configuration.builder().options(new Option[]{Option.ALWAYS_RETURN_LIST, Option.AS_PATH_LIST, Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL}).build();
            List resultPaths = (List)JsonPath.using((Configuration)configuration).parse(json).read(leadingDotPath, new Predicate[0]);
            for (String path : resultPaths) {
                dotPaths.add(this.convertToDotPath(path) + "." + lastLeave);
            }
            return dotPaths;
        }
        dotPaths.add(dotPath);
        return dotPaths;
    }

    default public String convertToDotPath(String path) {
        return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", "");
    }
}

