/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.processor;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import lombok.Generated;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.env.Environment;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;

public class TextEmbeddingProcessor
extends AbstractProcessor {
    @Generated
    private static final Logger log = LogManager.getLogger(TextEmbeddingProcessor.class);
    public static final String TYPE = "text_embedding";
    public static final String MODEL_ID_FIELD = "model_id";
    public static final String FIELD_MAP_FIELD = "field_map";
    private static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
    @VisibleForTesting
    private final String modelId;
    private final Map<String, Object> fieldMap;
    private final MLCommonsClientAccessor mlCommonsClientAccessor;
    private final Environment environment;

    public TextEmbeddingProcessor(String tag, String description, String modelId, Map<String, Object> fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment) {
        super(tag, description);
        if (StringUtils.isBlank((CharSequence)modelId)) {
            throw new IllegalArgumentException("model_id is null or empty, can not process it");
        }
        this.validateEmbeddingConfiguration(fieldMap);
        this.modelId = modelId;
        this.fieldMap = fieldMap;
        this.mlCommonsClientAccessor = clientAccessor;
        this.environment = environment;
    }

    private void validateEmbeddingConfiguration(Map<String, Object> fieldMap) {
        if (fieldMap == null || fieldMap.size() == 0 || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank((CharSequence)((CharSequence)x.getKey())) || Objects.isNull(x.getValue()) || StringUtils.isBlank((CharSequence)x.getValue().toString()))) {
            throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value");
        }
    }

    public IngestDocument execute(IngestDocument ingestDocument) {
        return ingestDocument;
    }

    public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
        try {
            this.validateEmbeddingFieldsValue(ingestDocument);
            Map<String, Object> knnMap = this.buildMapWithKnnKeyAndOriginalValue(ingestDocument);
            List<String> inferenceList = this.createInferenceList(knnMap);
            if (inferenceList.size() == 0) {
                handler.accept(ingestDocument, null);
            } else {
                this.mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, (ActionListener<List<List<Float>>>)ActionListener.wrap(vectors -> {
                    this.appendVectorFieldsToDocument(ingestDocument, knnMap, (List<List<Float>>)vectors);
                    handler.accept(ingestDocument, null);
                }, e -> handler.accept((IngestDocument)null, (Exception)e)));
            }
        }
        catch (Exception e2) {
            handler.accept(null, e2);
        }
    }

    void appendVectorFieldsToDocument(IngestDocument ingestDocument, Map<String, Object> knnMap, List<List<Float>> vectors) {
        Objects.requireNonNull(vectors, "embedding failed, inference returns null result!");
        log.debug("Text embedding result fetched, starting build vector output!");
        Map<String, Object> textEmbeddingResult = this.buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata());
        textEmbeddingResult.forEach((arg_0, arg_1) -> ((IngestDocument)ingestDocument).appendFieldValue(arg_0, arg_1));
    }

    private List<String> createInferenceList(Map<String, Object> knnKeyMap) {
        ArrayList<String> texts = new ArrayList<String>();
        knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> {
            Object sourceValue = knnMapEntry.getValue();
            if (sourceValue instanceof List) {
                texts.addAll((List)sourceValue);
            } else if (sourceValue instanceof Map) {
                this.createInferenceListForMapTypeInput(sourceValue, texts);
            } else {
                texts.add(sourceValue.toString());
            }
        });
        return texts;
    }

    private void createInferenceListForMapTypeInput(Object sourceValue, List<String> texts) {
        if (sourceValue instanceof Map) {
            ((Map)sourceValue).forEach((k, v) -> this.createInferenceListForMapTypeInput(v, texts));
        } else if (sourceValue instanceof List) {
            texts.addAll((List)sourceValue);
        } else {
            if (sourceValue == null) {
                return;
            }
            texts.add(sourceValue.toString());
        }
    }

    @VisibleForTesting
    Map<String, Object> buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        LinkedHashMap<String, Object> mapWithKnnKeys = new LinkedHashMap<String, Object>();
        for (Map.Entry<String, Object> fieldMapEntry : this.fieldMap.entrySet()) {
            String originalKey = fieldMapEntry.getKey();
            Object targetKey = fieldMapEntry.getValue();
            if (targetKey instanceof Map) {
                LinkedHashMap<String, Object> treeRes = new LinkedHashMap<String, Object>();
                this.buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes);
                mapWithKnnKeys.put(originalKey, treeRes.get(originalKey));
                continue;
            }
            mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
        }
        return mapWithKnnKeys;
    }

    private void buildMapWithKnnKeyAndOriginalValueForMapType(String parentKey, Object knnKey, Map<String, Object> sourceAndMetadataMap, Map<String, Object> treeRes) {
        if (knnKey == null || sourceAndMetadataMap == null) {
            return;
        }
        if (knnKey instanceof Map) {
            LinkedHashMap<String, Object> next = new LinkedHashMap<String, Object>();
            for (Map.Entry nestedFieldMapEntry : ((Map)knnKey).entrySet()) {
                this.buildMapWithKnnKeyAndOriginalValueForMapType((String)nestedFieldMapEntry.getKey(), nestedFieldMapEntry.getValue(), (Map)sourceAndMetadataMap.get(parentKey), next);
            }
            treeRes.put(parentKey, next);
        } else {
            String key = String.valueOf(knnKey);
            treeRes.put(key, sourceAndMetadataMap.get(parentKey));
        }
    }

    @VisibleForTesting
    Map<String, Object> buildTextEmbeddingResult(Map<String, Object> knnMap, List<List<Float>> modelTensorList, Map<String, Object> sourceAndMetadataMap) {
        IndexWrapper indexWrapper = new IndexWrapper(0);
        LinkedHashMap<String, Object> result = new LinkedHashMap<String, Object>();
        for (Map.Entry<String, Object> knnMapEntry : knnMap.entrySet()) {
            String knnKey = knnMapEntry.getKey();
            Object sourceValue = knnMapEntry.getValue();
            if (sourceValue instanceof String) {
                List<Float> modelTensor = modelTensorList.get(indexWrapper.index++);
                result.put(knnKey, modelTensor);
                continue;
            }
            if (sourceValue instanceof List) {
                result.put(knnKey, this.buildTextEmbeddingResultForListType((List)sourceValue, modelTensorList, indexWrapper));
                continue;
            }
            if (!(sourceValue instanceof Map)) continue;
            this.putTextEmbeddingResultToSourceMapForMapType(knnKey, sourceValue, modelTensorList, indexWrapper, sourceAndMetadataMap);
        }
        return result;
    }

    private void putTextEmbeddingResultToSourceMapForMapType(String knnKey, Object sourceValue, List<List<Float>> modelTensorList, IndexWrapper indexWrapper, Map<String, Object> sourceAndMetadataMap) {
        if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) {
            return;
        }
        if (sourceValue instanceof Map) {
            for (Map.Entry inputNestedMapEntry : ((Map)sourceValue).entrySet()) {
                this.putTextEmbeddingResultToSourceMapForMapType((String)inputNestedMapEntry.getKey(), inputNestedMapEntry.getValue(), modelTensorList, indexWrapper, (Map)sourceAndMetadataMap.get(knnKey));
            }
        } else if (sourceValue instanceof String) {
            sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++));
        } else if (sourceValue instanceof List) {
            sourceAndMetadataMap.put(knnKey, this.buildTextEmbeddingResultForListType((List)sourceValue, modelTensorList, indexWrapper));
        }
    }

    private List<Map<String, List<Float>>> buildTextEmbeddingResultForListType(List<String> sourceValue, List<List<Float>> modelTensorList, IndexWrapper indexWrapper) {
        ArrayList<Map<String, List<Float>>> numbers = new ArrayList<Map<String, List<Float>>>();
        IntStream.range(0, sourceValue.size()).forEachOrdered(x -> numbers.add((Map<String, List<Float>>)ImmutableMap.of((Object)LIST_TYPE_NESTED_MAP_KEY, (Object)((List)modelTensorList.get(indexWrapper.index++)))));
        return numbers;
    }

    private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
        Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
        for (Map.Entry<String, Object> embeddingFieldsEntry : this.fieldMap.entrySet()) {
            Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey());
            if (sourceValue == null) continue;
            String sourceKey = embeddingFieldsEntry.getKey();
            Class<?> sourceValueClass = sourceValue.getClass();
            if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) {
                this.validateNestedTypeValue(sourceKey, sourceValue, () -> 1);
                continue;
            }
            if (!String.class.isAssignableFrom(sourceValueClass)) {
                throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it");
            }
            if (!StringUtils.isBlank((CharSequence)sourceValue.toString())) continue;
            throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it");
        }
    }

    private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier<Integer> maxDepthSupplier) {
        int maxDepth = maxDepthSupplier.get();
        if ((long)maxDepth > (Long)MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(this.environment.settings())) {
            throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it");
        }
        if (List.class.isAssignableFrom(sourceValue.getClass())) {
            TextEmbeddingProcessor.validateListTypeValue(sourceKey, sourceValue);
        } else if (Map.class.isAssignableFrom(sourceValue.getClass())) {
            ((Map)sourceValue).values().stream().filter(Objects::nonNull).forEach(x -> this.validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1));
        } else {
            if (!String.class.isAssignableFrom(sourceValue.getClass())) {
                throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it");
            }
            if (StringUtils.isBlank((CharSequence)sourceValue.toString())) {
                throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it");
            }
        }
    }

    private static void validateListTypeValue(String sourceKey, Object sourceValue) {
        for (Object value : (List)sourceValue) {
            if (value == null) {
                throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it");
            }
            if (!(value instanceof String)) {
                throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it");
            }
            if (!StringUtils.isBlank((CharSequence)value.toString())) continue;
            throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it");
        }
    }

    public String getType() {
        return TYPE;
    }

    static class IndexWrapper {
        private int index;

        protected IndexWrapper(int index) {
            this.index = index;
        }
    }
}

