/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.som.training.basic;

import org.encog.mathutil.matrices.Matrix;
import org.encog.mathutil.matrices.MatrixMath;
import org.encog.ml.MLMethod;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLData;
import org.encog.ml.data.MLDataPair;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.LearningRate;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.som.SOM;
import org.encog.neural.som.training.basic.BestMatchingUnit;
import org.encog.neural.som.training.basic.neighborhood.NeighborhoodFunction;
import org.encog.util.Format;
import org.encog.util.logging.EncogLogging;

public class BasicTrainSOM
extends BasicTraining
implements LearningRate {
    private final NeighborhoodFunction neighborhood;
    private double learningRate;
    private final SOM network;
    private final int inputNeuronCount;
    private final int outputNeuronCount;
    private final BestMatchingUnit bmuUtil;
    private final Matrix correctionMatrix;
    private boolean forceWinner;
    private double startRate;
    private double endRate;
    private double startRadius;
    private double endRadius;
    private double autoDecayRate;
    private double autoDecayRadius;
    private double radius;

    public BasicTrainSOM(SOM network, double learningRate, MLDataSet training, NeighborhoodFunction neighborhood) {
        super(TrainingImplementationType.Iterative);
        this.neighborhood = neighborhood;
        this.setTraining(training);
        this.learningRate = learningRate;
        this.network = network;
        this.inputNeuronCount = network.getInputCount();
        this.outputNeuronCount = network.getOutputCount();
        this.forceWinner = false;
        this.setError(0.0);
        this.correctionMatrix = new Matrix(this.outputNeuronCount, this.inputNeuronCount);
        this.bmuUtil = new BestMatchingUnit(network);
    }

    private void applyCorrection() {
        this.network.getWeights().set(this.correctionMatrix);
    }

    public void autoDecay() {
        if (this.radius > this.endRadius) {
            this.radius += this.autoDecayRadius;
        }
        if (this.learningRate > this.endRate) {
            this.learningRate += this.autoDecayRate;
        }
        this.getNeighborhood().setRadius(this.radius);
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    private void copyInputPattern(Matrix matrix, int outputNeuron, MLData input) {
        for (int inputNeuron = 0; inputNeuron < this.inputNeuronCount; ++inputNeuron) {
            matrix.set(outputNeuron, inputNeuron, input.getData(inputNeuron));
        }
    }

    public void decay(double d) {
        this.radius *= 1.0 - d;
        this.learningRate *= 1.0 - d;
    }

    public void decay(double decayRate, double decayRadius) {
        this.radius *= 1.0 - decayRadius;
        this.learningRate *= 1.0 - decayRate;
        this.getNeighborhood().setRadius(this.radius);
    }

    private double determineNewWeight(double weight, double input, int currentNeuron, int bmu) {
        double newWeight = weight + this.neighborhood.function(currentNeuron, bmu) * this.learningRate * (input - weight);
        return newWeight;
    }

    private boolean forceWinners(Matrix matrix, int[] won, MLData leastRepresented) {
        double maxActivation = Double.MIN_VALUE;
        int maxActivationNeuron = -1;
        MLData output = this.compute(this.network, leastRepresented);
        for (int outputNeuron = 0; outputNeuron < won.length; ++outputNeuron) {
            if (won[outputNeuron] != 0 || maxActivationNeuron != -1 && !(output.getData(outputNeuron) > maxActivation)) continue;
            maxActivation = output.getData(outputNeuron);
            maxActivationNeuron = outputNeuron;
        }
        if (maxActivationNeuron != -1) {
            this.copyInputPattern(matrix, maxActivationNeuron, leastRepresented);
            return true;
        }
        return false;
    }

    public int getInputNeuronCount() {
        return this.inputNeuronCount;
    }

    @Override
    public double getLearningRate() {
        return this.learningRate;
    }

    @Override
    public MLMethod getMethod() {
        return this.network;
    }

    public NeighborhoodFunction getNeighborhood() {
        return this.neighborhood;
    }

    public int getOutputNeuronCount() {
        return this.outputNeuronCount;
    }

    public boolean isForceWinner() {
        return this.forceWinner;
    }

    @Override
    public void iteration() {
        EncogLogging.log(1, "Performing SOM Training iteration.");
        this.preIteration();
        this.bmuUtil.reset();
        int[] won = new int[this.outputNeuronCount];
        double leastRepresentedActivation = Double.MAX_VALUE;
        MLData leastRepresented = null;
        this.correctionMatrix.clear();
        for (MLDataPair pair : this.getTraining()) {
            MLData output;
            int bmu;
            MLData input = pair.getInput();
            int n = bmu = this.bmuUtil.calculateBMU(input);
            won[n] = won[n] + 1;
            if (this.forceWinner && (output = this.compute(this.network, pair.getInput())).getData(bmu) < leastRepresentedActivation) {
                leastRepresentedActivation = output.getData(bmu);
                leastRepresented = pair.getInput();
            }
            this.train(bmu, this.network.getWeights(), input);
            if (this.forceWinner) {
                if (this.forceWinners(this.network.getWeights(), won, leastRepresented)) continue;
                this.applyCorrection();
                continue;
            }
            this.applyCorrection();
        }
        this.setError(this.bmuUtil.getWorstDistance() / 100.0);
        this.postIteration();
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }

    public void setAutoDecay(int plannedIterations, double startRate, double endRate, double startRadius, double endRadius) {
        this.startRate = startRate;
        this.endRate = endRate;
        this.startRadius = startRadius;
        this.endRadius = endRadius;
        this.autoDecayRadius = (endRadius - startRadius) / (double)plannedIterations;
        this.autoDecayRate = (endRate - startRate) / (double)plannedIterations;
        this.setParams(this.startRate, this.startRadius);
    }

    public void setForceWinner(boolean forceWinner) {
        this.forceWinner = forceWinner;
    }

    @Override
    public void setLearningRate(double rate) {
        this.learningRate = rate;
    }

    public void setParams(double rate, double radius) {
        this.radius = radius;
        this.learningRate = rate;
        this.getNeighborhood().setRadius(radius);
    }

    public String toString() {
        StringBuilder result = new StringBuilder();
        result.append("Rate=");
        result.append(Format.formatPercent(this.learningRate));
        result.append(", Radius=");
        result.append(Format.formatDouble(this.radius, 2));
        return result.toString();
    }

    private void train(int bmu, Matrix matrix, MLData input) {
        for (int outputNeuron = 0; outputNeuron < this.outputNeuronCount; ++outputNeuron) {
            this.trainPattern(matrix, input, outputNeuron, bmu);
        }
    }

    private void trainPattern(Matrix matrix, MLData input, int current, int bmu) {
        for (int inputNeuron = 0; inputNeuron < this.inputNeuronCount; ++inputNeuron) {
            double currentWeight = matrix.get(current, inputNeuron);
            double inputValue = input.getData(inputNeuron);
            double newWeight = this.determineNewWeight(currentWeight, inputValue, current, bmu);
            this.correctionMatrix.set(current, inputNeuron, newWeight);
        }
    }

    public void trainPattern(MLData pattern) {
        MLData input = pattern;
        int bmu = this.bmuUtil.calculateBMU(input);
        this.train(bmu, this.network.getWeights(), input);
        this.applyCorrection();
    }

    private MLData compute(SOM som, MLData input) {
        BasicMLData result = new BasicMLData(som.getOutputCount());
        for (int i = 0; i < som.getOutputCount(); ++i) {
            Matrix optr = som.getWeights().getRow(i);
            Matrix inputMatrix = Matrix.createRowMatrix(input.getData());
            result.setData(i, MatrixMath.dotProduct(inputMatrix, optr));
        }
        return result;
    }
}

