/*
 * Decompiled with CFR 0.152.
 */
package org.encog.ml.hmm.train.bw;

import java.util.Arrays;
import java.util.List;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.MLSequenceSet;
import org.encog.ml.hmm.HiddenMarkovModel;
import org.encog.ml.hmm.alog.ForwardBackwardCalculator;
import org.encog.ml.hmm.distributions.StateDistribution;
import org.encog.ml.train.MLTrain;
import org.encog.ml.train.strategy.Strategy;
import org.encog.neural.networks.training.propagation.TrainingContinuation;

public abstract class BaseBaumWelch
implements MLTrain {
    private int iterations;
    private HiddenMarkovModel method;
    private final MLSequenceSet training;

    public BaseBaumWelch(HiddenMarkovModel hmm, MLSequenceSet training) {
        this.method = hmm;
        this.training = training;
    }

    @Override
    public void addStrategy(Strategy strategy) {
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    protected double[][] estimateGamma(double[][][] xi, ForwardBackwardCalculator fbc) {
        int i;
        int t;
        double[][] gamma = new double[xi.length + 1][xi[0].length];
        for (t = 0; t < xi.length + 1; ++t) {
            Arrays.fill(gamma[t], 0.0);
        }
        for (t = 0; t < xi.length; ++t) {
            for (i = 0; i < xi[0].length; ++i) {
                for (int j = 0; j < xi[0].length; ++j) {
                    double[] dArray = gamma[t];
                    int n = i;
                    dArray[n] = dArray[n] + xi[t][i][j];
                }
            }
        }
        for (int j = 0; j < xi[0].length; ++j) {
            for (i = 0; i < xi[0].length; ++i) {
                double[] dArray = gamma[xi.length];
                int n = j;
                dArray[n] = dArray[n] + xi[xi.length - 1][i][j];
            }
        }
        return gamma;
    }

    public abstract double[][][] estimateXi(MLDataSet var1, ForwardBackwardCalculator var2, HiddenMarkovModel var3);

    @Override
    public void finishTraining() {
    }

    public abstract ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet var1, HiddenMarkovModel var2);

    @Override
    public double getError() {
        return 0.0;
    }

    @Override
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override
    public int getIteration() {
        return this.iterations;
    }

    @Override
    public MLMethod getMethod() {
        return this.method;
    }

    @Override
    public List<Strategy> getStrategies() {
        return null;
    }

    @Override
    public MLDataSet getTraining() {
        return this.training;
    }

    @Override
    public boolean isTrainingDone() {
        return false;
    }

    @Override
    public void iteration() {
        int i;
        HiddenMarkovModel nhmm;
        try {
            nhmm = this.method.clone();
        }
        catch (CloneNotSupportedException e) {
            throw new InternalError();
        }
        double[][][] allGamma = new double[this.training.getSequenceCount()][][];
        double[][] aijNum = new double[this.method.getStateCount()][this.method.getStateCount()];
        double[] aijDen = new double[this.method.getStateCount()];
        Arrays.fill(aijDen, 0.0);
        for (int i2 = 0; i2 < this.method.getStateCount(); ++i2) {
            Arrays.fill(aijNum[i2], 0.0);
        }
        int g = 0;
        for (MLDataSet obsSeq : this.training.getSequences()) {
            ForwardBackwardCalculator fbc = this.generateForwardBackwardCalculator(obsSeq, this.method);
            double[][][] xi = this.estimateXi(obsSeq, fbc, this.method);
            int n = g++;
            double[][] dArray = this.estimateGamma(xi, fbc);
            allGamma[n] = dArray;
            double[][] gamma = dArray;
            for (int i3 = 0; i3 < this.method.getStateCount(); ++i3) {
                for (int t = 0; t < obsSeq.size() - 1; ++t) {
                    int n2 = i3;
                    aijDen[n2] = aijDen[n2] + gamma[t][i3];
                    for (int j = 0; j < this.method.getStateCount(); ++j) {
                        double[] dArray2 = aijNum[i3];
                        int n3 = j;
                        dArray2[n3] = dArray2[n3] + xi[t][i3][j];
                    }
                }
            }
        }
        for (i = 0; i < this.method.getStateCount(); ++i) {
            int j;
            if (aijDen[i] == 0.0) {
                for (j = 0; j < this.method.getStateCount(); ++j) {
                    nhmm.setTransitionProbability(i, j, this.method.getTransitionProbability(i, j));
                }
                continue;
            }
            for (j = 0; j < this.method.getStateCount(); ++j) {
                nhmm.setTransitionProbability(i, j, aijNum[i][j] / aijDen[i]);
            }
        }
        for (i = 0; i < this.method.getStateCount(); ++i) {
            nhmm.setPi(i, 0.0);
        }
        for (int o = 0; o < this.training.getSequenceCount(); ++o) {
            for (int i4 = 0; i4 < this.method.getStateCount(); ++i4) {
                nhmm.setPi(i4, nhmm.getPi(i4) + allGamma[o][0][i4] / (double)this.training.getSequenceCount());
            }
        }
        for (i = 0; i < this.method.getStateCount(); ++i) {
            double[] weights = new double[this.training.size()];
            double sum = 0.0;
            int j = 0;
            int o = 0;
            for (MLDataSet obsSeq : this.training.getSequences()) {
                int t = 0;
                while (t < obsSeq.size()) {
                    weights[j] = allGamma[o][t][i];
                    sum += weights[j];
                    ++t;
                    ++j;
                }
                ++o;
            }
            --j;
            while (j >= 0) {
                int n = j--;
                weights[n] = weights[n] / sum;
            }
            StateDistribution opdf = nhmm.getStateDistribution(i);
            opdf.fit(this.training, weights);
        }
        this.method = nhmm;
    }

    @Override
    public void iteration(int count) {
        for (int i = 0; i < count; ++i) {
            this.iteration();
        }
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    @Override
    public void setError(double error) {
    }

    @Override
    public void setIteration(int iteration) {
        this.iterations = iteration;
    }
}

