package org.neuroph.contrib.rnn.util;

import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/neuroph/contrib/rnn/util/LossFunction.class */
public class LossFunction {
    private static double getCategoricalCrossEntropy(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        for (int i = 0; i < doubleMatrix2.length; i++) {
            if (doubleMatrix2.get(i) == 0.0d) {
                doubleMatrix2.put(i, 1.0E-10d);
            }
        }
        return -doubleMatrix.mul(MatrixFunctions.log(doubleMatrix2)).sum();
    }

    public static double getMeanCategoricalCrossEntropy(DoubleMatrix doubleMatrix, DoubleMatrix doubleMatrix2) {
        double d = 0.0d;
        if (doubleMatrix.rows == doubleMatrix2.rows) {
            for (int i = 0; i < doubleMatrix.rows; i++) {
                d += getCategoricalCrossEntropy(doubleMatrix.getRow(i), doubleMatrix2.getRow(i));
            }
            d /= doubleMatrix.rows;
        } else {
            System.exit(-1);
        }
        return d;
    }
}
