/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.env.Environment;
import org.opensearch.ingest.AbstractBatchingProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.IngestDocumentWrapper;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.ProcessorDocumentUtils;

public abstract class InferenceProcessor
extends AbstractBatchingProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(InferenceProcessor.class);
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String FIELD_MAP_FIELD = "field_map";
    private static final BiFunction<Object, Object, Object> REMAPPING_FUNCTION = (v1, v2) -> {
        if (v1 instanceof Collection && v2 instanceof Collection) {
            ((Collection)v1).addAll((Collection)v2);
            return v1;
        }
        if (v1 instanceof Map && v2 instanceof Map) {
            ((Map)v1).putAll((Map)v2);
            return v1;
        }
        return v2;
    };
    private final String type;
    private final String listTypeNestedMapKey;
    protected final String modelId;
    private final Map<String, Object> fieldMap;
    protected final MLCommonsClientAccessor mlCommonsClientAccessor;
    private final Environment environment;
    private final ClusterService clusterService;

    public InferenceProcessor(String tag, String description, int batchSize, String type, String listTypeNestedMapKey, String modelId, Map<String, Object> fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService) {
        super(tag, description, batchSize);
        this.type = type;
        if (StringUtils.isBlank((CharSequence)modelId)) {
            throw new IllegalArgumentException("model_id is null or empty, cannot process it");
        }
        this.validateEmbeddingConfiguration(fieldMap);
        this.listTypeNestedMapKey = listTypeNestedMapKey;
        this.modelId = modelId;
        this.fieldMap = fieldMap;
        this.mlCommonsClientAccessor = clientAccessor;
        this.environment = environment;
        this.clusterService = clusterService;
    }

    private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
        if (fieldMap == null || fieldMap.size() == 0 || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank((CharSequence)((CharSequence)x.getKey())) || Objects.isNull(x.getValue()) || StringUtils.isBlank((CharSequence)x.getValue().toString()))) {
            throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value");
        }
    }

    public abstract void doExecute(IngestDocument var1, Map<String, Object> var2, List<String> var3, BiConsumer<IngestDocument, Exception> var4);

    public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
        return ingestDocument;
    }

    public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
        try {
            this.validateEmbeddingFieldsValue(ingestDocument);
            Map<String, Object> processMap = this.buildMapWithTargetKeys(ingestDocument);
            List<String> inferenceList = this.createInferenceList(processMap);
            if (inferenceList.size() == 0) {
                handler.accept(ingestDocument, null);
            } else {
                this.doExecute(ingestDocument, processMap, inferenceList, handler);
            }
        }
        catch (Exception e) {
            handler.accept(null, e);
        }
    }

    abstract void doBatchExecute(List<String> var1, Consumer<List<?>> var2, Consumer<Exception> var3);

    public void subBatchExecute(List<IngestDocumentWrapper> ingestDocumentWrappers, Consumer<List<IngestDocumentWrapper>> handler) {
        if (CollectionUtils.isEmpty(ingestDocumentWrappers)) {
            handler.accept(Collections.emptyList());
            return;
        }
        List<DataForInference> dataForInferences = this.getDataForInference(ingestDocumentWrappers);
        List inferenceList = this.constructInferenceTexts(dataForInferences);
        if (inferenceList.isEmpty()) {
            handler.accept(ingestDocumentWrappers);
            return;
        }
        Tuple<List<String>, Map<Integer, Integer>> sortedResult = this.sortByLengthAndReturnOriginalOrder(inferenceList);
        inferenceList = (List)sortedResult.v1();
        Map originalOrder = (Map)sortedResult.v2();
        this.doBatchExecute(inferenceList, results -> {
            int startIndex = 0;
            results = this.restoreToOriginalOrder((List<?>)results, originalOrder);
            for (DataForInference dataForInference : dataForInferences) {
                if (dataForInference.getIngestDocumentWrapper().getException() != null || CollectionUtils.isEmpty(dataForInference.getInferenceList())) continue;
                List<?> inferenceResults = results.subList(startIndex, startIndex + dataForInference.getInferenceList().size());
                startIndex += dataForInference.getInferenceList().size();
                this.setVectorFieldsToDocument(dataForInference.getIngestDocumentWrapper().getIngestDocument(), dataForInference.getProcessMap(), inferenceResults);
            }
            handler.accept(ingestDocumentWrappers);
        }, exception -> {
            for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
                if (ingestDocumentWrapper.getException() != null) continue;
                ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), exception);
            }
            handler.accept(ingestDocumentWrappers);
        });
    }

    private Tuple<List<String>, Map<Integer, Integer>> sortByLengthAndReturnOriginalOrder(List<String> inferenceList) {
        ArrayList<Tuple> docsWithIndex = new ArrayList<Tuple>();
        for (int i = 0; i < inferenceList.size(); ++i) {
            docsWithIndex.add(Tuple.tuple((Object)i, (Object)inferenceList.get(i)));
        }
        docsWithIndex.sort(Comparator.comparingInt(t -> ((String)t.v2()).length()));
        List sortedInferenceList = docsWithIndex.stream().map(Tuple::v2).collect(Collectors.toList());
        HashMap<Integer, Integer> originalOrderMap = new HashMap<Integer, Integer>();
        for (int i = 0; i < docsWithIndex.size(); ++i) {
            originalOrderMap.put(i, (Integer)((Tuple)docsWithIndex.get(i)).v1());
        }
        return Tuple.tuple(sortedInferenceList, originalOrderMap);
    }

    private List<?> restoreToOriginalOrder(List<?> results, Map<Integer, Integer> originalOrder) {
        List<Object> sortedResults = Arrays.asList(results.toArray());
        for (int i = 0; i < results.size(); ++i) {
            if (!originalOrder.containsKey(i)) continue;
            int oldIndex = originalOrder.get(i);
            sortedResults.set(oldIndex, results.get(i));
        }
        return sortedResults;
    }

    private List<String> constructInferenceTexts(List<DataForInference> dataForInferences) {
        ArrayList<String> inferenceTexts = new ArrayList<String>();
        for (DataForInference dataForInference : dataForInferences) {
            if (dataForInference.getIngestDocumentWrapper().getException() != null || CollectionUtils.isEmpty(dataForInference.getInferenceList())) continue;
            inferenceTexts.addAll(dataForInference.getInferenceList());
        }
        return inferenceTexts;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private List<DataForInference> getDataForInference(List<IngestDocumentWrapper> ingestDocumentWrappers) {
        ArrayList<DataForInference> dataForInferences = new ArrayList<DataForInference>();
        for (IngestDocumentWrapper ingestDocumentWrapper : ingestDocumentWrappers) {
            Map<String, Object> processMap = null;
            List<String> inferenceList = null;
            try {
                this.validateEmbeddingFieldsValue(ingestDocumentWrapper.getIngestDocument());
                processMap = this.buildMapWithTargetKeys(ingestDocumentWrapper.getIngestDocument());
                inferenceList = this.createInferenceList(processMap);
            }
            catch (Exception e) {
                ingestDocumentWrapper.update(ingestDocumentWrapper.getIngestDocument(), e);
            }
            finally {
                dataForInferences.add(new DataForInference(ingestDocumentWrapper, processMap, inferenceList));
            }
        }
        return dataForInferences;
    }

    private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
        ArrayList<String> texts = new ArrayList<String>();
        knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
            Object sourceValue = knnMapEntry.getValue();
            if (sourceValue instanceof List) {
                texts.addAll((List)sourceValue);
            } else if (sourceValue instanceof Map) {
                this.createInferenceListForMapTypeInput(sourceValue, texts);
            } else {
                texts.add(sourceValue.toString());
            }
        });
        return texts;
    }

    private void createInferenceListForMapTypeInput(Object sourceValue, List<String> texts) {
        if (sourceValue instanceof Map) {
            ((Map)sourceValue).forEach((k, v) -> this.createInferenceListForMapTypeInput(v, texts));
        } else if (sourceValue instanceof List) {
            texts.addAll((List)sourceValue);
        } else {
            if (sourceValue == null) {
                return;
            }
            texts.add(sourceValue.toString());
        }
    }

    @VisibleForTesting
    Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        LinkedHashMap<String, Object> mapWithProcessorKeys = new LinkedHashMap<String, Object>();
        for (Map.Entry<String, Object> fieldMapEntry : this.fieldMap.entrySet()) {
            Pair<String, Object> processedNestedKey = this.processNestedKey(fieldMapEntry);
            String originalKey = (String)processedNestedKey.getKey();
            Object targetKey = processedNestedKey.getValue();
            if (targetKey instanceof Map) {
                LinkedHashMap<String, Object> treeRes = new LinkedHashMap<String, Object>();
                this.buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
                mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
                continue;
            }
            mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
        }
        return mapWithProcessorKeys;
    }

    @VisibleForTesting
    void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> sourceAndMetadataMap, Map<String, Object> treeRes) {
        if (Objects.isNull(processorKey) || Objects.isNull(sourceAndMetadataMap)) {
            return;
        }
        if (processorKey instanceof Map) {
            LinkedHashMap<String, Object> next = new LinkedHashMap<String, Object>();
            if (sourceAndMetadataMap.get(parentKey) instanceof Map) {
                for (Map.Entry<String, Object> entry : ((Map)processorKey).entrySet()) {
                    Pair<String, Object> processedNestedKey = this.processNestedKey(entry);
                    this.buildNestedMap((String)processedNestedKey.getKey(), processedNestedKey.getValue(), (Map)sourceAndMetadataMap.get(parentKey), next);
                }
            } else if (sourceAndMetadataMap.get(parentKey) instanceof List) {
                for (Map.Entry entry : ((Map)processorKey).entrySet()) {
                    List list = (List)sourceAndMetadataMap.get(parentKey);
                    List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList());
                    LinkedHashMap<String, Object> map = new LinkedHashMap<String, Object>();
                    map.put((String)entry.getKey(), listOfStrings);
                    this.buildNestedMap((String)entry.getKey(), entry.getValue(), map, next);
                }
            }
            treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
        } else {
            String key = String.valueOf(processorKey);
            treeRes.put(key, sourceAndMetadataMap.get(parentKey));
        }
    }

    @VisibleForTesting
    protected Pair<String, Object> processNestedKey(Map.Entry<String, Object> nestedFieldMapEntry) {
        String originalKey = nestedFieldMapEntry.getKey();
        LinkedHashMap targetKey = nestedFieldMapEntry.getValue();
        int nestedDotIndex = originalKey.indexOf(46);
        if (nestedDotIndex != -1) {
            LinkedHashMap newTargetKey = new LinkedHashMap();
            newTargetKey.put(originalKey.substring(nestedDotIndex + 1), targetKey);
            targetKey = newTargetKey;
            originalKey = originalKey.substring(0, nestedDotIndex);
        }
        return new ImmutablePair((Object)originalKey, (Object)targetKey);
    }

    private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        String indexName = sourceAndMetadataMap.get("_index").toString();
        ProcessorDocumentUtils.validateMapTypeValue(FIELD_MAP_FIELD, sourceAndMetadataMap, this.fieldMap, indexName, this.clusterService, this.environment, false);
    }

    protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> processorMap, List<?> results) {
        Objects.requireNonNull(results, "embedding failed, inference returns null result!");
        log.debug("Model inference result fetched, starting build vector output!");
        Map<String, Object> nlpResult = this.buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata());
        nlpResult.forEach((arg_0, arg_1) -> ((IngestDocument)ingestDocument).setFieldValue(arg_0, arg_1));
    }

    @VisibleForTesting
    Map<String, Object> buildNLPResult(Map<String, Object> processorMap, List<?> results, Map<String, Object> sourceAndMetadataMap) {
        IndexWrapper indexWrapper = new IndexWrapper(0);
        LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
        for (Map.Entry<String, Object> knnMapEntry : processorMap.entrySet()) {
            Pair<String, Object> processedNestedKey = this.processNestedKey(knnMapEntry);
            String knnKey = (String)processedNestedKey.getKey();
            Object sourceValue = processedNestedKey.getValue();
            if (sourceValue instanceof String) {
                result.put(knnKey, results.get(indexWrapper.index++));
                continue;
            }
            if (sourceValue instanceof List) {
                result.put(knnKey, this.buildNLPResultForListType((List)sourceValue, results, indexWrapper));
                continue;
            }
            if (!(sourceValue instanceof Map)) continue;
            this.putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap);
        }
        return result;
    }

    private void putNLPResultToSourceMapForMapType(String processorKey, Object sourceValue, List<?> results, IndexWrapper indexWrapper, Map<String, Object> sourceAndMetadataMap) {
        if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) {
            return;
        }
        if (sourceValue instanceof Map) {
            for (Map.Entry<String, Object> entry : ((Map)sourceValue).entrySet()) {
                HashMap<String, Object> sourceMap;
                if (sourceAndMetadataMap.get(processorKey) instanceof List) {
                    for (Map nestedElement : (List)sourceAndMetadataMap.get(processorKey)) {
                        nestedElement.put(entry.getKey(), results.get(indexWrapper.index++));
                    }
                    continue;
                }
                Pair<String, Object> processedNestedKey = this.processNestedKey(entry);
                if (sourceAndMetadataMap.get(processorKey) == null) {
                    sourceMap = new HashMap();
                    sourceAndMetadataMap.put(processorKey, sourceMap);
                } else {
                    sourceMap = (Map)sourceAndMetadataMap.get(processorKey);
                }
                this.putNLPResultToSourceMapForMapType((String)processedNestedKey.getKey(), processedNestedKey.getValue(), results, indexWrapper, sourceMap);
            }
        } else if (sourceValue instanceof String) {
            sourceAndMetadataMap.merge(processorKey, results.get(indexWrapper.index++), REMAPPING_FUNCTION);
        } else if (sourceValue instanceof List) {
            sourceAndMetadataMap.merge(processorKey, this.buildNLPResultForListType((List)sourceValue, results, indexWrapper), REMAPPING_FUNCTION);
        }
    }

    private List<Map<String, Object>> buildNLPResultForListType(List<String> sourceValue, List<?> results, IndexWrapper indexWrapper) {
        ArrayList<Map<String, Object>> keyToResult = new ArrayList<Map<String, Object>>();
        IntStream.range(0, sourceValue.size()).forEachOrdered(x -> keyToResult.add((Map<String, Object>)ImmutableMap.of((Object)this.listTypeNestedMapKey, results.get(indexWrapper.index++))));
        return keyToResult;
    }

    public String getType() {
        return this.type;
    }

    static class IndexWrapper {
        private int index;

        protected IndexWrapper(int index) {
            this.index = index;
        }
    }

    private static class DataForInference {
        private final IngestDocumentWrapper ingestDocumentWrapper;
        private final Map<String, Object> processMap;
        private final List<String> inferenceList;

        @Generated
        public IngestDocumentWrapper getIngestDocumentWrapper() {
            return this.ingestDocumentWrapper;
        }

        @Generated
        public Map<String, Object> getProcessMap() {
            return this.processMap;
        }

        @Generated
        public List<String> getInferenceList() {
            return this.inferenceList;
        }

        @Generated
        public DataForInference(IngestDocumentWrapper ingestDocumentWrapper, Map<String, Object> processMap, List<String> inferenceList) {
            this.ingestDocumentWrapper = ingestDocumentWrapper;
            this.processMap = processMap;
            this.inferenceList = inferenceList;
        }
    }
}

