/*
 * Decompiled with CFR 0.152.
 */
package org.languagetool.rules;

import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.Streams;
import java.io.IOException;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.net.ssl.SSLException;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.languagetool.AnalyzedSentence;
import org.languagetool.Language;
import org.languagetool.languagemodel.bert.RemoteLanguageModel;
import org.languagetool.rules.RemoteRule;
import org.languagetool.rules.RemoteRuleConfig;
import org.languagetool.rules.RemoteRuleResult;
import org.languagetool.rules.Rule;
import org.languagetool.rules.RuleMatch;
import org.languagetool.rules.SuggestedReplacement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BERTSuggestionRanking
extends RemoteRule {
    public static final String RULE_ID = "BERT_SUGGESTION_RANKING";
    private static final Logger logger = LoggerFactory.getLogger(BERTSuggestionRanking.class);
    private static final LoadingCache<RemoteRuleConfig, RemoteLanguageModel> models = CacheBuilder.newBuilder().build(CacheLoader.from(serviceConfiguration -> {
        String host = serviceConfiguration.getUrl();
        int port = serviceConfiguration.getPort();
        boolean ssl = Boolean.parseBoolean(serviceConfiguration.getOptions().getOrDefault("secure", "false"));
        String key = serviceConfiguration.getOptions().get("clientKey");
        String cert = serviceConfiguration.getOptions().get("clientCertificate");
        String ca = serviceConfiguration.getOptions().get("rootCertificate");
        try {
            return new RemoteLanguageModel(host, port, ssl, key, cert, ca);
        }
        catch (SSLException e) {
            throw new RuntimeException(e);
        }
    }));
    protected int suggestionLimit = 10;
    private final RemoteLanguageModel model;
    private final Rule wrappedRule;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public BERTSuggestionRanking(Language language, Rule rule, RemoteRuleConfig config, boolean inputLogging) {
        super(language, rule.messages, config, inputLogging, rule.getId());
        this.wrappedRule = rule;
        super.setCategory(this.wrappedRule.getCategory());
        LoadingCache<RemoteRuleConfig, RemoteLanguageModel> loadingCache = models;
        synchronized (loadingCache) {
            RemoteLanguageModel model = null;
            try {
                model = (RemoteLanguageModel)models.get((Object)this.serviceConfiguration);
            }
            catch (Exception e) {
                logger.error("Could not connect to BERT service at " + this.serviceConfiguration + " for suggestion reranking", (Throwable)e);
            }
            this.model = model;
        }
    }

    protected List<SuggestedReplacement> prepareSuggestions(List<SuggestedReplacement> suggestions) {
        this.suggestionLimit = suggestions.stream().anyMatch(s -> s.getType() == SuggestedReplacement.SuggestionType.Translation) ? 25 : 10;
        return suggestions.subList(0, Math.min(suggestions.size(), this.suggestionLimit));
    }

    @Override
    protected RemoteRule.RemoteRequest prepareRequest(List<AnalyzedSentence> sentences, Long textSessionId) {
        LinkedList<RuleMatch> matches = new LinkedList<RuleMatch>();
        LinkedList<RemoteLanguageModel.Request> requests = new LinkedList<RemoteLanguageModel.Request>();
        try {
            for (AnalyzedSentence sentence : sentences) {
                RuleMatch[] sentenceMatches;
                for (RuleMatch match : sentenceMatches = this.wrappedRule.match(sentence)) {
                    match.setSuggestedReplacementObjects(this.prepareSuggestions(match.getSuggestedReplacementObjects()));
                    requests.add(this.buildRequest(match));
                }
                Collections.addAll(matches, sentenceMatches);
            }
            return new MatchesForReordering(sentences, matches, requests);
        }
        catch (IOException e) {
            logger.error("Error while executing rule " + this.wrappedRule.getId(), (Throwable)e);
            return new MatchesForReordering(sentences, Collections.emptyList(), Collections.emptyList());
        }
    }

    @Override
    protected RemoteRuleResult fallbackResults(RemoteRule.RemoteRequest request) {
        MatchesForReordering req = (MatchesForReordering)request;
        return new RemoteRuleResult(false, false, req.matches, req.sentences);
    }

    @Override
    protected Callable<RemoteRuleResult> executeRequest(RemoteRule.RemoteRequest request, long timeoutMilliseconds) throws TimeoutException {
        return () -> {
            if (this.model == null) {
                return this.fallbackResults(request);
            }
            MatchesForReordering data = (MatchesForReordering)request;
            List<RuleMatch> matches = data.matches;
            List<RemoteLanguageModel.Request> requests = data.requests;
            Streams.FunctionWithIndex mapIndices = (req, index) -> req != null ? Long.valueOf(index) : null;
            List indices = Streams.mapWithIndex(requests.stream(), (Streams.FunctionWithIndex)mapIndices).filter(Objects::nonNull).collect(Collectors.toList());
            if ((requests = requests.stream().filter(Objects::nonNull).collect(Collectors.toList())).isEmpty()) {
                return new RemoteRuleResult(false, true, matches, data.sentences);
            }
            List<List<Double>> results = this.model.batchScore(requests, timeoutMilliseconds);
            for (int i = 0; i < indices.size(); ++i) {
                List<Double> scores = results.get(i);
                String userWord = requests.get((int)i).text.substring(requests.get((int)i).start, requests.get((int)i).end);
                RuleMatch match = matches.get(((Long)indices.get(i)).intValue());
                List<SuggestedReplacement> ranked = Streams.zip(match.getSuggestedReplacementObjects().stream(), scores.stream(), Pair::of).sorted(new CuratedAndSameCaseComparator(userWord)).map(Pair::getLeft).collect(Collectors.toList());
                match.setSuggestedReplacementObjects(ranked);
            }
            return new RemoteRuleResult(true, true, matches, data.sentences);
        };
    }

    @Nullable
    private RemoteLanguageModel.Request buildRequest(RuleMatch match) {
        List<String> suggestions = match.getSuggestedReplacements();
        if (suggestions != null && suggestions.size() > 1) {
            return new RemoteLanguageModel.Request(match.getSentence().getText(), match.getFromPos(), match.getToPos(), suggestions);
        }
        return null;
    }

    @Override
    public String getId() {
        return this.wrappedRule.getId();
    }

    @Override
    public String getDescription() {
        return this.wrappedRule.getDescription();
    }

    static {
        shutdownRoutines.add(() -> models.asMap().values().forEach(RemoteLanguageModel::shutdown));
    }

    private static class CuratedAndSameCaseComparator
    implements Comparator<Pair<SuggestedReplacement, Double>> {
        private final String userWord;

        CuratedAndSameCaseComparator(String userWord) {
            this.userWord = userWord;
        }

        @Override
        public int compare(Pair<SuggestedReplacement, Double> a, Pair<SuggestedReplacement, Double> b) {
            if (((SuggestedReplacement)a.getKey()).getReplacement().equalsIgnoreCase(this.userWord)) {
                return -1;
            }
            if (((SuggestedReplacement)b.getKey()).getReplacement().equalsIgnoreCase(this.userWord)) {
                return 1;
            }
            if (((SuggestedReplacement)a.getKey()).getType() != ((SuggestedReplacement)b.getKey()).getType()) {
                if (((SuggestedReplacement)a.getKey()).getType() == SuggestedReplacement.SuggestionType.Curated) {
                    return -1;
                }
                if (((SuggestedReplacement)b.getKey()).getType() == SuggestedReplacement.SuggestionType.Curated) {
                    return 1;
                }
                return ((Double)b.getRight()).compareTo((Double)a.getRight());
            }
            return ((Double)b.getRight()).compareTo((Double)a.getRight());
        }
    }

    class MatchesForReordering
    extends RemoteRule.RemoteRequest {
        final List<AnalyzedSentence> sentences;
        final List<RuleMatch> matches;
        final List<RemoteLanguageModel.Request> requests;

        MatchesForReordering(List<AnalyzedSentence> sentences, List<RuleMatch> matches, List<RemoteLanguageModel.Request> requests) {
            super(BERTSuggestionRanking.this);
            this.sentences = sentences;
            this.matches = matches;
            this.requests = requests;
        }
    }
}

