/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.freeform.training;

import java.io.Serializable;
import java.util.HashSet;
import java.util.Set;
import org.encog.engine.network.activation.ActivationSigmoid;
import org.encog.mathutil.error.ErrorCalculation;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.freeform.FreeformConnection;
import org.encog.neural.freeform.FreeformNetwork;
import org.encog.neural.freeform.FreeformNeuron;
import org.encog.neural.freeform.task.ConnectionTask;

public abstract class FreeformPropagationTraining
extends BasicTraining
implements Serializable {
    private static final long serialVersionUID = 1L;
    public static final double FLAT_SPOT_CONST = 0.1;
    private final FreeformNetwork network;
    private final MLDataSet training;
    private int iterationCount;
    private double error;
    private final Set<FreeformNeuron> visited = new HashSet<FreeformNeuron>();
    private boolean fixFlatSopt = true;
    private int batchSize = 0;

    public FreeformPropagationTraining() {
        super(TrainingImplementationType.Iterative);
        this.network = null;
        this.training = null;
    }

    public FreeformPropagationTraining(FreeformNetwork theNetwork, MLDataSet theTraining) {
        super(TrainingImplementationType.Iterative);
        this.network = theNetwork;
        this.training = theTraining;
    }

    private void calculateNeuronGradient(FreeformNeuron toNeuron) {
        if (toNeuron.getInputSummation() != null) {
            for (FreeformConnection connection : toNeuron.getInputSummation().list()) {
                double gradient = connection.getSource().getActivation() * toNeuron.getTempTraining(0);
                connection.addTempTraining(0, gradient);
                FreeformNeuron fromNeuron = connection.getSource();
                double sum = 0.0;
                for (FreeformConnection toConnection : fromNeuron.getOutputs()) {
                    sum += toConnection.getTarget().getTempTraining(0) * toConnection.getWeight();
                }
                double neuronOutput = fromNeuron.getActivation();
                double neuronSum = fromNeuron.getSum();
                double deriv = toNeuron.getInputSummation().getActivationFunction().derivativeFunction(neuronSum, neuronOutput);
                if (this.fixFlatSopt && toNeuron.getInputSummation().getActivationFunction() instanceof ActivationSigmoid) {
                    deriv += 0.1;
                }
                double layerDelta = sum * deriv;
                fromNeuron.setTempTraining(0, layerDelta);
            }
            for (FreeformConnection connection : toNeuron.getInputSummation().list()) {
                FreeformNeuron fromNeuron = connection.getSource();
                this.calculateNeuronGradient(fromNeuron);
            }
        }
    }

    private void calculateOutputDelta(FreeformNeuron neuron, double diff) {
        double neuronOutput = neuron.getActivation();
        double neuronSum = neuron.getInputSummation().getSum();
        double deriv = neuron.getInputSummation().getActivationFunction().derivativeFunction(neuronSum, neuronOutput);
        if (this.fixFlatSopt && neuron.getInputSummation().getActivationFunction() instanceof ActivationSigmoid) {
            deriv += 0.1;
        }
        double layerDelta = deriv * diff;
        neuron.setTempTraining(0, layerDelta);
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    @Override
    public void finishTraining() {
        this.network.tempTrainingClear();
    }

    @Override
    public double getError() {
        return this.error;
    }

    @Override
    public TrainingImplementationType getImplementationType() {
        return TrainingImplementationType.Iterative;
    }

    @Override
    public int getIteration() {
        return this.iterationCount;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    @Override
    public MLDataSet getTraining() {
        return this.training;
    }

    public boolean isFixFlatSopt() {
        return this.fixFlatSopt;
    }

    @Override
    public void iteration() {
        this.preIteration();
        ++this.iterationCount;
        this.network.clearContext();
        if (this.batchSize == 0) {
            this.processPureBatch();
        } else {
            this.processBatches();
        }
        this.postIteration();
    }

    @Override
    public void iteration(int count) {
        for (int i = 0; i < count; ++i) {
            this.iteration();
        }
    }

    protected void processPureBatch() {
        ErrorCalculation errorCalc = new ErrorCalculation();
        this.visited.clear();
        for (MLDataPair pair : this.training) {
            MLData input = pair.getInput();
            MLData ideal = pair.getIdeal();
            MLData actual = this.network.compute(input);
            double sig = pair.getSignificance();
            errorCalc.updateError(actual.getData(), ideal.getData(), sig);
            for (int i = 0; i < this.network.getOutputCount(); ++i) {
                double diff = (ideal.getData(i) - actual.getData(i)) * sig;
                FreeformNeuron neuron = this.network.getOutputLayer().getNeurons().get(i);
                this.calculateOutputDelta(neuron, diff);
                this.calculateNeuronGradient(neuron);
            }
        }
        this.setError(errorCalc.calculate());
        this.learn();
    }

    protected void processBatches() {
        int lastLearn = 0;
        ErrorCalculation errorCalc = new ErrorCalculation();
        this.visited.clear();
        for (MLDataPair pair : this.training) {
            MLData input = pair.getInput();
            MLData ideal = pair.getIdeal();
            MLData actual = this.network.compute(input);
            double sig = pair.getSignificance();
            errorCalc.updateError(actual.getData(), ideal.getData(), sig);
            for (int i = 0; i < this.network.getOutputCount(); ++i) {
                double diff = (ideal.getData(i) - actual.getData(i)) * sig;
                FreeformNeuron neuron = this.network.getOutputLayer().getNeurons().get(i);
                this.calculateOutputDelta(neuron, diff);
                this.calculateNeuronGradient(neuron);
            }
            if (++lastLearn < this.batchSize) continue;
            lastLearn = 0;
            this.learn();
        }
        if (lastLearn > 0) {
            this.learn();
        }
        this.setError(errorCalc.calculate());
    }

    protected void learn() {
        this.network.performConnectionTask(new ConnectionTask(){

            @Override
            public void task(FreeformConnection connection) {
                FreeformPropagationTraining.this.learnConnection(connection);
                connection.setTempTraining(0, 0.0);
            }
        });
    }

    protected abstract void learnConnection(FreeformConnection var1);

    @Override
    public void setError(double theError) {
        this.error = theError;
    }

    public void setFixFlatSopt(boolean fixFlatSopt) {
        this.fixFlatSopt = fixFlatSopt;
    }

    @Override
    public void setIteration(int iteration) {
        this.iterationCount = iteration;
    }

    public int getBatchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }
}

