package org.neuroph.contrib.bpbench;

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.LearningRule;
import org.neuroph.nnet.learning.MomentumBackpropagation;

/* loaded from: input_file:org/neuroph/contrib/bpbench/MomentumTraining.class */
public class MomentumTraining extends AbstractTraining {
    public MomentumTraining(NeuralNetwork neuralNetwork, DataSet dataSet, TrainingSettings trainingSettings) {
        super(neuralNetwork, dataSet, trainingSettings);
    }

    public MomentumTraining(DataSet dataSet, TrainingSettings trainingSettings) {
        super(dataSet, trainingSettings);
    }

    @Override // org.neuroph.contrib.bpbench.AbstractTraining
    public void testNeuralNet() {
        MomentumBackpropagation parameters = setParameters();
        getNeuralNet().setLearningRule(parameters);
        getNeuralNet().learn(getDataset());
        getStats().addData(new TrainingResult(parameters.getCurrentIteration(), parameters.getTotalNetworkError(), createMatrix()));
        getStats().calculateParameters();
    }

    @Override // org.neuroph.contrib.bpbench.AbstractTraining
    public LearningRule setParameters() {
        MomentumBackpropagation momentumBackpropagation = new MomentumBackpropagation();
        momentumBackpropagation.setBatchMode(getSettings().isBatchMode());
        momentumBackpropagation.setLearningRate(getSettings().getLearningRate());
        momentumBackpropagation.setMaxError(getSettings().getMaxError());
        momentumBackpropagation.setMaxIterations(getSettings().getMaxIterations());
        momentumBackpropagation.setMomentum(getSettings().getMomentum());
        return momentumBackpropagation;
    }
}
