package smile.clustering;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;

/* loaded from: input_file:smile/clustering/DENCLUE.class */
public class DENCLUE extends PartitionClustering {
    private static final long serialVersionUID = 2;
    private static final Logger logger = LoggerFactory.getLogger(DENCLUE.class);
    private double tol;
    private double sigma;
    public final double[][] attractors;
    private double[] radius;
    private double[][] samples;

    public DENCLUE(int i, double[][] dArr, double[] dArr2, double[][] dArr3, double d, int[] iArr, double d2) {
        super(i, iArr);
        this.attractors = dArr;
        this.radius = dArr2;
        this.samples = dArr3;
        this.sigma = d;
        this.tol = d2;
    }

    public static DENCLUE fit(double[][] dArr, double d, int i) {
        return fit(dArr, d, i, 0.01d, Math.max(10, dArr.length / 200));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public static DENCLUE fit(double[][] dArr, double d, int i, double d2, int i2) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid standard deviation of Gaussian kernel: " + d);
        }
        if (i <= 0 || i > dArr.length) {
            throw new IllegalArgumentException("Invalid number of selected samples: " + i);
        }
        logger.info("Select {} samples by k-means", Integer.valueOf(i));
        double[][] dArr2 = (double[][]) KMeans.fit(dArr, i).centroids;
        int length = dArr.length;
        double[][] dArr3 = new double[length][dArr[0].length];
        double[][] dArr4 = new double[length][2];
        logger.info("Hill-climbing of density function for each observation");
        IntStream.range(0, length).parallel().mapToDouble(i3 -> {
            return climb(dArr[i3], dArr3[i3], dArr4[i3], dArr2, d, d2);
        }).toArray();
        double[] array = Arrays.stream(dArr4).mapToDouble(dArr5 -> {
            return dArr5[0] + dArr5[1];
        }).toArray();
        double mean = MathEx.mean(array);
        logger.info("Clustering attractors with DBSCAN (radius = {})", Double.valueOf(mean));
        DBSCAN<double[]> fit = DBSCAN.fit(dArr3, i2, mean);
        return new DENCLUE(fit.k, dArr3, array, dArr2, d, fit.y, d2);
    }

    public int predict(double[] dArr) {
        int length = this.attractors[0].length;
        if (dArr.length != length) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", Integer.valueOf(dArr.length), Integer.valueOf(length)));
        }
        double[] dArr2 = new double[length];
        double[] dArr3 = new double[2];
        climb(dArr, dArr2, dArr3, this.samples, this.sigma, this.tol);
        double d = dArr3[0] + dArr3[1];
        for (int i = 0; i < this.attractors.length; i++) {
            if (MathEx.distance(this.attractors[i], dArr2) < this.radius[i] + d) {
                return this.y[i];
            }
        }
        return Integer.MAX_VALUE;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double climb(double[] dArr, double[] dArr2, double[] dArr3, double[][] dArr4, double d, double d2) {
        int length = dArr4.length;
        int length2 = dArr.length;
        int length3 = dArr3.length;
        double d3 = 1.0d;
        double pow = Math.pow(6.283185307179586d * d, length2 / 2.0d);
        double d4 = (-0.5d) / (d * d);
        double[] dArr5 = (double[]) dArr.clone();
        double[] dArr6 = new double[length];
        double d5 = Double.MAX_VALUE;
        int i = 0;
        while (true) {
            if (i >= length3 && d5 <= d2) {
                return d3;
            }
            for (int i2 = 0; i2 < length; i2++) {
                dArr6[i2] = Math.exp(d4 * MathEx.squaredDistance(dArr5, dArr4[i2]));
            }
            Arrays.fill(dArr2, 0.0d);
            for (int i3 = 0; i3 < length; i3++) {
                double d6 = dArr6[i3];
                double[] dArr7 = dArr4[i3];
                for (int i4 = 0; i4 < length2; i4++) {
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + (d6 * dArr7[i4]);
                }
            }
            double sum = MathEx.sum(dArr6);
            for (int i6 = 0; i6 < length2; i6++) {
                int i7 = i6;
                dArr2[i7] = dArr2[i7] / sum;
            }
            double d7 = sum / (length * pow);
            d5 = Math.abs(d7 - d3) / d3;
            d3 = d7;
            dArr3[i % length3] = MathEx.distance(dArr2, dArr5);
            System.arraycopy(dArr2, 0, dArr5, 0, length2);
            i++;
        }
    }
}
