package org.neuroph.contrib.rnn.bptt;

import java.util.Map;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.GRU;
import org.neuroph.contrib.rnn.RNN;

/* loaded from: input_file:org/neuroph/contrib/rnn/bptt/GRUBackPropagationThroughTime.class */
public class GRUBackPropagationThroughTime extends BackPropagationThroughTime {
    @Override // org.neuroph.contrib.rnn.bptt.BackPropagationThroughTime
    public void propagate(Map<String, DoubleMatrix> map, int i, double d) {
        GRU gru = (GRU) getNeuralNetwork();
        int i2 = i;
        while (i2 >= 0) {
            DoubleMatrix sub = map.get("predictedResult" + i2).sub(map.get("result" + i2));
            map.put("resultDelta" + i2, sub);
            DoubleMatrix doubleMatrix = map.get("output" + i2);
            DoubleMatrix doubleMatrix2 = map.get("updateActivation" + i2);
            DoubleMatrix doubleMatrix3 = map.get("resetActivation" + i2);
            DoubleMatrix doubleMatrix4 = map.get("memoryCellGate" + i2);
            DoubleMatrix computeOutputDeltaForLastTimestep = i2 == i ? computeOutputDeltaForLastTimestep(null, sub, gru) : computeOutputDeltaForNotLastTimeStep(null, sub, map, i2, gru);
            map.put("outputDelta" + i2, computeOutputDeltaForLastTimestep);
            DoubleMatrix mul = computeOutputDeltaForLastTimestep.mul(doubleMatrix2).mul(deriveTanh(doubleMatrix4));
            map.put("memoryCellGateDelta" + i2, mul);
            DoubleMatrix zeros = i2 > 0 ? map.get("output" + (i2 - 1)) : DoubleMatrix.zeros(1, doubleMatrix.length);
            map.put("resetActivationDelta" + i2, gru.getMemoryCellOutputWeight().mmul(mul.mul(zeros).transpose()).transpose().mul(deriveExp(doubleMatrix3)));
            map.put("updateActivationDelta" + i2, computeOutputDeltaForLastTimestep.mul(doubleMatrix4.sub(zeros)).mul(deriveExp(doubleMatrix2)));
            i2--;
        }
        updateParameters(map, i, d, gru);
    }

