/*
 * Decompiled with CFR 0.152.
 */
package com.github.kilianB.clustering;

import com.github.kilianB.ArrayUtil;
import com.github.kilianB.clustering.ClusterAlgorithm;
import com.github.kilianB.clustering.ClusterResult;
import com.github.kilianB.clustering.distance.DistanceFunction;
import com.github.kilianB.clustering.distance.EuclideanDistance;
import com.github.kilianB.pcg.fast.PcgRSFast;
import java.util.DoubleSummaryStatistics;

public class KMeans
implements ClusterAlgorithm {
    protected int k;
    protected DistanceFunction distanceFunction;
    protected int lastIterationCount;

    public KMeans(int clusters) {
        this(clusters, new EuclideanDistance());
    }

    public KMeans(int clusters, DistanceFunction distanceFunction) {
        this.k = clusters;
        this.distanceFunction = distanceFunction;
    }

    @Override
    public ClusterResult cluster(double[][] data) {
        int[] cluster = new int[data.length];
        if (this.k == 1) {
            ArrayUtil.fillArray(cluster, () -> 0);
            return new ClusterResult(cluster, data);
        }
        if (this.k >= data.length) {
            throw new IllegalArgumentException("Can't compute more clusters than datapoints are present");
        }
        int dataDimension = data[0].length;
        DoubleSummaryStatistics[][] clusterMeans = this.computeStartingClusters(data, this.k, dataDimension);
        this.computeKMeans(clusterMeans, data, cluster, dataDimension);
        return new ClusterResult(cluster, data);
    }

    protected DoubleSummaryStatistics[][] computeStartingClusters(double[][] data, int k, int dataDimension) {
        int j;
        PcgRSFast rng = new PcgRSFast();
        double[][] range = new double[data.length][2];
        DoubleSummaryStatistics[][] clusterMeans = new DoubleSummaryStatistics[k][dataDimension];
        double[][] dArray = range;
        int n = range.length;
        int n2 = 0;
        while (n2 < n) {
            double[] arr = dArray[n2];
            arr[0] = Double.MAX_VALUE;
            arr[1] = -1.7976931348623157E308;
            ++n2;
        }
        ArrayUtil.fillArrayMulti(clusterMeans, () -> new DoubleSummaryStatistics());
        int i = 0;
        while (i < data.length) {
            j = 0;
            while (j < dataDimension) {
                double value = data[i][j];
                if (value < range[i][0]) {
                    range[i][0] = value;
                }
                if (value > range[i][1]) {
                    range[i][1] = value;
                }
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < k) {
            j = 0;
            while (j < dataDimension) {
                clusterMeans[i][j].accept(rng.nextDouble() * (range[j][1] - range[j][0]) + range[j][0]);
                ++j;
            }
            ++i;
        }
        return clusterMeans;
    }

    protected void computeKMeans(DoubleSummaryStatistics[][] clusterMeans, double[][] data, int[] cluster, int dataDimension) {
        this.lastIterationCount = 0;
        boolean dirty = false;
        do {
            dirty = false;
            int dataIndex = 0;
            while (dataIndex < data.length) {
                double minDistance = Double.MAX_VALUE;
                int bestCluster = -1;
                int clusterIndex = 0;
                while (clusterIndex < this.k) {
                    double distToCluster = this.distanceFunction.distance(clusterMeans[clusterIndex], data[dataIndex]);
                    if (distToCluster < minDistance) {
                        bestCluster = clusterIndex;
                        minDistance = distToCluster;
                    }
                    ++clusterIndex;
                }
                if (cluster[dataIndex] != bestCluster) {
                    cluster[dataIndex] = bestCluster;
                    dirty = true;
                }
                ++dataIndex;
            }
            if (dirty) {
                ArrayUtil.fillArrayMulti(clusterMeans, () -> new DoubleSummaryStatistics());
                dataIndex = 0;
                while (dataIndex < data.length) {
                    double[] dat = data[dataIndex];
                    DoubleSummaryStatistics[] clusterTemp = clusterMeans[cluster[dataIndex]];
                    int i = 0;
                    while (i < dataDimension) {
                        clusterTemp[i].accept(dat[i]);
                        ++i;
                    }
                    ++dataIndex;
                }
            }
            ++this.lastIterationCount;
        } while (dirty);
    }

    public int iterations() {
        return this.lastIterationCount;
    }
}

