package elki.clustering.em;

import elki.clustering.ClusteringAlgorithm;
import elki.clustering.em.EM;
import elki.clustering.em.models.TextbookMultivariateGaussianModel;
import elki.clustering.em.models.TextbookMultivariateGaussianModelFactory;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.DoubleVector;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.ids.ArrayModifiableDBIDs;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDArrayMIter;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.math.MathUtil;
import elki.math.linearalgebra.ConstrainedQuadraticProblemSolver;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.datastructures.arraylike.IntegerArray;
import elki.utilities.documentation.Description;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import net.jafama.FastMath;

@Reference(authors = "Andrew W. Moore", booktitle = "Advances in Neural Information Processing Systems 11 (NIPS 1998)", title = "Very Fast EM-based Mixture Model Clustering using Multiresolution kd-trees", bibkey = "DBLP:conf/nips/Moore98")
@Description("Gaussian mixture modeling accelerated using a kd-tree")
/* loaded from: input_file:elki/clustering/em/KDTreeEM.class */
public class KDTreeEM implements ClusteringAlgorithm<Clustering<EMModel>> {
    private TextbookMultivariateGaussianModelFactory mfactory;
    private boolean soft;
    private double delta;
    private int k;
    private double mbw;
    private double tau;
    private double tauClass;
    private int miniter;
    private int maxiter;
    protected ArrayModifiableDBIDs sorted;
    private List<TextbookMultivariateGaussianModel> models;
    private List<TextbookMultivariateGaussianModel> newmodels;
    private ConstrainedQuadraticProblemSolver solver;
    private double ipiPow;
    private double[] wsum;
    protected boolean exactAssign;
    private static final Logging LOG = Logging.getLogger(KDTreeEM.class);
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation<>(double[].class);

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:elki/clustering/em/KDTreeEM$KDTree.class */
    public static class KDTree {
        KDTree leftChild;
        KDTree rightChild;
        int left;
        int right;
        double[] sum;
        double[][] sumSq;
        double[] midpoint;
        double[] halfwidth;
        static final /* synthetic */ boolean $assertionsDisabled;

        public KDTree(Relation<? extends NumberVector> relation, ArrayModifiableDBIDs arrayModifiableDBIDs, int i, int i2, double[] dArr, double d) {
            DBIDArrayMIter iter = arrayModifiableDBIDs.iter();
            int length = ((NumberVector) relation.get(iter)).toArray().length;
            this.left = i;
            this.right = i2;
            computeBoundingBox(relation, iter);
            int argmax = VMath.argmax(this.halfwidth);
            if (2.0d * this.halfwidth[argmax] < d * dArr[argmax]) {
                aggregateStats(relation, iter, length);
                return;
            }
            double d2 = this.midpoint[argmax];
            int i3 = i;
            int i4 = i2 - 1;
            while (true) {
                if (i3 > i4 || ((NumberVector) relation.get(iter.seek(i3))).doubleValue(argmax) > d2) {
                    while (i3 <= i4 && ((NumberVector) relation.get(iter.seek(i4))).doubleValue(argmax) >= d2) {
                        i4--;
                    }
                    if (i3 >= i4) {
                        break;
                    }
                    int i5 = i3;
                    i3++;
                    int i6 = i4;
                    i4--;
                    arrayModifiableDBIDs.swap(i5, i6);
                } else {
                    i3++;
                }
            }
            if (!$assertionsDisabled && ((NumberVector) relation.get(iter.seek(i4))).doubleValue(argmax) > d2) {
                throw new AssertionError(((NumberVector) relation.get(iter.seek(i4))).doubleValue(argmax) + " not less than " + d2);
            }
            int i7 = i4 + 1;
            if (i7 == i2) {
                aggregateStats(relation, iter, length);
                return;
            }
            this.leftChild = new KDTree(relation, arrayModifiableDBIDs, i, i7, dArr, d);
            this.rightChild = new KDTree(relation, arrayModifiableDBIDs, i7, i2, dArr, d);
            this.sum = VMath.plus(this.leftChild.sum, this.rightChild.sum);
            this.sumSq = VMath.plus(this.leftChild.sumSq, this.rightChild.sumSq);
        }