    @Override // org.neuroph.contrib.rnn.bptt.BackPropagationThroughTime
    protected void updateParameters(Map<String, DoubleMatrix> map, int i, double d, RNN rnn) {
        GRU gru = (GRU) rnn;
        DoubleMatrix doubleMatrix = new DoubleMatrix(gru.getResetGateInputWeight().rows, gru.getResetGateInputWeight().columns);
        DoubleMatrix doubleMatrix2 = new DoubleMatrix(gru.getResetGateOutputWeight().rows, gru.getResetGateOutputWeight().columns);
        DoubleMatrix doubleMatrix3 = new DoubleMatrix(gru.getResetGateBias().rows, gru.getResetGateBias().columns);
        DoubleMatrix doubleMatrix4 = new DoubleMatrix(gru.getUpdateGateInputWeight().rows, gru.getUpdateGateInputWeight().columns);
        DoubleMatrix doubleMatrix5 = new DoubleMatrix(gru.getUpdateGateOutputWeight().rows, gru.getUpdateGateOutputWeight().columns);
        DoubleMatrix doubleMatrix6 = new DoubleMatrix(gru.getUpdateGateBias().rows, gru.getUpdateGateBias().columns);
        DoubleMatrix doubleMatrix7 = new DoubleMatrix(gru.getMemoryCellInputWeight().rows, gru.getMemoryCellInputWeight().columns);
        DoubleMatrix doubleMatrix8 = new DoubleMatrix(gru.getMemoryCellOutputWeight().rows, gru.getMemoryCellOutputWeight().columns);
        DoubleMatrix doubleMatrix9 = new DoubleMatrix(gru.getMemoryCellBias().rows, gru.getMemoryCellBias().columns);
        DoubleMatrix doubleMatrix10 = new DoubleMatrix(gru.getOutputWeight().rows, gru.getOutputWeight().columns);
        DoubleMatrix doubleMatrix11 = new DoubleMatrix(gru.getOutputBias().rows, gru.getOutputBias().columns);
        for (int i2 = 0; i2 < i + 1; i2++) {
            DoubleMatrix transpose = map.get("input" + i2).transpose();
            doubleMatrix = doubleMatrix.add(transpose.mmul(map.get("resetActivationDelta" + i2)));
            doubleMatrix4 = doubleMatrix4.add(transpose.mmul(map.get("updateActivationDelta" + i2)));
            doubleMatrix7 = doubleMatrix7.add(transpose.mmul(map.get("memoryCellGateDelta" + i2)));
            if (i2 > 0) {
                DoubleMatrix transpose2 = map.get("output" + (i2 - 1)).transpose();
                doubleMatrix2 = doubleMatrix2.add(transpose2.mmul(map.get("resetActivationDelta" + i2)));
                doubleMatrix5 = doubleMatrix5.add(transpose2.mmul(map.get("updateActivationDelta" + i2)));
                doubleMatrix8 = doubleMatrix8.add(map.get("resetActivation" + i2).transpose().mul(transpose2).mmul(map.get("memoryCellGateDelta" + i2)));
            }
            doubleMatrix10 = doubleMatrix10.add(map.get("output" + i2).transpose().mmul(map.get("resultDelta" + i2)));
            doubleMatrix3 = doubleMatrix3.add(map.get("resetActivationDelta" + i2));
            doubleMatrix6 = doubleMatrix6.add(map.get("updateActivationDelta" + i2));
            doubleMatrix9 = doubleMatrix9.add(map.get("memoryCellGateDelta" + i2));
            doubleMatrix11 = doubleMatrix11.add(map.get("resultDelta" + i2));
        }
        gru.setResetGateInputWeight(gru.getResetGateInputWeight().sub(doubleMatrix.div(i).mul(d)));
        gru.setResetGateOutputWeight(gru.getResetGateOutputWeight().sub(doubleMatrix2.div(i < 2 ? 1.0d : i - 1).mul(d)));
        gru.setResetGateBias(gru.getResetGateBias().sub(doubleMatrix3.div(i).mul(d)));
        gru.setUpdateGateInputWeight(gru.getUpdateGateInputWeight().sub(doubleMatrix4.div(i).mul(d)));
        gru.setUpdateGateOutputWeight(gru.getUpdateGateOutputWeight().sub(doubleMatrix5.div(i < 2 ? 1.0d : i - 1).mul(d)));
        gru.setUpdateGateBias(gru.getUpdateGateBias().sub(doubleMatrix6.div(i).mul(d)));
        gru.setMemoryCellInputWeight(gru.getMemoryCellInputWeight().sub(doubleMatrix7.div(i).mul(d)));
        gru.setMemoryCellOutputWeight(gru.getMemoryCellOutputWeight().sub(doubleMatrix8.div(i < 2 ? 1.0d : i - 1).mul(d)));
        gru.setMemoryCellBias(gru.getMemoryCellBias().sub(doubleMatrix9.div(i).mul(d)));
        gru.setOutputWeight(gru.getOutputWeight().sub(doubleMatrix10.div(i).mul(d)));
        gru.setOutputBias(gru.getOutputBias().sub(doubleMatrix11.div(i).mul(d)));
    }

    private DoubleMatrix computeOutputDeltaForLastTimestep(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, GRU gru) {
        return gru.getOutputWeight().mmul(doubleMatrix2.transpose()).transpose();
    }

    private DoubleMatrix computeOutputDeltaForNotLastTimeStep(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2, Map<String, DoubleMatrix> map, int i, GRU gru) {
        DoubleMatrix doubleMatrix3 = map.get("outputDelta" + (i + 1));
        DoubleMatrix doubleMatrix4 = map.get("memoryCellGateDelta" + (i + 1));
        DoubleMatrix doubleMatrix5 = map.get("resetActivationDelta" + (i + 1));
        DoubleMatrix doubleMatrix6 = map.get("updateActivationDelta" + (i + 1));
        DoubleMatrix doubleMatrix7 = map.get("resetActivation" + (i + 1));
        DoubleMatrix doubleMatrix8 = map.get("updateActivation" + (i + 1));
        return gru.getOutputWeight().mmul(doubleMatrix2.transpose()).transpose().add(gru.getResetGateOutputWeight().mmul(doubleMatrix5.transpose()).transpose()).add(gru.getUpdateGateOutputWeight().mmul(doubleMatrix6.transpose()).transpose()).add(gru.getMemoryCellOutputWeight().mmul(doubleMatrix4.mul(doubleMatrix7).transpose()).transpose()).add(doubleMatrix3.mul(DoubleMatrix.ones(1, doubleMatrix8.columns).sub(doubleMatrix8)));
    }
}
