package smile.base.mlp;

import java.io.Serializable;
import java.util.function.Consumer;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.stat.distribution.GaussianDistribution;

/* loaded from: input_file:smile/base/mlp/Layer.class */
public abstract class Layer implements Serializable {
    private static final long serialVersionUID = 2;
    protected int n;
    protected int p;
    protected double[] output;
    protected double[] gradient;
    protected DenseMatrix weight;
    protected DenseMatrix delta;
    protected DenseMatrix update;
    protected Consumer<double[]> activation;
    static final /* synthetic */ boolean $assertionsDisabled;

    public Layer(int i, int i2) {
        this.n = i;
        this.p = i2;
        this.weight = Matrix.zeros(i, i2 + 1);
        this.delta = Matrix.zeros(i, i2 + 1);
        this.update = Matrix.zeros(i, i2 + 1);
        GaussianDistribution gaussianDistribution = GaussianDistribution.getInstance();
        double sqrt = Math.sqrt(2.0d / i2);
        for (int i3 = 0; i3 < i2; i3++) {
            for (int i4 = 0; i4 < i; i4++) {
                this.weight.set(i4, i3, sqrt * gaussianDistribution.rand());
            }
        }
    }

    public int getOutputSize() {
        return this.n;
    }

    public int getInputSize() {
        return this.p;
    }

    public double[] output() {
        return this.output;
    }

    public double[] gradient() {
        return this.gradient;
    }

    public void propagate(double[] dArr) {
        if (!$assertionsDisabled && dArr[this.p] != 1.0d) {
            throw new AssertionError("bias/intercept is not 1");
        }
        this.weight.ax(dArr, this.output);
        this.activation.accept(this.output);
    }

    public abstract void backpropagate(double[] dArr);

    public void computeUpdate(double d, double d2, double[] dArr) {
        for (int i = 0; i <= this.p; i++) {
            double d3 = dArr[i];
            for (int i2 = 0; i2 < this.n; i2++) {
                double d4 = d * this.gradient[i2] * d3;
                this.delta.set(i2, i, d4);
                if (d2 > 0.0d) {
                    d4 += d2 * this.update.get(i2, i);
                }
                this.update.set(i2, i, d4);
            }
        }
    }

    public void update(double d, double d2) {
        this.weight.add(this.update);
        if (d2 < 1.0d) {
            for (int i = 0; i < this.p; i++) {
                for (int i2 = 0; i2 < this.n; i2++) {
                    this.weight.mul(i2, i, d2);
                }
            }
        }
        if (d == 1.0d) {
            this.update.fill(0.0d);
        }
    }

    public static HiddenLayerBuilder linear(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.linear());
    }

    public static HiddenLayerBuilder rectifier(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.rectifier());
    }

    public static HiddenLayerBuilder sigmoid(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.sigmoid());
    }

    public static HiddenLayerBuilder tanh(int i) {
        return new HiddenLayerBuilder(i, ActivationFunction.tanh());
    }

    public static OutputLayerBuilder mse(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.MEAN_SQUARED_ERROR);
    }

    public static OutputLayerBuilder mle(int i, OutputFunction outputFunction) {
        return new OutputLayerBuilder(i, outputFunction, Cost.LIKELIHOOD);
    }

    static {
        $assertionsDisabled = !Layer.class.desiredAssertionStatus();
    }
}