        private void computeBoundingBox(Relation<? extends NumberVector> relation, DBIDArrayIter dBIDArrayIter) {
            double[] array = ((NumberVector) relation.get(dBIDArrayIter.seek(this.left))).toArray();
            double[] dArr = (double[]) array.clone();
            dBIDArrayIter.advance();
            while (dBIDArrayIter.getOffset() < this.right) {
                NumberVector numberVector = (NumberVector) relation.get(dBIDArrayIter);
                for (int i = 0; i < array.length; i++) {
                    double doubleValue = numberVector.doubleValue(i);
                    array[i] = doubleValue < array[i] ? doubleValue : array[i];
                    dArr[i] = doubleValue > dArr[i] ? doubleValue : dArr[i];
                }
                dBIDArrayIter.advance();
            }
            for (int i2 = 0; i2 < array.length; i2++) {
                double d = array[i2];
                double d2 = dArr[i2];
                array[i2] = (d + d2) * 0.5d;
                dArr[i2] = (d2 - d) * 0.5d;
            }
            this.midpoint = array;
            this.halfwidth = dArr;
        }

        private void aggregateStats(Relation<? extends NumberVector> relation, DBIDArrayIter dBIDArrayIter, int i) {
            this.sum = new double[i];
            this.sumSq = new double[i][i];
            dBIDArrayIter.seek(this.left);
            while (dBIDArrayIter.getOffset() < this.right) {
                NumberVector numberVector = (NumberVector) relation.get(dBIDArrayIter);
                for (int i2 = 0; i2 < i; i2++) {
                    double doubleValue = numberVector.doubleValue(i2);
                    double[] dArr = this.sum;
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + doubleValue;
                    for (int i4 = 0; i4 < i; i4++) {
                        double doubleValue2 = numberVector.doubleValue(i4);
                        double[] dArr2 = this.sumSq[i2];
                        int i5 = i4;
                        dArr2[i5] = dArr2[i5] + (doubleValue * doubleValue2);
                    }
                }
                dBIDArrayIter.advance();
            }
        }

