package org.neuroph.contrib.rnn.example;

import java.io.PrintStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.GRU;
import org.neuroph.contrib.rnn.RNN;
import org.neuroph.contrib.rnn.bptt.BackPropagationThroughTime;
import org.neuroph.contrib.rnn.bptt.GRUBackPropagationThroughTime;
import org.neuroph.contrib.rnn.util.LossFunction;
import org.neuroph.contrib.rnn.util.MatrixInitializer;
import org.neuroph.contrib.rnn.util.SequenceModeller;
import org.neuroph.core.data.DataSet;

/* loaded from: input_file:org/neuroph/contrib/rnn/example/GRUStockPricePredictionExample.class */
public class GRUStockPricePredictionExample {
    public static void main(String[] strArr) {
        trainNetwork();
    }

    private static void trainNetwork() {
        DataSet createFromFile = DataSet.createFromFile("google-stock-price-train.csv", 3, 1, ",");
        DataSet createFromFile2 = DataSet.createFromFile("google-stock-price-test.csv", 3, 1, ",");
        int size = new SequenceModeller(createFromFile).getCharIndex().size();
        System.out.println("Creating neural network...");
        GRU gru = new GRU(size, 100, new MatrixInitializer(MatrixInitializer.Type.Uniform, 0.1d, 0.0d, 0.0d));
        GRUBackPropagationThroughTime gRUBackPropagationThroughTime = new GRUBackPropagationThroughTime();
        gRUBackPropagationThroughTime.setLearningRate(0.8d);
        gru.setLearningRule(gRUBackPropagationThroughTime);
        System.out.println("Training network...");
        gRUBackPropagationThroughTime.learn(createFromFile, 100);
        System.out.println("Training completed.");
        testNetwork(gru, createFromFile2);
    }

    private static void testNetwork(RNN rnn, DataSet dataSet) {
        SequenceModeller sequenceModeller = new SequenceModeller(dataSet);
        Map<Integer, String> indexChar = sequenceModeller.getIndexChar();
        Map<String, DoubleMatrix> charVector = sequenceModeller.getCharVector();
        List<String> sequence = sequenceModeller.getSequence();
        System.out.println("Test set:");
        PrintStream printStream = System.out;
        Objects.requireNonNull(printStream);
        dataSet.forEach((v1) -> {
            r1.println(v1);
        });
        System.out.println("Prediction:");
        double d = 0.0d;
        double d2 = 0.0d;
        double currentTimeMillis = System.currentTimeMillis();
        for (int i = 0; i < sequence.size(); i++) {
            String str = sequence.get(i);
            HashMap hashMap = new HashMap();
            System.out.print(String.valueOf(str.charAt(0)));
            for (int i2 = 0; i2 < str.length() - 1; i2++) {
                hashMap.put("input" + i2, charVector.get(String.valueOf(str.charAt(i2))));
                rnn.activate(i2, hashMap);
                DoubleMatrix decode = rnn.decode(hashMap.get("output" + i2));
                hashMap.put("predictedResult" + i2, decode);
                DoubleMatrix doubleMatrix = charVector.get(String.valueOf(str.charAt(i2 + 1)));
                hashMap.put("result" + i2, doubleMatrix);
                System.out.print(indexChar.get(Integer.valueOf(decode.argmax())));
                d += LossFunction.getMeanCategoricalCrossEntropy(decode, doubleMatrix);
            }
            System.out.println();
            BackPropagationThroughTime learningRule = rnn.getLearningRule();
            learningRule.propagate(hashMap, str.length() - 2, learningRule.getLearningRate());
            d2 += str.length();
        }
        System.out.println("Error = " + (d / d2) + ", time = " + ((System.currentTimeMillis() - currentTimeMillis) / 1000.0d) + "s");
    }
}
