/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.hmm.alog;

import java.util.EnumSet;
import java.util.Iterator;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.hmm.HiddenMarkovModel;

public class ForwardBackwardCalculator {
    protected double[][] alpha = null;
    protected double[][] beta = null;
    protected double probability;

    protected ForwardBackwardCalculator() {
    }

    public ForwardBackwardCalculator(MLDataSet oseq, HiddenMarkovModel hmm) {
        this(oseq, hmm, EnumSet.of(Computation.ALPHA));
    }

    public ForwardBackwardCalculator(MLDataSet oseq, HiddenMarkovModel hmm, EnumSet<Computation> flags) {
        if (oseq.size() < 1) {
            throw new IllegalArgumentException("Empty sequence");
        }
        if (flags.contains((Object)Computation.ALPHA)) {
            this.computeAlpha(hmm, oseq);
        }
        if (flags.contains((Object)Computation.BETA)) {
            this.computeBeta(hmm, oseq);
        }
        this.computeProbability(oseq, hmm, flags);
    }

    public double alphaElement(int t, int i) {
        if (this.alpha == null) {
            throw new UnsupportedOperationException("Alpha array has not been computed");
        }
        return this.alpha[t][i];
    }

    public double betaElement(int t, int i) {
        if (this.beta == null) {
            throw new UnsupportedOperationException("Beta array has not been computed");
        }
        return this.beta[t][i];
    }

    protected void computeAlpha(HiddenMarkovModel hmm, MLDataSet oseq) {
        this.alpha = new double[oseq.size()][hmm.getStateCount()];
        for (int i = 0; i < hmm.getStateCount(); ++i) {
            this.computeAlphaInit(hmm, oseq.get(0), i);
        }
        Iterator seqIterator = oseq.iterator();
        if (seqIterator.hasNext()) {
            seqIterator.next();
        }
        for (int t = 1; t < oseq.size(); ++t) {
            MLDataPair observation = (MLDataPair)seqIterator.next();
            for (int i = 0; i < hmm.getStateCount(); ++i) {
                this.computeAlphaStep(hmm, observation, t, i);
            }
        }
    }

    protected void computeAlphaInit(HiddenMarkovModel hmm, MLDataPair o, int i) {
        this.alpha[0][i] = hmm.getPi(i) * hmm.getStateDistribution(i).probability(o);
    }

    protected void computeAlphaStep(HiddenMarkovModel hmm, MLDataPair o, int t, int j) {
        double sum = 0.0;
        for (int i = 0; i < hmm.getStateCount(); ++i) {
            sum += this.alpha[t - 1][i] * hmm.getTransitionProbability(i, j);
        }
        this.alpha[t][j] = sum * hmm.getStateDistribution(j).probability(o);
    }

    protected void computeBeta(HiddenMarkovModel hmm, MLDataSet oseq) {
        this.beta = new double[oseq.size()][hmm.getStateCount()];
        for (int i = 0; i < hmm.getStateCount(); ++i) {
            this.beta[oseq.size() - 1][i] = 1.0;
        }
        for (int t = oseq.size() - 2; t >= 0; --t) {
            for (int i = 0; i < hmm.getStateCount(); ++i) {
                this.computeBetaStep(hmm, oseq.get(t + 1), t, i);
            }
        }
    }

    protected void computeBetaStep(HiddenMarkovModel hmm, MLDataPair o, int t, int i) {
        double sum = 0.0;
        for (int j = 0; j < hmm.getStateCount(); ++j) {
            sum += this.beta[t + 1][j] * hmm.getTransitionProbability(i, j) * hmm.getStateDistribution(j).probability(o);
        }
        this.beta[t][i] = sum;
    }

    private void computeProbability(MLDataSet oseq, HiddenMarkovModel hmm, EnumSet<Computation> flags) {
        this.probability = 0.0;
        if (flags.contains((Object)Computation.ALPHA)) {
            for (int i = 0; i < hmm.getStateCount(); ++i) {
                this.probability += this.alpha[oseq.size() - 1][i];
            }
        } else {
            for (int i = 0; i < hmm.getStateCount(); ++i) {
                this.probability += hmm.getPi(i) * hmm.getStateDistribution(i).probability(oseq.get(0)) * this.beta[0][i];
            }
        }
    }

    public double probability() {
        return this.probability;
    }

    public static enum Computation {
        ALPHA,
        BETA;

    }
}