        static {
            $assertionsDisabled = !KDTreeEM.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:elki/clustering/em/KDTreeEM$Par.class */
    public static class Par implements Parameterizer {
        public static final OptionID K_ID = EM.Par.K_ID;
        public static final OptionID DELTA_ID = EM.Par.DELTA_ID;
        public static final OptionID MBW_ID = new OptionID("emkd.mbw", "Pruning criterion for the KD-Tree during construction. Stop splitting when leafwidth < mbw * width.");
        public static final OptionID TAU_ID = new OptionID("emkd.tau", "Pruning criterion for the KD-Tree during algorithm. Stop traversing when error e < tau * totalweight.");
        public static final OptionID TAU_CLASS_ID = new OptionID("emkd.tauclass", "Parameter for pruning. Drop a class if w[c] < tauclass * max(wmins). Set to 0 to disable dropping of classes.");
        public static final OptionID MINITER_ID = EM.Par.MINITER_ID;
        public static final OptionID MAXITER_ID = EM.Par.MAXITER_ID;
        public static final OptionID SOFT_ID = EM.Par.SOFT_ID;
        public static final OptionID EXACT_ASSIGN_ID = new OptionID("emkd.exactassign", "Assign each point individually, not using the kd-tree in the final step.");
        protected int k;
        protected double mbw;
        protected double tau;
        protected double tauclass;
        protected double delta;
        protected TextbookMultivariateGaussianModelFactory mfactory;
        protected int miniter = 1;
        protected int maxiter = -1;
        boolean soft = false;
        boolean exactAssign = false;

        public void configure(Parameterization parameterization) {
            new IntParameter(K_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).grab(parameterization, i -> {
                this.k = i;
            });
            new DoubleParameter(MBW_ID, 0.01d).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE).addConstraint(CommonConstraints.LESS_THAN_ONE_DOUBLE).grab(parameterization, d -> {
                this.mbw = d;
            });
            new DoubleParameter(TAU_ID, 0.01d).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE).addConstraint(CommonConstraints.LESS_THAN_ONE_DOUBLE).grab(parameterization, d2 -> {
                this.tau = d2;
            });
            new DoubleParameter(TAU_CLASS_ID, 1.0E-4d).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE).addConstraint(CommonConstraints.LESS_THAN_ONE_DOUBLE).grab(parameterization, d3 -> {
                this.tauclass = d3;
            });
            new DoubleParameter(DELTA_ID, 1.0E-7d).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE).grab(parameterization, d4 -> {
                this.delta = d4;
            });
            this.mfactory = (TextbookMultivariateGaussianModelFactory) parameterization.tryInstantiate(TextbookMultivariateGaussianModelFactory.class);
            new IntParameter(MINITER_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT).setOptional(true).grab(parameterization, i2 -> {
                this.miniter = i2;
            });
            new IntParameter(MAXITER_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT).setOptional(true).grab(parameterization, i3 -> {
                this.maxiter = i3;
            });
            new Flag(SOFT_ID).grab(parameterization, z -> {
                this.soft = z;
            });
            new Flag(EXACT_ASSIGN_ID).grab(parameterization, z2 -> {
                this.exactAssign = z2;
            });
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public KDTreeEM m104make() {
            return new KDTreeEM(this.k, this.mbw, this.tau, this.tauclass, this.delta, this.mfactory, this.miniter, this.maxiter, this.soft, this.exactAssign);
        }
    }

    public KDTreeEM(int i, double d, double d2, double d3, double d4, TextbookMultivariateGaussianModelFactory textbookMultivariateGaussianModelFactory, int i2, int i3, boolean z, boolean z2) {
        this.k = 3;
        this.exactAssign = false;
        this.k = i;
        this.mbw = d;
        this.tau = d2;
        this.tauClass = d3;
        this.delta = d4;
        this.mfactory = textbookMultivariateGaussianModelFactory;
        this.miniter = i2;
        this.maxiter = i3;
        this.soft = z;
        this.exactAssign = z2;
    }

    public Clustering<EMModel> run(Relation<? extends NumberVector> relation) {
        int dimensionality = ((NumberVector) relation.get(relation.iterDBIDs())).getDimensionality();
        this.sorted = DBIDUtil.newArray(relation.getDBIDs());
        double[] analyseDimWidth = analyseDimWidth(relation);
        Duration begin = LOG.newDuration(getClass().getName() + ".kdtree.buildtime").begin();
        KDTree kDTree = new KDTree(relation, this.sorted, 0, this.sorted.size(), analyseDimWidth, this.mbw);
        LOG.statistics(begin.end());
        this.models = this.mfactory.buildInitialModels(relation, this.k);
        this.newmodels = new ArrayList(this.k);
        for (int i = 0; i < this.k; i++) {
            this.newmodels.add(new TextbookMultivariateGaussianModel(0.0d, new double[dimensionality]));
        }
        this.wsum = new double[this.k];
        DoubleStatistic doubleStatistic = new DoubleStatistic(getClass().getName() + ".loglikelihood");
        this.solver = new ConstrainedQuadraticProblemSolver(dimensionality);
        this.ipiPow = 1.0d / FastMath.pow(MathUtil.SQRTPI, dimensionality);
        int i2 = 0;
        int i3 = 0;
        double d = Double.NEGATIVE_INFINITY;
        double d2 = 0.0d;
        while (true) {
            if (i2 >= this.maxiter && this.maxiter >= 0) {
                break;
            }
            double d3 = d2;
            Iterator<TextbookMultivariateGaussianModel> it = this.newmodels.iterator();
            while (it.hasNext()) {
                it.next().beginEStep();
            }
            Arrays.fill(this.wsum, 0.0d);
            d2 = makeStats(kDTree, MathUtil.sequence(0, this.k), null) / relation.size();
            for (int i4 = 0; i4 < this.k; i4++) {
                double size = this.wsum[i4] / relation.size();
                if (size <= Double.MIN_NORMAL) {
                    LOG.warning("A cluster has degenerated by pruning.");
                    this.newmodels.get(i4).clone(this.models.get(i4));
                } else {
                    this.newmodels.get(i4).finalizeEStep(size, 0.0d);
                }
            }
            List<TextbookMultivariateGaussianModel> list = this.newmodels;
            this.newmodels = this.models;
            this.models = list;
            LOG.statistics(doubleStatistic.setDouble(d2));
            if (d2 - d > this.delta) {
                i3 = i2;
                d = d2;
            }
            if (i2 >= this.miniter && (Math.abs(d2 - d3) <= this.delta || i3 < (i2 >> 1))) {
                break;
            }
            i2++;
        }
        ArrayList arrayList = new ArrayList(this.k);
        for (int i5 = 0; i5 < this.k; i5++) {
            arrayList.add(DBIDUtil.newArray());
        }
        WritableDataStore<double[]> makeStorage = DataStoreUtil.makeStorage(relation.getDBIDs(), 10, double[].class);
        double assignProbabilitiesToInstances = this.exactAssign ? EM.assignProbabilitiesToInstances(relation, this.models, makeStorage, null) : makeStats(kDTree, MathUtil.sequence(0, this.k), makeStorage) / relation.size();
        LOG.statistics(new LongStatistic(getClass().getName() + ".iterations", i2));
        LOG.statistics(new DoubleStatistic(getClass().getName() + ".loglikelihood", assignProbabilitiesToInstances));
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            ((ModifiableDBIDs) arrayList.get(VMath.argmax((double[]) makeStorage.get(iterDBIDs)))).add(iterDBIDs);
            iterDBIDs.advance();
        }
        Clustering<EMModel> clustering = new Clustering<>();
        Metadata.of(clustering).setLongName("KDTreeEM Clustering");
        for (int i6 = 0; i6 < this.k; i6++) {
            clustering.addToplevelCluster(new Cluster<>((DBIDs) arrayList.get(i6), this.models.get(i6).finalizeCluster()));
        }
        if (this.soft) {
            Metadata.hierarchyOf(clustering).addChild(new MaterializedRelation("KDTreeEM Cluster Probabilities", SOFT_TYPE, relation.getDBIDs(), makeStorage));
        } else {
            makeStorage.destroy();
        }
        this.solver = null;
        this.newmodels = null;
        return clustering;
    }

    private double[] analyseDimWidth(Relation<? extends NumberVector> relation) {
        DBIDIter iterDBIDs = relation.iterDBIDs();
        NumberVector numberVector = (NumberVector) relation.get(iterDBIDs);
        int dimensionality = numberVector.getDimensionality();
        double[] array = numberVector.toArray();
        double[] dArr = (double[]) array.clone();
        iterDBIDs.advance();
        while (iterDBIDs.valid()) {
            NumberVector numberVector2 = (NumberVector) relation.get(iterDBIDs);
            for (int i = 0; i < dimensionality; i++) {
                double doubleValue = numberVector2.doubleValue(i);
                array[i] = array[i] < doubleValue ? array[i] : doubleValue;
                dArr[i] = dArr[i] > doubleValue ? dArr[i] : doubleValue;
            }
            iterDBIDs.advance();
        }
        return VMath.minusEquals(dArr, array);
    }

    private int[] checkStoppingCondition(KDTree kDTree, int[] iArr) {
        if (!(this.models.get(0) instanceof TextbookMultivariateGaussianModel)) {
            return iArr;
        }
        double[][] dArr = new double[this.models.size()][kDTree.sum.length];
        double[][] dArr2 = new double[this.models.size()][kDTree.sum.length];
        double[][] dArr3 = new double[this.models.size()][2];
        for (int i : iArr) {
            calculateModelLimits(kDTree, this.models.get(i), dArr2[i], dArr[i], dArr3[i]);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i2 : iArr) {
            d += this.models.get(i2).getWeight() * dArr3[i2][0];
            d2 += this.models.get(i2).getWeight() * dArr3[i2][1];
        }
        boolean z = true;
        double d3 = Double.NEGATIVE_INFINITY;
        double[] dArr4 = new double[this.models.size()];
        double d4 = kDTree.right - kDTree.left;
        for (int i3 : iArr) {
            double weight = this.models.get(i3).getWeight();
            double d5 = d + (weight * (dArr3[i3][0] - dArr3[i3][1]));
            double d6 = d2 + (weight * (dArr3[i3][1] - dArr3[i3][0]));
            double clamp = MathUtil.clamp((weight * dArr3[i3][0]) / d5, 0.0d, 1.0d);
            d3 = clamp > d3 ? clamp : d3;
            dArr4[i3] = MathUtil.clamp((weight * dArr3[i3][1]) / d6, 0.0d, 1.0d);
            if (d4 * (dArr4[i3] - clamp) > this.tau * (this.newmodels.get(i3).getWeight() + (clamp * d4))) {
                z = false;
            }
        }
        if (z) {
            return null;
        }
        if (this.tauClass <= 0.0d) {
            return iArr;
        }
        IntegerArray integerArray = new IntegerArray(iArr.length);
        for (int i4 : iArr) {
            if (dArr4[i4] >= this.tauClass * d3) {
                integerArray.add(i4);
            }
        }
        return integerArray.toArray();
    }

    private void calculateModelLimits(KDTree kDTree, TextbookMultivariateGaussianModel textbookMultivariateGaussianModel, double[] dArr, double[] dArr2, double[] dArr3) {
        double[] minus = VMath.minus(kDTree.midpoint, kDTree.halfwidth);
        textbookMultivariateGaussianModel.calculateModelLimits(minus, VMath.plusTimes(minus, kDTree.halfwidth, 2.0d), this.solver, this.ipiPow, dArr, dArr2, dArr3);
    }

    private double makeStats(KDTree kDTree, int[] iArr, WritableDataStore<double[]> writableDataStore) {
        int[] checkStoppingCondition;
        int i = kDTree.right - kDTree.left;
        if (iArr.length == 1) {
            double estimateLogDensity = this.models.get(iArr[0]).estimateLogDensity(DoubleVector.wrap(VMath.times(kDTree.sum, 1.0d / i)));
            double[] dArr = this.wsum;
            int i2 = iArr[0];
            dArr[i2] = dArr[i2] + i;
            this.newmodels.get(iArr[0]).updateE(kDTree.sum, kDTree.sumSq, 1.0d, i);
            if (writableDataStore != null) {
                double[] dArr2 = new double[this.k];
                dArr2[iArr[0]] = 1.0d;
                DBIDArrayMIter seek = this.sorted.iter().seek(kDTree.left);
                while (seek.getOffset() < kDTree.right) {
                    writableDataStore.put(seek, dArr2);
                    seek.advance();
                }
            }
            return estimateLogDensity * i;
        }
        if (kDTree.leftChild != null && (checkStoppingCondition = checkStoppingCondition(kDTree, iArr)) != null) {
            return makeStats(kDTree.leftChild, checkStoppingCondition, writableDataStore) + makeStats(kDTree.rightChild, checkStoppingCondition, writableDataStore);
        }
        NumberVector wrap = DoubleVector.wrap(VMath.times(kDTree.sum, 1.0d / i));
        double[] dArr3 = new double[iArr.length];
        for (int i3 = 0; i3 < iArr.length; i3++) {
            dArr3[i3] = this.models.get(iArr[i3]).estimateLogDensity(wrap);
        }
        double logSumExp = EM.logSumExp(dArr3);
        VMath.minusEquals(dArr3, logSumExp);
        double[] dArr4 = writableDataStore != null ? new double[this.k] : null;
        for (int i4 = 0; i4 < iArr.length; i4++) {
            double exp = FastMath.exp(dArr3[i4]);
            double[] dArr5 = this.wsum;
            int i5 = iArr[i4];
            dArr5[i5] = dArr5[i5] + (exp * i);
            this.newmodels.get(iArr[i4]).updateE(kDTree.sum, kDTree.sumSq, exp, exp * i);
            if (dArr4 != null) {
                dArr4[iArr[i4]] = exp;
            }
        }
        if (writableDataStore != null) {
            DBIDArrayMIter seek2 = this.sorted.iter().seek(kDTree.left);
            while (seek2.getOffset() < kDTree.right) {
                writableDataStore.put(seek2, dArr4);
                seek2.advance();
            }
        }
        return logSumExp * i;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }
}
