package org.neuroph.contrib.rnn;

import java.util.Map;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.util.Activation;
import org.neuroph.contrib.rnn.util.MatrixInitializer;

/* loaded from: input_file:org/neuroph/contrib/rnn/LSTM.class */
public final class LSTM extends RNN {
    private DoubleMatrix inputGateInputWeight;
    private DoubleMatrix inputGateOutputWeight;
    private DoubleMatrix inputGateMemoryCellWeight;
    private DoubleMatrix inputGateBias;
    private DoubleMatrix forgetGateInputWeight;
    private DoubleMatrix forgetGateOutputWeight;
    private DoubleMatrix forgetGateMemoryCellWeight;
    private DoubleMatrix forgetGateBias;
    private DoubleMatrix memoryCellInputWeight;
    private DoubleMatrix memoryCellOutputWeight;
    private DoubleMatrix memoryCellBias;
    private DoubleMatrix outputGateInputWeight;
    private DoubleMatrix outputGateOutputWeight;
    private DoubleMatrix outputGateMemoryCellWeight;
    private DoubleMatrix outputGateBias;
    private DoubleMatrix outputWeight;
    private DoubleMatrix outputBias;

    public LSTM(int i, int i2, MatrixInitializer matrixInitializer) {
        this.inputSize = i;
        this.outputSize = i2;
        if (matrixInitializer.getType() == MatrixInitializer.Type.Uniform) {
            setUniformWeights(matrixInitializer);
        } else if (matrixInitializer.getType() == MatrixInitializer.Type.Gaussian) {
            setGaussianWeights(matrixInitializer);
        }
    }

    public DoubleMatrix getInputGateInputWeight() {
        return this.inputGateInputWeight;
    }

    public void setInputGateInputWeight(DoubleMatrix doubleMatrix) {
        this.inputGateInputWeight = doubleMatrix;
    }

    public DoubleMatrix getInputGateOutputWeight() {
        return this.inputGateOutputWeight;
    }

