package org.neuroph.contrib.rnn.bptt;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.neuroph.contrib.rnn.RNN;
import org.neuroph.contrib.rnn.util.LossFunction;
import org.neuroph.contrib.rnn.util.SequenceModeller;
import org.neuroph.core.data.DataSet;
import org.neuroph.nnet.learning.BackPropagation;

/* loaded from: input_file:org/neuroph/contrib/rnn/bptt/BackPropagationThroughTime.class */
public abstract class BackPropagationThroughTime extends BackPropagation {
    public void learn(DataSet dataSet, int i) {
        SequenceModeller sequenceModeller = new SequenceModeller(dataSet);
        Map<String, DoubleMatrix> charVector = sequenceModeller.getCharVector();
        List<String> sequence = sequenceModeller.getSequence();
        for (int i2 = 0; i2 < i; i2++) {
            double d = 0.0d;
            double d2 = 0.0d;
            double currentTimeMillis = System.currentTimeMillis();
            for (int i3 = 0; i3 < sequence.size(); i3++) {
                String str = sequence.get(i3);
                if (str.length() >= 3) {
                    RNN rnn = (RNN) getNeuralNetwork();
                    HashMap hashMap = new HashMap();
                    for (int i4 = 0; i4 < str.length() - 1; i4++) {
                        hashMap.put("input" + i4, charVector.get(String.valueOf(str.charAt(i4))));
                        rnn.activate(i4, hashMap);
                        DoubleMatrix decode = rnn.decode(hashMap.get("output" + i4));
                        hashMap.put("predictedResult" + i4, decode);
                        DoubleMatrix doubleMatrix = charVector.get(String.valueOf(str.charAt(i4 + 1)));
                        hashMap.put("result" + i4, doubleMatrix);
                        d += LossFunction.getMeanCategoricalCrossEntropy(decode, doubleMatrix);
                    }
                    BackPropagationThroughTime learningRule = rnn.getLearningRule();
                    learningRule.propagate(hashMap, str.length() - 2, learningRule.getLearningRate());
                    d2 += str.length();
                }
            }
            System.out.println("Iteration = " + (i2 + 1) + ", error = " + (d / d2) + ", time = " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + "s");
        }
    }

    public abstract void propagate(Map<String, DoubleMatrix> map, int i, double d);

    protected abstract void updateParameters(Map<String, DoubleMatrix> map, int i, double d, RNN rnn);

    /* JADX INFO: Access modifiers changed from: protected */
    public DoubleMatrix deriveExp(DoubleMatrix doubleMatrix) {
        return doubleMatrix.mul(DoubleMatrix.ones(1, doubleMatrix.length).sub(doubleMatrix));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DoubleMatrix deriveTanh(DoubleMatrix doubleMatrix) {
        return DoubleMatrix.ones(1, doubleMatrix.length).sub(MatrixFunctions.pow(doubleMatrix, 2.0d));
    }
}
