package smile.base.svm;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.kernel.MercerKernel;

/* loaded from: input_file:smile/base/svm/LASVM.class */
public class LASVM<T> implements Serializable {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(LASVM.class);
    private static final double TAU = 1.0E-12d;
    private MercerKernel<T> kernel;
    private double Cp;
    private double Cn;
    private double tol;
    private LinkedList<SupportVector<T>> sv;
    private double b;
    private boolean minmaxflag;
    private SupportVector<T> svmin;
    private SupportVector<T> svmax;
    private double gmin;
    private double gmax;
    private T[] x;
    private double[][] K;

    public LASVM(MercerKernel<T> mercerKernel, double d, double d2) {
        this(mercerKernel, d, d, d2);
    }

    public LASVM(MercerKernel<T> mercerKernel, double d, double d2, double d3) {
        this.Cp = 1.0d;
        this.Cn = 1.0d;
        this.tol = 0.001d;
        this.sv = new LinkedList<>();
        this.b = 0.0d;
        this.minmaxflag = false;
        this.svmin = null;
        this.svmax = null;
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308d;
        this.kernel = mercerKernel;
        this.Cp = d;
        this.Cn = d2;
        this.tol = d3;
    }

    public KernelMachine<T> fit(T[] tArr, int[] iArr) {
        return fit(tArr, iArr, 2);
    }

