/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.common.model;

import java.security.AccessController;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import lombok.Generated;
import lombok.NonNull;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.client.Client;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.ml.common.model.Guardrail;
import org.opensearch.ml.common.model.Guardrails;
import org.opensearch.ml.common.model.StopWords;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.repackage.com.google.common.collect.ImmutableSet;
import org.opensearch.search.builder.SearchSourceBuilder;

public class MLGuard {
    @Generated
    private static final Logger log = LogManager.getLogger(MLGuard.class);
    private Map<String, List<String>> stopWordsIndicesInput = new HashMap<String, List<String>>();
    private Map<String, List<String>> stopWordsIndicesOutput = new HashMap<String, List<String>>();
    private List<String> inputRegex;
    private List<String> outputRegex;
    private List<Pattern> inputRegexPattern;
    private List<Pattern> outputRegexPattern;
    private NamedXContentRegistry xContentRegistry;
    private Client client;
    private Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");

    public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
        this.xContentRegistry = xContentRegistry;
        this.client = client;
        if (guardrails == null) {
            return;
        }
        Guardrail inputGuardrail = guardrails.getInputGuardrail();
        Guardrail outputGuardrail = guardrails.getOutputGuardrail();
        if (inputGuardrail != null) {
            this.fillStopWordsToMap(inputGuardrail, this.stopWordsIndicesInput);
            this.inputRegex = inputGuardrail.getRegex() == null ? new ArrayList() : Arrays.asList(inputGuardrail.getRegex());
            this.inputRegexPattern = this.inputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
        }
        if (outputGuardrail != null) {
            this.fillStopWordsToMap(outputGuardrail, this.stopWordsIndicesOutput);
            this.outputRegex = outputGuardrail.getRegex() == null ? new ArrayList() : Arrays.asList(outputGuardrail.getRegex());
            this.outputRegexPattern = this.outputRegex.stream().map(reg -> Pattern.compile(reg)).collect(Collectors.toList());
        }
    }

    private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map<String, List<String>> map) {
        Objects.requireNonNull(guardrail, "guardrail is marked non-null but is null");
        List<StopWords> stopWords = guardrail.getStopWords();
        if (stopWords == null || stopWords.isEmpty()) {
            return;
        }
        for (StopWords e : stopWords) {
            map.put(e.getIndex(), Arrays.asList(e.getSourceFields()));
        }
    }

    public Boolean validate(String input, Type type) {
        switch (type.ordinal()) {
            case 0: {
                return this.validateRegexList(input, this.inputRegexPattern) != false && this.validateStopWords(input, this.stopWordsIndicesInput) != false;
            }
            case 1: {
                return this.validateRegexList(input, this.outputRegexPattern) != false && this.validateStopWords(input, this.stopWordsIndicesOutput) != false;
            }
        }
        throw new IllegalArgumentException("Unsupported type to validate for guardrails.");
    }

    public Boolean validateRegexList(String input, List<Pattern> regexPatterns) {
        for (Pattern pattern : regexPatterns) {
            if (this.validateRegex(input, pattern).booleanValue()) continue;
            return false;
        }
        return true;
    }

    public Boolean validateRegex(String input, Pattern pattern) {
        Matcher matcher = pattern.matcher(input);
        return !matcher.matches();
    }

    public Boolean validateStopWords(String input, Map<String, List<String>> stopWordsIndices) {
        for (Map.Entry<String, List<String>> entry : stopWordsIndices.entrySet()) {
            if (this.validateStopWordsSingleIndex(input, entry.getKey(), entry.getValue()).booleanValue()) continue;
            return false;
        }
        return true;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Boolean validateStopWordsSingleIndex(String input, String indexName, List<String> fieldNames) {
        AtomicBoolean hitStopWords = new AtomicBoolean(false);
        HashMap<String, String> documentMap = new HashMap<String, String>();
        for (String field : fieldNames) {
            documentMap.put(field, input);
        }
        Map queryBodyMap = Map.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
        CountDownLatch latch = new CountDownLatch(1);
        try (ThreadContext.StoredContext context = null;){
            String queryBody = AccessController.doPrivileged(() -> StringUtils.gson.toJson((Object)queryBodyMap));
            SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
            XContentParser queryParser = XContentType.JSON.xContent().createParser(this.xContentRegistry, (DeprecationHandler)LoggingDeprecationHandler.INSTANCE, queryBody);
            searchSourceBuilder.parseXContent(queryParser);
            searchSourceBuilder.size(1);
            SearchRequest searchRequest = new SearchRequest().source(searchSourceBuilder).indices(new String[]{indexName});
            if (this.isStopWordsSystemIndex(indexName)) {
                ThreadContext.StoredContext finalContext = context = this.client.threadPool().getThreadContext().stashContext();
                this.client.search(searchRequest, ActionListener.runBefore((ActionListener)new LatchedActionListener(ActionListener.wrap(r -> {
                    if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0L) {
                        hitStopWords.set(true);
                    }
                }, e -> {
                    log.error("Failed to search stop words index {}", (Object)indexName, e);
                    hitStopWords.set(true);
                }), latch), () -> finalContext.restore()));
            } else {
                this.client.search(searchRequest, (ActionListener)new LatchedActionListener(ActionListener.wrap(r -> {
                    if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0L) {
                        hitStopWords.set(true);
                    }
                }, e -> {
                    log.error("Failed to search stop words index {}", (Object)indexName, e);
                    hitStopWords.set(true);
                }), latch));
            }
        }
        try {
            latch.await(5L, TimeUnit.SECONDS);
        }
        catch (InterruptedException e2) {
            log.error("[validateStopWords] Searching stop words index was timeout.", (Throwable)e2);
            throw new IllegalStateException(e2);
        }
        return hitStopWords.get();
    }

    private boolean isStopWordsSystemIndex(String index) {
        return this.stopWordsIndices.contains(index);
    }

    @Generated
    public Map<String, List<String>> getStopWordsIndicesInput() {
        return this.stopWordsIndicesInput;
    }

    @Generated
    public Map<String, List<String>> getStopWordsIndicesOutput() {
        return this.stopWordsIndicesOutput;
    }

    @Generated
    public List<String> getInputRegex() {
        return this.inputRegex;
    }

    @Generated
    public List<String> getOutputRegex() {
        return this.outputRegex;
    }

    @Generated
    public List<Pattern> getInputRegexPattern() {
        return this.inputRegexPattern;
    }

    @Generated
    public List<Pattern> getOutputRegexPattern() {
        return this.outputRegexPattern;
    }

    @Generated
    public NamedXContentRegistry getXContentRegistry() {
        return this.xContentRegistry;
    }

    @Generated
    public Client getClient() {
        return this.client;
    }

    @Generated
    public Set<String> getStopWordsIndices() {
        return this.stopWordsIndices;
    }

    public static enum Type {
        INPUT,
        OUTPUT;

    }
}