    public void setInputGateOutputWeight(DoubleMatrix doubleMatrix) {
        this.inputGateOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getInputGateMemoryCellWeight() {
        return this.inputGateMemoryCellWeight;
    }

    public void setInputGateMemoryCellWeight(DoubleMatrix doubleMatrix) {
        this.inputGateMemoryCellWeight = doubleMatrix;
    }

    public DoubleMatrix getInputGateBias() {
        return this.inputGateBias;
    }

    public void setInputGateBias(DoubleMatrix doubleMatrix) {
        this.inputGateBias = doubleMatrix;
    }

    public DoubleMatrix getForgetGateInputWeight() {
        return this.forgetGateInputWeight;
    }

    public void setForgetGateInputWeight(DoubleMatrix doubleMatrix) {
        this.forgetGateInputWeight = doubleMatrix;
    }

    public DoubleMatrix getForgetGateOutputWeight() {
        return this.forgetGateOutputWeight;
    }

    public void setForgetGateOutputWeight(DoubleMatrix doubleMatrix) {
        this.forgetGateOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getForgetGateMemoryCellWeight() {
        return this.forgetGateMemoryCellWeight;
    }

    public void setForgetGateMemoryCellWeight(DoubleMatrix doubleMatrix) {
        this.forgetGateMemoryCellWeight = doubleMatrix;
    }

    public DoubleMatrix getForgetGateBias() {
        return this.forgetGateBias;
    }

    public void setForgetGateBias(DoubleMatrix doubleMatrix) {
        this.forgetGateBias = doubleMatrix;
    }

    public DoubleMatrix getMemoryCellInputWeight() {
        return this.memoryCellInputWeight;
    }

    public void setMemoryCellInputWeight(DoubleMatrix doubleMatrix) {
        this.memoryCellInputWeight = doubleMatrix;
    }

    public DoubleMatrix getMemoryCellOutputWeight() {
        return this.memoryCellOutputWeight;
    }

    public void setMemoryCellOutputWeight(DoubleMatrix doubleMatrix) {
        this.memoryCellOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getMemoryCellBias() {
        return this.memoryCellBias;
    }

    public void setMemoryCellBias(DoubleMatrix doubleMatrix) {
        this.memoryCellBias = doubleMatrix;
    }

    public DoubleMatrix getOutputGateInputWeight() {
        return this.outputGateInputWeight;
    }

    public void setOutputGateInputWeight(DoubleMatrix doubleMatrix) {
        this.outputGateInputWeight = doubleMatrix;
    }

    public DoubleMatrix getOutputGateOutputWeight() {
        return this.outputGateOutputWeight;
    }

    public void setOutputGateOutputWeight(DoubleMatrix doubleMatrix) {
        this.outputGateOutputWeight = doubleMatrix;
    }

    public DoubleMatrix getOutputGateMemoryCellWeight() {
        return this.outputGateMemoryCellWeight;
    }

    public void setOutputGateMemoryCellWeight(DoubleMatrix doubleMatrix) {
        this.outputGateMemoryCellWeight = doubleMatrix;
    }

    public DoubleMatrix getOutputGateBias() {
        return this.outputGateBias;
    }

    public void setOutputGateBias(DoubleMatrix doubleMatrix) {
        this.outputGateBias = doubleMatrix;
    }

    public DoubleMatrix getOutputWeight() {
        return this.outputWeight;
    }

    public void setOutputWeight(DoubleMatrix doubleMatrix) {
        this.outputWeight = doubleMatrix;
    }

    public DoubleMatrix getOutputBias() {
        return this.outputBias;
    }

    public void setOutputBias(DoubleMatrix doubleMatrix) {
        this.outputBias = doubleMatrix;
    }

    @Override // org.neuroph.contrib.rnn.RNN
    public void activate(int i, Map<String, DoubleMatrix> map) {
        DoubleMatrix doubleMatrix;
        DoubleMatrix doubleMatrix2;
        DoubleMatrix doubleMatrix3 = map.get("input" + i);
        if (i == 0) {
            doubleMatrix = new DoubleMatrix(1, this.outputSize);
            doubleMatrix2 = doubleMatrix.dup();
        } else {
            doubleMatrix = map.get("output" + (i - 1));
            doubleMatrix2 = map.get("memoryCellActivation" + (i - 1));
        }
        DoubleMatrix logistic = Activation.logistic(doubleMatrix3.mmul(this.inputGateInputWeight).add(doubleMatrix.mmul(this.inputGateOutputWeight)).add(doubleMatrix2.mmul(this.inputGateMemoryCellWeight)).add(this.inputGateBias));
        DoubleMatrix logistic2 = Activation.logistic(doubleMatrix3.mmul(this.forgetGateInputWeight).add(doubleMatrix.mmul(this.forgetGateOutputWeight)).add(doubleMatrix2.mmul(this.forgetGateMemoryCellWeight)).add(this.forgetGateBias));
        DoubleMatrix tanh = Activation.tanh(doubleMatrix3.mmul(this.memoryCellInputWeight).add(doubleMatrix.mmul(this.memoryCellOutputWeight)).add(this.memoryCellBias));
        DoubleMatrix add = logistic2.mul(doubleMatrix2).add(logistic.mul(tanh));
        DoubleMatrix logistic3 = Activation.logistic(doubleMatrix3.mmul(this.outputGateInputWeight).add(doubleMatrix.mmul(this.outputGateOutputWeight)).add(add.mmul(this.outputGateMemoryCellWeight)).add(this.outputGateBias));
        DoubleMatrix tanh2 = Activation.tanh(add);
        DoubleMatrix mul = logistic3.mul(tanh2);
        map.put("inputActivation" + i, logistic);
        map.put("forgetActivation" + i, logistic2);
        map.put("memoryCellGate" + i, tanh);
        map.put("memoryCellActivation" + i, add);
        map.put("outputActivation" + i, logistic3);
        map.put("outputActivationGate" + i, tanh2);
        map.put("output" + i, mul);
    }

    @Override // org.neuroph.contrib.rnn.RNN
    public DoubleMatrix decode(DoubleMatrix doubleMatrix) {
        return Activation.softmax(doubleMatrix.mmul(this.outputWeight).add(this.outputBias));
    }

    @Override // org.neuroph.contrib.rnn.RNN
    protected void setUniformWeights(MatrixInitializer matrixInitializer) {
        this.inputGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.inputGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.inputGateMemoryCellWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.inputGateBias = new DoubleMatrix(1, this.outputSize);
        this.forgetGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.forgetGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.forgetGateMemoryCellWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.forgetGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.outputGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.outputGateMemoryCellWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.outputGateBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.uniform(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }

    @Override // org.neuroph.contrib.rnn.RNN
    protected void setGaussianWeights(MatrixInitializer matrixInitializer) {
        this.inputGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.inputGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.inputGateMemoryCellWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.inputGateBias = new DoubleMatrix(1, this.outputSize);
        this.forgetGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.forgetGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.forgetGateMemoryCellWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.forgetGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.outputGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.outputGateMemoryCellWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.outputGateBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.gaussian(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }
}
