/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.nlp.lm.io;

import edu.berkeley.nlp.lm.ArrayEncodedNgramLanguageModel;
import edu.berkeley.nlp.lm.ConfigOptions;
import edu.berkeley.nlp.lm.WordIndexer;
import edu.berkeley.nlp.lm.io.ArpaLmReaderCallback;
import edu.berkeley.nlp.lm.io.LmReader;
import edu.berkeley.nlp.lm.io.NgramOrderedLmReaderCallback;
import edu.berkeley.nlp.lm.map.HashNgramMap;
import edu.berkeley.nlp.lm.map.NgramMap;
import edu.berkeley.nlp.lm.util.Logger;
import edu.berkeley.nlp.lm.util.LongRef;
import edu.berkeley.nlp.lm.values.KneserNeyCountValueContainer;
import edu.berkeley.nlp.lm.values.ProbBackoffPair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class KneserNeyLmReaderCallback<W>
implements NgramOrderedLmReaderCallback<LongRef>,
LmReader<ProbBackoffPair, ArpaLmReaderCallback<ProbBackoffPair>>,
ArrayEncodedNgramLanguageModel<W>,
Serializable {
    protected static final long serialVersionUID = 1L;
    protected static final int MAX_ORDER = 10;
    protected static final float DEFAULT_DISCOUNT = 0.75f;
    protected final int lmOrder;
    protected final WordIndexer<W> wordIndexer;
    protected final HashNgramMap<KneserNeyCountValueContainer.KneserNeyCounts> ngrams;
    protected final ConfigOptions opts;
    protected final int startIndex;

    public KneserNeyLmReaderCallback(WordIndexer<W> wordIndexer, int maxOrder) {
        this(wordIndexer, maxOrder, new ConfigOptions());
    }

    public KneserNeyLmReaderCallback(WordIndexer<W> wordIndexer, int maxOrder, ConfigOptions opts) {
        this.lmOrder = maxOrder;
        this.startIndex = wordIndexer.getIndexPossiblyUnk(wordIndexer.getStartSymbol());
        if (maxOrder >= 10) {
            throw new IllegalArgumentException("Reguested n-grams of order " + maxOrder + " but we only allow up to " + 10);
        }
        this.opts = opts;
        double last = Double.NEGATIVE_INFINITY;
        for (double c : opts.kneserNeyMinCounts) {
            if (c < last) {
                throw new IllegalArgumentException("Please ensure that ConfigOptions.kneserNeyMinCounts is monotonic (value was " + Arrays.toString(opts.kneserNeyMinCounts) + ")");
            }
            last = c;
        }
        this.wordIndexer = wordIndexer;
        KneserNeyCountValueContainer values = new KneserNeyCountValueContainer(this.lmOrder, this.startIndex);
        this.ngrams = HashNgramMap.createExplicitWordHashNgramMap(values, opts, this.lmOrder, false);
    }

    public void call(W[] ngram, LongRef value) {
        int[] ints = new int[ngram.length];
        for (int i = 0; i < ngram.length; ++i) {
            ints[i] = this.wordIndexer.getOrAddIndex(ngram[i]);
        }
        this.call(ints, 0, ints.length, value, "");
    }

    public void callJustLast(W[] ngram, LongRef value, long[][] scratch) {
        int[] ints = new int[ngram.length];
        for (int i = 0; i < ngram.length; ++i) {
            ints[i] = this.wordIndexer.getOrAddIndex(ngram[i]);
        }
        this.addNgram(ints, 0, ints.length, value, "", true, scratch);
    }

    @Override
    public void call(int[] ngram, int startPos, int endPos, LongRef value, String words) {
        long[][] prevOffsets = new long[this.lmOrder][endPos - startPos];
        this.addNgram(ngram, startPos, endPos, value, words, false, prevOffsets);
    }

    public void addNgram(int[] ngram, int startPos, int endPos, LongRef value, String words, boolean justLastWord, long[][] scratch) {
        KneserNeyCountValueContainer.KneserNeyCounts scratchCounts = new KneserNeyCountValueContainer.KneserNeyCounts();
        this.ngrams.rehashIfNecessary(endPos - startPos);
        for (int ngramOrder = 0; ngramOrder < this.lmOrder; ++ngramOrder) {
            for (int i = startPos; i < endPos; ++i) {
                long suffixOffset;
                int j = i + ngramOrder + 1;
                if (j > endPos) continue;
                scratchCounts.tokenCounts = value.value;
                long prevOffset = ngramOrder == 0 ? 0L : scratch[ngramOrder - 1][i];
                long l = suffixOffset = ngramOrder == 0 ? 0L : scratch[ngramOrder - 1][i + 1];
                assert (prevOffset >= 0L);
                scratch[ngramOrder][i - startPos] = this.ngrams.putWithOffsetAndSuffix(ngram, i, j, prevOffset, suffixOffset, !justLastWord || j == endPos ? scratchCounts : null);
            }
        }
    }

    protected float interpolateProb(int[] ngram, int startPos, int endPos) {
        if (startPos == endPos) {
            return 0.0f;
        }
        float backoff = this.getLowerOrderBackoff(ngram, startPos, endPos - 1);
        float prob = this.getLowerOrderProb(ngram, startPos, endPos);
        return prob + backoff * this.interpolateProb(ngram, startPos + 1, endPos);
    }

    protected float getHighestOrderProb(int[] ngram, int startPos, int endPos) {
        KneserNeyCountValueContainer.KneserNeyCounts counts = this.getCounts(ngram, startPos, endPos, false);
        KneserNeyCountValueContainer.KneserNeyCounts rightDotCounts = this.getCounts(ngram, startPos, endPos - 1, true);
        int ngramOrder = endPos - startPos - 1;
        float D = this.getDiscountForOrder(ngramOrder);
        float prob = rightDotCounts.tokenCounts == 0L ? 0.0f : Math.max(0.0f, ((float)counts.tokenCounts - D) / (float)rightDotCounts.tokenCounts);
        return prob;
    }

    protected float getLowerOrderProb(int[] ngram, int startPos, int endPos) {
        if (startPos == endPos) {
            return 1.0f;
        }
        KneserNeyCountValueContainer.KneserNeyCounts counts = this.getCounts(ngram, startPos, endPos, false);
        KneserNeyCountValueContainer.KneserNeyCounts prefixCounts = this.getCounts(ngram, startPos, endPos - 1, true);
        float probDiscount = endPos - startPos == 1 ? 0.0f : this.getDiscountForOrder(endPos - startPos - 1);
        float prob = prefixCounts.dotdotTypeCounts == 0L ? 0.0f : Math.max(0.0f, (float)counts.leftDotTypeCounts - probDiscount) / (float)prefixCounts.dotdotTypeCounts;
        return prob;
    }

    protected float getLowerOrderBackoff(int[] ngram, int startPos, int endPos) {
        long backoffDenom;
        if (startPos == endPos) {
            return 1.0f;
        }
        KneserNeyCountValueContainer.KneserNeyCounts counts = this.getCounts(ngram, startPos, endPos, true);
        long l = backoffDenom = endPos - startPos == this.lmOrder - 1 || ngram[startPos] == this.startIndex ? counts.tokenCounts : counts.dotdotTypeCounts;
        assert (backoffDenom >= 0L);
        float backoffDiscount = this.getDiscountForOrder(endPos - startPos);
        float backoff = (float)backoffDenom == 0.0f ? 1.0f : backoffDiscount * (float)counts.rightDotTypeCounts / (float)backoffDenom;
        return backoff;
    }

    protected float getDiscountForOrder(int ngramOrder) {
        int numTwoCounters;
        if (this.opts.kneserNeyDiscounts != null) {
            return (float)this.opts.kneserNeyDiscounts[ngramOrder];
        }
        int numOneCounters = ((KneserNeyCountValueContainer)this.ngrams.getValues()).getNumOneCountNgrams(ngramOrder);
        float denom = (float)numOneCounters + 2.0f * (float)(numTwoCounters = ((KneserNeyCountValueContainer)this.ngrams.getValues()).getNumTwoCountNgrams(ngramOrder));
        return denom == 0.0f ? 1.0E-5f : (float)numOneCounters / denom;
    }

    @Override
    public void cleanup() {
    }

    private KneserNeyCountValueContainer.KneserNeyCounts getCounts(int[] key, int startPos, int endPos, boolean isBackoff) {
        boolean endsWithEndSym;
        KneserNeyCountValueContainer.KneserNeyCounts value = new KneserNeyCountValueContainer.KneserNeyCounts();
        if (startPos == endPos) {
            value.dotdotTypeCounts = ((KneserNeyCountValueContainer)this.ngrams.getValues()).getBigramTypeCounts();
            return value;
        }
        long offset = this.ngrams.getOffsetForNgramInModel(key, startPos, endPos);
        if (offset < 0L) {
            return value;
        }
        this.ngrams.getValues().getFromOffset(offset, endPos - startPos - 1, value);
        boolean startsWithStartSym = key[startPos] == this.startIndex;
        boolean bl = endsWithEndSym = key[endPos - 1] == this.wordIndexer.getIndexPossiblyUnk(this.wordIndexer.getEndSymbol());
        if (startsWithStartSym) {
            value.dotdotTypeCounts = value.rightDotTypeCounts;
            if (endPos - startPos < this.lmOrder - 1 || endPos - startPos == this.lmOrder - 1 && !isBackoff) {
                value.tokenCounts = value.leftDotTypeCounts;
            }
        }
        if (endsWithEndSym) {
            value.rightDotTypeCounts = 1L;
            value.dotdotTypeCounts = value.leftDotTypeCounts;
        }
        return value;
    }

    public static double[] defaultDiscounts() {
        return KneserNeyLmReaderCallback.constantArray(10, 0.75);
    }

    public static double[] defaultMinCounts() {
        return new double[]{1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0};
    }

    private static double[] constantArray(int n, double f) {
        double[] ret = new double[n];
        Arrays.fill(ret, f);
        return ret;
    }

    @Override
    public void parse(ArpaLmReaderCallback<ProbBackoffPair> callback) {
        int ngramOrder;
        Logger.startTrack("Writing Kneser-Ney probabilities", new Object[0]);
        ArrayList<Long> lengths = new ArrayList<Long>();
        for (ngramOrder = 0; ngramOrder < this.lmOrder; ++ngramOrder) {
            Logger.startTrack("Counting counts for order " + ngramOrder, new Object[0]);
            long numNgrams = 0L;
            for (NgramMap.Entry<KneserNeyCountValueContainer.KneserNeyCounts> entry : this.ngrams.getNgramsForOrder(ngramOrder)) {
                long relevantCount = ((KneserNeyCountValueContainer.KneserNeyCounts)entry.value).tokenCounts;
                if (ngramOrder >= this.lmOrder - 2 && (double)relevantCount < this.opts.kneserNeyMinCounts[ngramOrder]) continue;
                ++numNgrams;
            }
            lengths.add(numNgrams);
            Logger.endTrack();
        }
        callback.initWithLengths(lengths);
        for (ngramOrder = 0; ngramOrder < this.lmOrder; ++ngramOrder) {
            callback.handleNgramOrderStarted(ngramOrder + 1);
            Logger.logss("On order " + (ngramOrder + 1));
            int linenum = 0;
            for (NgramMap.Entry<KneserNeyCountValueContainer.KneserNeyCounts> entry : this.ngrams.getNgramsForOrder(ngramOrder)) {
                if (linenum++ % 10000 == 0) {
                    Logger.logs("Writing line " + linenum);
                }
                long relevantCount = ((KneserNeyCountValueContainer.KneserNeyCounts)entry.value).tokenCounts;
                if (ngramOrder >= this.lmOrder - 2 && (double)relevantCount < this.opts.kneserNeyMinCounts[ngramOrder]) continue;
                int[] ngram = entry.key;
                int endPos = ngram.length;
                boolean startPos = false;
                ProbBackoffPair value = this.getProbBackoff(ngram, 0, endPos);
                callback.call(ngram, 0, endPos, value, "");
            }
            callback.handleNgramOrderFinished(ngramOrder + 1);
        }
        callback.cleanup();
        Logger.endTrack();
    }

    private ProbBackoffPair getProbBackoff(int[] ngram, int startPos, int endPos) {
        int nextNonStart;
        int ngramOrder = endPos - startPos - 1;
        boolean isHighestOrder = ngramOrder == this.lmOrder - 1;
        float val = isHighestOrder || ngram[startPos] == this.startIndex ? this.getHighestOrderProb(ngram, startPos, endPos) : this.getLowerOrderProb(ngram, startPos, endPos);
        for (nextNonStart = startPos + 1; nextNonStart < endPos && ngram[nextNonStart] == this.startIndex; ++nextNonStart) {
        }
        float prob = val + this.getLowerOrderBackoff(ngram, startPos, endPos - 1) * this.interpolateProb(ngram, nextNonStart, endPos);
        boolean isStartEndSym = endPos - startPos == 1 && ngram[startPos] == this.startIndex;
        float logProb = isStartEndSym ? -99.0f : (float)Math.log10(prob);
        float backoff = isHighestOrder ? 0.0f : (float)Math.log10(this.getLowerOrderBackoff(ngram, startPos, endPos));
        ProbBackoffPair ret = new ProbBackoffPair(logProb, backoff);
        return ret;
    }

    @Override
    public WordIndexer<W> getWordIndexer() {
        return this.wordIndexer;
    }

    @Override
    public void handleNgramOrderFinished(int order) {
    }

    @Override
    public void handleNgramOrderStarted(int order) {
    }

    @Override
    public int getLmOrder() {
        return this.lmOrder;
    }

    @Override
    public float scoreSentence(List<W> sentence) {
        return ArrayEncodedNgramLanguageModel.DefaultImplementations.scoreSentence(sentence, this);
    }

    @Override
    public float getLogProb(List<W> ngram) {
        return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this);
    }

    @Override
    public float getLogProb(int[] ngram, int startPos, int endPos) {
        ProbBackoffPair probBackoff = this.getProbBackoff(ngram, startPos, endPos);
        return probBackoff.prob;
    }

    @Override
    public float getLogProb(int[] ngram) {
        return ArrayEncodedNgramLanguageModel.DefaultImplementations.getLogProb(ngram, this);
    }

    public long getTotalSize() {
        return this.ngrams.getTotalSize();
    }

    @Override
    public void setOovWordLogProb(float logProb) {
        throw new UnsupportedOperationException("Method not yet implemented");
    }
}

