package org.neuroph.contrib.rnn;

import java.util.Map;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.util.Activation;
import org.neuroph.contrib.rnn.util.MatrixInitializer;

/* loaded from: input_file:org/neuroph/contrib/rnn/GRU.class */
public final class GRU extends RNN {
    private DoubleMatrix resetGateInputWeight;
    private DoubleMatrix resetGateOutputWeight;
    private DoubleMatrix resetGateBias;
    private DoubleMatrix updateGateInputWeight;
    private DoubleMatrix updateGateOutputWeight;
    private DoubleMatrix updateGateBias;
    private DoubleMatrix memoryCellInputWeight;
    private DoubleMatrix memoryCellOutputWeight;
    private DoubleMatrix memoryCellBias;
    private DoubleMatrix outputWeight;
    private DoubleMatrix outputBias;

    public GRU(int i, int i2, MatrixInitializer matrixInitializer) {
        this.inputSize = i;
        this.outputSize = i2;
        if (matrixInitializer.getType() == MatrixInitializer.Type.Uniform) {
            setUniformWeights(matrixInitializer);
        } else if (matrixInitializer.getType() == MatrixInitializer.Type.Gaussian) {
            setGaussianWeights(matrixInitializer);
        }
    }

    public DoubleMatrix getResetGateInputWeight() {
        return this.resetGateInputWeight;
    }

    public void setResetGateInputWeight(DoubleMatrix doubleMatrix) {
        this.resetGateInputWeight = doubleMatrix;
    }

    public DoubleMatrix getResetGateOutputWeight() {
        return this.resetGateOutputWeight;
    }

    public void setResetGateOutputWeight(DoubleMatrix doubleMatrix) {
        this.resetGateOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getResetGateBias() {
        return this.resetGateBias;
    }

    public void setResetGateBias(DoubleMatrix doubleMatrix) {
        this.resetGateBias = doubleMatrix;
    }

    public DoubleMatrix getUpdateGateInputWeight() {
        return this.updateGateInputWeight;
    }

    public void setUpdateGateInputWeight(DoubleMatrix doubleMatrix) {
        this.updateGateInputWeight = doubleMatrix;
    }

    public DoubleMatrix getUpdateGateOutputWeight() {
        return this.updateGateOutputWeight;
    }

    public void setUpdateGateOutputWeight(DoubleMatrix doubleMatrix) {
        this.updateGateOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getUpdateGateBias() {
        return this.updateGateBias;
    }

    public void setUpdateGateBias(DoubleMatrix doubleMatrix) {
        this.updateGateBias = doubleMatrix;
    }

    public DoubleMatrix getMemoryCellInputWeight() {
        return this.memoryCellInputWeight;
    }

    public void setMemoryCellInputWeight(DoubleMatrix doubleMatrix) {
        this.memoryCellInputWeight = doubleMatrix;
    }

    public DoubleMatrix getMemoryCellOutputWeight() {
        return this.memoryCellOutputWeight;
    }

    public void setMemoryCellOutputWeight(DoubleMatrix doubleMatrix) {
        this.memoryCellOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getMemoryCellBias() {
        return this.memoryCellBias;
    }

    public void setMemoryCellBias(DoubleMatrix doubleMatrix) {
        this.memoryCellBias = doubleMatrix;
    }

    public DoubleMatrix getOutputWeight() {
        return this.outputWeight;
    }

    public void setOutputWeight(DoubleMatrix doubleMatrix) {
        this.outputWeight = doubleMatrix;
    }

    public DoubleMatrix getOutputBias() {
        return this.outputBias;
    }

    public void setOutputBias(DoubleMatrix doubleMatrix) {
        this.outputBias = doubleMatrix;
    }

    @Override // org.neuroph.contrib.rnn.RNN
    public void activate(int i, Map<String, DoubleMatrix> map) {
        DoubleMatrix doubleMatrix = map.get("input" + i);
        DoubleMatrix doubleMatrix2 = i == 0 ? new DoubleMatrix(1, this.outputSize) : map.get("output" + (i - 1));
        DoubleMatrix logistic = Activation.logistic(doubleMatrix.mmul(this.resetGateInputWeight).add(doubleMatrix2.mmul(this.resetGateOutputWeight)).add(this.resetGateBias));
        DoubleMatrix logistic2 = Activation.logistic(doubleMatrix.mmul(this.updateGateInputWeight).add(doubleMatrix2.mmul(this.updateGateOutputWeight)).add(this.updateGateBias));
        DoubleMatrix tanh = Activation.tanh(doubleMatrix.mmul(this.memoryCellInputWeight).add(logistic.mul(doubleMatrix2).mmul(this.memoryCellOutputWeight)).add(this.memoryCellBias));
        DoubleMatrix add = DoubleMatrix.ones(1, logistic2.columns).sub(logistic2).mul(doubleMatrix2).add(logistic2.mul(tanh));
        map.put("resetActivation" + i, logistic);
        map.put("updateActivation" + i, logistic2);
        map.put("memoryCellGate" + i, tanh);
        map.put("output" + i, add);
    }

    @Override // org.neuroph.contrib.rnn.RNN
    public DoubleMatrix decode(DoubleMatrix doubleMatrix) {
        return Activation.softmax(doubleMatrix.mmul(this.outputWeight).add(this.outputBias));
    }

    @Override // org.neuroph.contrib.rnn.RNN
    protected void setUniformWeights(MatrixInitializer matrixInitializer) {
        this.resetGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.resetGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.resetGateBias = new DoubleMatrix(1, this.outputSize);
        this.updateGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.updateGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.updateGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.uniform(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }

    @Override // org.neuroph.contrib.rnn.RNN
    protected void setGaussianWeights(MatrixInitializer matrixInitializer) {
        this.resetGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.resetGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.resetGateBias = new DoubleMatrix(1, this.outputSize);
        this.updateGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.updateGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.updateGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.gaussian(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }
}
