package org.neuroph.contrib.bpbench;

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.nnet.learning.BackPropagation;

/* loaded from: input_file:org/neuroph/contrib/bpbench/BackpropagationTraining.class */
public class BackpropagationTraining extends AbstractTraining {
    public BackpropagationTraining(NeuralNetwork neuralNetwork, DataSet dataSet, TrainingSettings trainingSettings) {
        super(neuralNetwork, dataSet, trainingSettings);
    }

    public BackpropagationTraining(DataSet dataSet, TrainingSettings trainingSettings) {
        super(dataSet, trainingSettings);
    }

    @Override // org.neuroph.contrib.bpbench.AbstractTraining
    public void testNeuralNet() {
        BackPropagation parameters = setParameters();
        getNeuralNet().setLearningRule(parameters);
        getNeuralNet().learn(getDataset());
        getStats().addData(new TrainingResult(parameters.getCurrentIteration(), parameters.getTotalNetworkError(), createMatrix()));
        getStats().calculateParameters();
    }

    @Override // org.neuroph.contrib.bpbench.AbstractTraining
    public LearningRule setParameters() {
        BackPropagation backPropagation = new BackPropagation();
        backPropagation.setLearningRate(getSettings().getLearningRate());
        backPropagation.setMaxError(getSettings().getMaxError());
        backPropagation.setBatchMode(getSettings().isBatchMode());
        backPropagation.setMaxIterations(getSettings().getMaxIterations());
        return backPropagation;
    }
}