    /* JADX WARN: Type inference failed for: r1v3, types: [double[], double[][]] */
    public KernelMachine<T> fit(T[] tArr, int[] iArr, int i) {
        this.x = tArr;
        this.K = new double[tArr.length];
        init(tArr, iArr);
        int min = Math.min(tArr.length, 1000);
        int i2 = 0;
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 : MathEx.permutate(tArr.length)) {
                process(i4, tArr[i4], iArr[i4]);
                do {
                    reprocess(this.tol);
                    minmax();
                } while (this.gmax - this.gmin > 1000.0d);
                i2++;
                if (i2 % min == 0) {
                    logger.info("{} iterations, {} support vectors", Integer.valueOf(i2), Integer.valueOf(this.sv.size()));
                }
            }
        }
        finish();
        int size = this.sv.size();
        Object[] objArr = (Object[]) Array.newInstance(tArr.getClass().getComponentType(), size);
        double[] dArr = new double[size];
        for (int i5 = 0; i5 < size; i5++) {
            SupportVector<T> supportVector = this.sv.get(i5);
            objArr[i5] = supportVector.x;
            dArr[i5] = supportVector.alpha;
        }
        return new KernelMachine<>(this.kernel, objArr, dArr, this.b);
    }

    private void init(T[] tArr, int[] iArr) {
        int i = 0;
        int i2 = 0;
        for (int i3 : MathEx.permutate(tArr.length)) {
            if (iArr[i3] != 1 || i >= 5) {
                if (iArr[i3] == -1 && i2 < 5 && process(i3, tArr[i3], iArr[i3])) {
                    i2++;
                }
            } else if (process(i3, tArr[i3], iArr[i3])) {
                i++;
            }
            if (i >= 5 && i2 >= 5) {
                return;
            }
        }
    }

    private void minmax() {
        if (this.minmaxflag) {
            return;
        }
        this.gmin = Double.MAX_VALUE;
        this.gmax = -1.7976931348623157E308d;
        Iterator<SupportVector<T>> it = this.sv.iterator();
        while (it.hasNext()) {
            SupportVector<T> next = it.next();
            double d = next.g;
            double d2 = next.alpha;
            if (d < this.gmin && d2 > next.cmin) {
                this.svmin = next;
                this.gmin = d;
            }
            if (d > this.gmax && d2 < next.cmax) {
                this.svmax = next;
                this.gmax = d;
            }
        }
        this.minmaxflag = true;
    }

    private double k(int i, int i2) {
        double d = Double.NaN;
        double[] dArr = this.K[i];
        if (dArr != null) {
            d = dArr[i2];
        }
        if (Double.isNaN(d)) {
            d = this.kernel.k(this.x[i], this.x[i2]);
            if (dArr != null) {
                dArr[i2] = d;
            }
        }
        return d;
    }

    private boolean smo(SupportVector<T> supportVector, SupportVector<T> supportVector2, double d) {
        if (supportVector == null && supportVector2 == null) {
            minmax();
            if (this.gmax > (-this.gmin)) {
                supportVector2 = this.svmax;
            } else {
                supportVector = this.svmin;
            }
        }
        double d2 = Double.NaN;
        if (supportVector2 == null) {
            double d3 = supportVector.k;
            double d4 = supportVector.g;
            double d5 = 0.0d;
            Iterator<SupportVector<T>> it = this.sv.iterator();
            while (it.hasNext()) {
                SupportVector<T> next = it.next();
                double d6 = next.g - d4;
                double k = k(supportVector.i, next.i);
                double d7 = (d3 + next.k) - (2.0d * k);
                if (d7 <= 0.0d) {
                    d7 = 1.0E-12d;
                }
                double d8 = d6 / d7;
                if ((d8 > 0.0d && next.alpha < next.cmax) || (d8 < 0.0d && next.alpha > next.cmin)) {
                    double d9 = d6 * d8;
                    if (d9 > d5) {
                        d5 = d9;
                        supportVector2 = next;
                        d2 = k;
                    }
                }
            }
        }
        if (supportVector == null) {
            double d10 = supportVector2.k;
            double d11 = supportVector2.g;
            double d12 = 0.0d;
            Iterator<SupportVector<T>> it2 = this.sv.iterator();
            while (it2.hasNext()) {
                SupportVector<T> next2 = it2.next();
                double d13 = d11 - next2.g;
                double k2 = k(supportVector2.i, next2.i);
                double d14 = (d10 + next2.k) - (2.0d * k2);
                if (d14 <= 0.0d) {
                    d14 = 1.0E-12d;
                }
                double d15 = d13 / d14;
                if ((d15 > 0.0d && next2.alpha > next2.cmin) || (d15 < 0.0d && next2.alpha < next2.cmax)) {
                    double d16 = d13 * d15;
                    if (d16 > d12) {
                        d12 = d16;
                        supportVector = next2;
                        d2 = k2;
                    }
                }
            }
        }
        if (supportVector == null || supportVector2 == null) {
            return false;
        }
        if (Double.isNaN(d2)) {
            d2 = this.kernel.k(supportVector.x, supportVector2.x);
        }
        double d17 = (supportVector.k + supportVector2.k) - (2.0d * d2);
        if (d17 <= 0.0d) {
            d17 = 1.0E-12d;
        }
        double d18 = (supportVector2.g - supportVector.g) / d17;
        if (d18 >= 0.0d) {
            double d19 = supportVector.alpha - supportVector.cmin;
            if (d19 < d18) {
                d18 = d19;
            }
            double d20 = supportVector2.cmax - supportVector2.alpha;
            if (d20 < d18) {
                d18 = d20;
            }
        } else {
            double d21 = supportVector2.cmin - supportVector2.alpha;
            if (d21 > d18) {
                d18 = d21;
            }
            double d22 = supportVector.alpha - supportVector.cmax;
            if (d22 > d18) {
                d18 = d22;
            }
        }
        supportVector.alpha -= d18;
        supportVector2.alpha += d18;
        Iterator<SupportVector<T>> it3 = this.sv.iterator();
        while (it3.hasNext()) {
            SupportVector<T> next3 = it3.next();
            next3.g -= d18 * (k(supportVector2.i, next3.i) - k(supportVector.i, next3.i));
        }
        this.minmaxflag = false;
        minmax();
        this.b = (this.gmax + this.gmin) / 2.0d;
        return this.gmax - this.gmin > d;
    }

    private boolean process(int i, T t, int i2) {
        if (i2 != 1 && i2 != -1) {
            throw new IllegalArgumentException("Invalid label: " + i2);
        }
        Iterator<SupportVector<T>> it = this.sv.iterator();
        while (it.hasNext()) {
            if (it.next().x == t) {
                return true;
            }
        }
        double[] dArr = new double[this.K.length];
        Arrays.fill(dArr, Double.NaN);
        double sum = i2 - ((Stream) this.sv.stream().parallel()).mapToDouble(supportVector -> {
            double k = this.kernel.k(supportVector.x, t);
            dArr[supportVector.i] = k;
            return supportVector.alpha * k;
        }).sum();
        minmax();
        if (this.gmin < this.gmax) {
            if (i2 > 0 && sum < this.gmin) {
                return false;
            }
            if (i2 < 0 && sum > this.gmax) {
                return false;
            }
        }
        SupportVector<T> supportVector2 = new SupportVector<>(i, t, i2, 0.0d, sum, this.Cp, this.Cn, this.kernel.k(t, t));
        this.sv.addFirst(supportVector2);
        this.K[i] = dArr;
        if (i2 > 0) {
            smo(null, supportVector2, 0.0d);
        } else {
            smo(supportVector2, null, 0.0d);
        }
        this.minmaxflag = false;
        return true;
    }

    private boolean reprocess(double d) {
        boolean smo = smo(null, null, d);
        evict();
        return smo;
    }

    private void finish() {
        finish(this.tol, this.sv.size());
        int i = 0;
        Iterator<SupportVector<T>> it = this.sv.iterator();
        while (it.hasNext()) {
            SupportVector<T> next = it.next();
            if (next.alpha == next.cmin || next.alpha == next.cmax) {
                i++;
            }
        }
        logger.info("{} samples, {} support vectors, {} bounded", new Object[]{Integer.valueOf(this.x.length), Integer.valueOf(this.sv.size()), Integer.valueOf(i)});
    }

    private void finish(double d, int i) {
        logger.info("Finalizing the training by reprocess.");
        for (int i2 = 1; i2 <= i && smo(null, null, d); i2++) {
            if (i2 % 1000 == 0) {
                logger.info("{} reprocess iterations.", Integer.valueOf(i2));
            }
        }
        evict();
    }

    private void evict() {
        minmax();
        Iterator<SupportVector<T>> it = this.sv.iterator();
        while (it.hasNext()) {
            SupportVector<T> next = it.next();
            if (next.alpha == 0.0d && ((next.g >= this.gmax && 0.0d >= next.cmax) || (next.g <= this.gmin && 0.0d <= next.cmin))) {
                this.K[next.i] = null;
                it.remove();
            }
        }
    }
}
