/*
 * Decompiled with CFR 0.152.
 */
package org.encog.app.analyst.commands;

import java.io.File;
import org.encog.app.analyst.EncogAnalyst;
import org.encog.app.analyst.commands.Cmd;
import org.encog.ml.MLMethod;
import org.encog.ml.MLResettable;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.ea.train.EvolutionaryAlgorithm;
import org.encog.ml.factory.MLTrainFactory;
import org.encog.ml.train.MLTrain;
import org.encog.neural.networks.training.cross.CrossValidationKFold;
import org.encog.persist.EncogDirectoryPersistence;
import org.encog.util.logging.EncogLogging;
import org.encog.util.validate.ValidateNetwork;

public class CmdTrain
extends Cmd {
    public static final String COMMAND_NAME = "TRAIN";

    public CmdTrain(EncogAnalyst analyst) {
        super(analyst);
    }

    private MLTrain createTrainer(MLMethod method, MLDataSet trainingSet) {
        MLTrainFactory factory = new MLTrainFactory();
        String type = this.getProp().getPropertyString("ML:TRAIN_type");
        String args = this.getProp().getPropertyString("ML:TRAIN_arguments");
        EncogLogging.log(0, "training type:" + type);
        EncogLogging.log(0, "training args:" + args);
        if (method instanceof MLResettable) {
            this.getAnalyst().setMethod(method);
        }
        MLTrain train = factory.create(method, trainingSet, type, args);
        if (this.getKfold() > 0) {
            train = new CrossValidationKFold(train, this.getKfold());
        }
        return train;
    }

    @Override
    public boolean executeCommand(String args) {
        this.setKfold(this.obtainCross());
        MLDataSet trainingSet = this.obtainTrainingSet();
        MLMethod method = this.obtainMethod();
        MLTrain trainer = this.createTrainer(method, trainingSet);
        if (method instanceof BayesianNetwork) {
            String query = this.getProp().getPropertyString("ML:CONFIG_query");
            ((BayesianNetwork)method).defineClassificationStructure(query);
        }
        EncogLogging.log(0, "Beginning training");
        this.performTraining(trainer, method, trainingSet);
        String resourceID = this.getProp().getPropertyString("ML:CONFIG_machineLearningFile");
        File resourceFile = this.getAnalyst().getScript().resolveFilename(resourceID);
        method = null;
        if (trainer instanceof EvolutionaryAlgorithm) {
            EvolutionaryAlgorithm ea = (EvolutionaryAlgorithm)((Object)trainer);
            method = ea.getPopulation();
        }
        if (method == null) {
            method = trainer.getMethod();
        }
        EncogDirectoryPersistence.saveObject(resourceFile, (Object)method);
        EncogLogging.log(0, "save to:" + resourceID);
        trainingSet.close();
        return this.getAnalyst().shouldStopCommand();
    }

    @Override
    public String getName() {
        return COMMAND_NAME;
    }

    private void performTraining(MLTrain train, MLMethod method, MLDataSet trainingSet) {
        ValidateNetwork.validateMethodToData(method, trainingSet);
        double targetError = this.getProp().getPropertyDouble("ML:TRAIN_targetError");
        this.getAnalyst().reportTrainingBegin();
        int maxIteration = this.getAnalyst().getMaxIteration();
        if (train.getImplementationType() == TrainingImplementationType.OnePass) {
            train.iteration();
            this.getAnalyst().reportTraining(train);
        } else {
            do {
                train.iteration();
                this.getAnalyst().reportTraining(train);
            } while (train.getError() > targetError && !this.getAnalyst().shouldStopCommand() && !train.isTrainingDone() && (maxIteration == -1 || train.getIteration() < maxIteration));
        }
        train.finishTraining();
        this.getAnalyst().reportTrainingEnd();
        this.getAnalyst().setMethod(train.getMethod());
    }
}

