package smile.regression;

import smile.math.matrix.DenseMatrix;
import smile.math.matrix.Matrix;
import smile.math.matrix.SVD;

/* loaded from: input_file:smile/regression/RLS.class */
public class RLS implements OnlineRegression<double[]> {
    private static final long serialVersionUID = 1;
    private int p;
    private double[] w;
    private double lambda;
    private DenseMatrix V;
    private double[] x1;
    private double[] Vx;

    /* loaded from: input_file:smile/regression/RLS$Trainer.class */
    public static class Trainer extends RegressionTrainer<double[]> {
        @Override // smile.regression.RegressionTrainer
        public RLS train(double[][] dArr, double[] dArr2) {
            return new RLS(dArr, dArr2);
        }
    }

    public RLS(double[][] dArr, double[] dArr2) {
        this(dArr, dArr2, 1.0d);
    }

    public RLS(double[][] dArr, double[] dArr2, double d) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(dArr2.length)));
        }
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("The forgetting factor must be in (0, 1]");
        }
        this.lambda = d;
        int length = dArr.length;
        this.p = dArr[0].length;
        if (length <= this.p) {
            throw new IllegalArgumentException(String.format("The input matrix is not over determined: %d rows, %d columns", Integer.valueOf(length), Integer.valueOf(this.p)));
        }
        DenseMatrix zeros = Matrix.zeros(length, this.p + 1);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < this.p; i2++) {
                zeros.set(i, i2, dArr[i][i2]);
            }
            zeros.set(i, this.p, 1.0d);
        }
        this.w = new double[this.p + 1];
        SVD svd = zeros.svd();
        svd.solve(dArr2, this.w);
        this.V = svd.CholeskyOfAtA().inverse();
        this.Vx = new double[this.p + 1];
        this.x1 = new double[this.p + 1];
        this.x1[this.p] = 1.0d;
    }

    public double[] coefficients() {
        return this.w;
    }

    @Override // smile.regression.Regression
    public double predict(double[] dArr) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        double d = this.w[this.p];
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * this.w[i];
        }
        return d;
    }

    public void learn(double[][] dArr, double[] dArr2) {
        if (dArr.length != dArr2.length) {
            throw new IllegalArgumentException(String.format("Input vector x of size %d not equal to length %d of y", Integer.valueOf(dArr.length), Integer.valueOf(dArr2.length)));
        }
        for (int i = 0; i < dArr.length; i++) {
            learn(dArr[i], dArr2[i]);
        }
    }

    @Override // smile.regression.OnlineRegression
    public void learn(double[] dArr, double d) {
        if (dArr.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(this.p)));
        }
        System.arraycopy(dArr, 0, this.x1, 0, this.p);
        double xax = 1.0d + this.V.xax(this.x1);
        if (Double.isNaN(1.0d / xax)) {
            throw new IllegalStateException("The updated V matrix is no longer invertible.");
        }
        this.V.ax(this.x1, this.Vx);
        for (int i = 0; i <= this.p; i++) {
            for (int i2 = 0; i2 <= this.p; i2++) {
                this.V.set(i2, i, (this.V.get(i2, i) - ((this.Vx[i2] * this.Vx[i]) / xax)) / this.lambda);
            }
        }
        this.V.ax(this.x1, this.Vx);
        double predict = d - predict(dArr);
        for (int i3 = 0; i3 <= this.p; i3++) {
            double[] dArr2 = this.w;
            int i4 = i3;
            dArr2[i4] = dArr2[i4] + (this.Vx[i3] * predict);
        }
    }

    public double getForgettingFactor() {
        return this.lambda;
    }

    public void setForgettingFactor(double d) {
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("The forgetting factor must be in (0, 1]");
        }
        this.lambda = d;
    }
}
