package org.neuroph.contrib.matrixmlp;

import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.transfer.TransferFunction;
import org.neuroph.nnet.learning.MomentumBackpropagation;

/* loaded from: input_file:org/neuroph/contrib/matrixmlp/MatrixMomentumBackpropagation.class */
public class MatrixMomentumBackpropagation extends MomentumBackpropagation {
    private MatrixMultiLayerPerceptron matrixMlp;
    private MatrixLayer[] matrixLayers;

    public void setNeuralNetwork(NeuralNetwork neuralNetwork) {
        super.setNeuralNetwork(neuralNetwork);
        this.matrixMlp = (MatrixMultiLayerPerceptron) getNeuralNetwork();
        this.matrixLayers = this.matrixMlp.getMatrixLayers();
    }

    protected void calculateErrorAndUpdateOutputNeurons(double[] dArr) {
        MatrixMlpLayer matrixMlpLayer = (MatrixMlpLayer) this.matrixLayers[this.matrixLayers.length - 1];
        TransferFunction transferFunction = matrixMlpLayer.getTransferFunction();
        double[] outputs = matrixMlpLayer.getOutputs();
        double[] netInput = matrixMlpLayer.getNetInput();
        double[] errors = matrixMlpLayer.getErrors();
        for (int i = 0; i < outputs.length; i++) {
            errors[i] = dArr[i] * transferFunction.getDerivative(netInput[i]);
        }
        updateLayerWeights(matrixMlpLayer, errors);
        System.out.println("MSE:" + getErrorFunction().getTotalError());
    }

    protected void updateLayerWeights(MatrixMlpLayer matrixMlpLayer, double[] dArr) {
        double[] inputs = matrixMlpLayer.getInputs();
        double[][] weights = matrixMlpLayer.getWeights();
        double[][] deltaWeights = matrixMlpLayer.getDeltaWeights();
        for (int i = 0; i < matrixMlpLayer.getNeuronsCount(); i++) {
            for (int i2 = 0; i2 < weights[i].length; i2++) {
                double d = (this.learningRate * dArr[i] * inputs[i2]) + (this.momentum * deltaWeights[i][i2]);
                deltaWeights[i][i2] = d;
                double[] dArr2 = weights[i];
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + d;
            }
        }
    }

    protected void calculateErrorAndUpdateHiddenNeurons() {
        for (int layersCount = this.matrixMlp.getLayersCount() - 2; layersCount > 0; layersCount--) {
            MatrixMlpLayer matrixMlpLayer = (MatrixMlpLayer) this.matrixLayers[layersCount];
            TransferFunction transferFunction = matrixMlpLayer.getTransferFunction();
            int neuronsCount = matrixMlpLayer.getNeuronsCount();
            double[] errors = matrixMlpLayer.getErrors();
            double[] netInput = matrixMlpLayer.getNetInput();
            MatrixMlpLayer matrixMlpLayer2 = (MatrixMlpLayer) matrixMlpLayer.getNextLayer();
            double[] errors2 = matrixMlpLayer2.getErrors();
            double[][] weights = matrixMlpLayer2.getWeights();
            for (int i = 0; i < neuronsCount; i++) {
                double d = 0.0d;
                for (int i2 = 0; i2 < matrixMlpLayer2.getNeuronsCount(); i2++) {
                    d += errors2[i2] * weights[i2][i];
                }
                errors[i] = transferFunction.getDerivative(netInput[i]) * d;
            }
            updateLayerWeights(matrixMlpLayer, errors);
        }
    }
}
