package elki.clustering.em;

import elki.clustering.ClusteringAlgorithm;
import elki.clustering.em.models.BetulaClusterModel;
import elki.clustering.em.models.BetulaClusterModelFactory;
import elki.clustering.kmeans.AbstractKMeans;
import elki.clustering.kmeans.KMeans;
import elki.data.Cluster;
import elki.data.Clustering;
import elki.data.NumberVector;
import elki.data.model.EMModel;
import elki.data.type.SimpleTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.datastore.DataStoreUtil;
import elki.database.datastore.WritableDataStore;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.index.tree.betula.CFTree;
import elki.index.tree.betula.features.ClusterFeature;
import elki.logging.Logging;
import elki.logging.statistics.DoubleStatistic;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.math.linearalgebra.VMath;
import elki.result.Metadata;
import elki.utilities.documentation.Reference;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import it.unimi.dsi.fastutil.objects.Reference2ObjectOpenHashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import net.jafama.FastMath;

@Reference(authors = "Andreas Lang and Erich Schubert", title = "BETULA: Fast Clustering of Large Data with Improved BIRCH CF-Trees", booktitle = "Information Systems", url = "https://doi.org/10.1016/j.is.2021.101918", bibkey = "DBLP:journals/is/LangS22")
/* loaded from: input_file:elki/clustering/em/BetulaGMM.class */
public class BetulaGMM implements ClusteringAlgorithm<Clustering<EMModel>> {
    CFTree.Factory<?> cffactory;
    int k;
    private double delta;
    int maxiter;
    private double prior;
    private boolean soft;
    protected static final double MIN_LOGLIKELIHOOD = -100000.0d;
    BetulaClusterModelFactory<?> initializer;
    private static final Logging LOG = Logging.getLogger(BetulaGMMWeighted.class);
    public static final SimpleTypeInformation<double[]> SOFT_TYPE = new SimpleTypeInformation<>(double[].class);

    /* loaded from: input_file:elki/clustering/em/BetulaGMM$Par.class */
    public static class Par implements Parameterizer {
        public static final OptionID INIT_ID = new OptionID("em.model", "Model factory.");
        public static final OptionID DELTA_ID = new OptionID("em.delta", "The termination criterion for maximization of E(M): E(M) - E(M') < em.delta");
        public static final OptionID PRIOR_ID = new OptionID("em.map.prior", "Regularization factor for MAP estimation.");
        CFTree.Factory<?> cffactory;
        protected int k;
        protected double delta;
        protected boolean soft;
        protected BetulaClusterModelFactory<?> initialization;
        protected int maxiter = -1;
        protected double prior = 0.0d;

        public void configure(Parameterization parameterization) {
            this.cffactory = (CFTree.Factory) parameterization.tryInstantiate(CFTree.Factory.class);
            new ObjectParameter(INIT_ID, BetulaClusterModelFactory.class).grab(parameterization, betulaClusterModelFactory -> {
                this.initialization = betulaClusterModelFactory;
            });
            new IntParameter(AbstractKMeans.K_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).grab(parameterization, i -> {
                this.k = i;
            });
            new DoubleParameter(DELTA_ID, 1.0E-7d).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_DOUBLE).grab(parameterization, d -> {
                this.delta = d;
            });
            new DoubleParameter(PRIOR_ID).setOptional(true).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).grab(parameterization, d2 -> {
                this.prior = d2;
            });
            new IntParameter(KMeans.MAXITER_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT).setOptional(true).grab(parameterization, i2 -> {
                this.maxiter = i2;
            });
        }

        @Override // 
        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public BetulaGMM mo97make() {
            return new BetulaGMM(this.cffactory, this.delta, this.k, this.maxiter, this.soft, this.initialization, this.prior);
        }
    }

    public BetulaGMM(CFTree.Factory<?> factory, double d, int i, int i2, boolean z, BetulaClusterModelFactory<?> betulaClusterModelFactory, double d2) {
        this.prior = 0.0d;
        this.cffactory = factory;
        this.delta = d;
        this.k = i;
        this.maxiter = i2;
        this.soft = z;
        this.initializer = betulaClusterModelFactory;
        this.prior = d2;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{TypeUtil.NUMBER_VECTOR_FIELD});
    }

    public Clustering<EMModel> run(Relation<NumberVector> relation) {
        if (relation.size() == 0) {
            throw new IllegalArgumentException("database empty: must contain elements");
        }
        CFTree<?> newTree = this.cffactory.newTree(relation.getDBIDs(), relation, false);
        Duration begin = LOG.newDuration(getClass().getName() + ".modeltime").begin();
        ArrayList<?> leaves = newTree.getLeaves();
        List<?> buildInitialModels = this.initializer.buildInitialModels(leaves, this.k, newTree);
        Reference2ObjectOpenHashMap reference2ObjectOpenHashMap = new Reference2ObjectOpenHashMap(leaves.size());
        double assignProbabilitiesToInstances = assignProbabilitiesToInstances((ArrayList<? extends ClusterFeature>) leaves, (List<? extends BetulaClusterModel>) buildInitialModels, (Map<ClusterFeature, double[]>) reference2ObjectOpenHashMap);
        DoubleStatistic doubleStatistic = new DoubleStatistic(getClass().getName() + ".modelloglikelihood");
        LOG.statistics(doubleStatistic.setDouble(assignProbabilitiesToInstances));
        int i = 0;
        int i2 = 0;
        double d = Double.NEGATIVE_INFINITY;
        do {
            i++;
            if (i >= this.maxiter && this.maxiter >= 0) {
                break;
            }
            double d2 = assignProbabilitiesToInstances;
            recomputeCovarianceMatrices(leaves, reference2ObjectOpenHashMap, buildInitialModels, this.prior, newTree.getRoot().getCF().getWeight());
            assignProbabilitiesToInstances = assignProbabilitiesToInstances((ArrayList<? extends ClusterFeature>) leaves, (List<? extends BetulaClusterModel>) buildInitialModels, (Map<ClusterFeature, double[]>) reference2ObjectOpenHashMap);
            LOG.statistics(doubleStatistic.setDouble(assignProbabilitiesToInstances));
            if (assignProbabilitiesToInstances - d > this.delta) {
                i2 = i;
                d = assignProbabilitiesToInstances;
            }
            if (Math.abs(assignProbabilitiesToInstances - d2) <= this.delta) {
                break;
            }
        } while (i2 >= (i >> 1));
        LOG.statistics(new LongStatistic(getClass().getName() + ".iterations", i));
        LOG.statistics(begin.end());
        ArrayList arrayList = new ArrayList(this.k);
        for (int i3 = 0; i3 < this.k; i3++) {
            arrayList.add(DBIDUtil.newArray());
        }
        WritableDataStore<double[]> makeStorage = DataStoreUtil.makeStorage(relation.getDBIDs(), 10, double[].class);
        LOG.statistics(new DoubleStatistic(getClass().getName() + ".loglikelihood", assignProbabilitiesToInstances((Relation<? extends NumberVector>) relation, (List<? extends BetulaClusterModel>) buildInitialModels, makeStorage)));
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            ((ModifiableDBIDs) arrayList.get(VMath.argmax((double[]) makeStorage.get(iterDBIDs)))).add(iterDBIDs);
            iterDBIDs.advance();
        }
        Clustering<EMModel> clustering = new Clustering<>();
        Metadata.of(clustering).setLongName("EM Clustering");
        for (int i4 = 0; i4 < this.k; i4++) {
            clustering.addToplevelCluster(new Cluster<>((DBIDs) arrayList.get(i4), buildInitialModels.get(i4).finalizeCluster()));
        }
        if (isSoft()) {
            Metadata.hierarchyOf(clustering).addChild(new MaterializedRelation("EM Cluster Probabilities", SOFT_TYPE, relation.getDBIDs(), makeStorage));
        }
        return clustering;
    }

    private boolean isSoft() {
        return this.soft;
    }

    public double assignProbabilitiesToInstances(ArrayList<? extends ClusterFeature> arrayList, List<? extends BetulaClusterModel> list, Map<ClusterFeature, double[]> map) {
        int size = list.size();
        double d = 0.0d;
        int i = 0;
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            ClusterFeature clusterFeature = arrayList.get(i2);
            double[] dArr = new double[size];
            for (int i3 = 0; i3 < size; i3++) {
                double estimateLogDensity = list.get(i3).estimateLogDensity((BetulaClusterModel) clusterFeature);
                dArr[i3] = estimateLogDensity > MIN_LOGLIKELIHOOD ? estimateLogDensity : MIN_LOGLIKELIHOOD;
            }
            double logSumExp = EM.logSumExp(dArr);
            for (int i4 = 0; i4 < size; i4++) {
                dArr[i4] = FastMath.exp(dArr[i4] - logSumExp);
            }
            map.put(clusterFeature, dArr);
            d += logSumExp * clusterFeature.getWeight();
            i += clusterFeature.getWeight();
        }
        return d / i;
    }

    public double assignProbabilitiesToInstances(Relation<? extends NumberVector> relation, List<? extends BetulaClusterModel> list, WritableDataStore<double[]> writableDataStore) {
        int size = list.size();
        double d = 0.0d;
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            NumberVector numberVector = (NumberVector) relation.get(iterDBIDs);
            double[] dArr = new double[size];
            for (int i = 0; i < size; i++) {
                double estimateLogDensity = list.get(i).estimateLogDensity((BetulaClusterModel) numberVector);
                dArr[i] = estimateLogDensity > MIN_LOGLIKELIHOOD ? estimateLogDensity : MIN_LOGLIKELIHOOD;
            }
            double logSumExp = EM.logSumExp(dArr);
            for (int i2 = 0; i2 < size; i2++) {
                dArr[i2] = FastMath.exp(dArr[i2] - logSumExp);
            }
            writableDataStore.put(iterDBIDs, dArr);
            d += logSumExp;
            iterDBIDs.advance();
        }
        return d / relation.size();
    }

    public void recomputeCovarianceMatrices(ArrayList<? extends ClusterFeature> arrayList, Map<ClusterFeature, double[]> map, List<? extends BetulaClusterModel> list, double d, int i) {
        double d2;
        double d3;
        int size = list.size();
        boolean z = false;
        for (BetulaClusterModel betulaClusterModel : list) {
            betulaClusterModel.beginEStep();
            z |= betulaClusterModel.needsTwoPass();
        }
        if (z) {
            throw new IllegalStateException("Not Implemented");
        }
        double[] dArr = new double[size];
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            ClusterFeature clusterFeature = arrayList.get(i2);
            double[] dArr2 = map.get(clusterFeature);
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                double d4 = dArr2[i3];
                if (d4 > 1.0E-10d) {
                    list.get(i3).updateE(clusterFeature, d4 * clusterFeature.getWeight());
                }
                int i4 = i3;
                dArr[i4] = dArr[i4] + (d4 * clusterFeature.getWeight());
            }
        }
        for (int i5 = 0; i5 < list.size(); i5++) {
            if (d <= 0.0d) {
                d2 = dArr[i5];
                d3 = i;
            } else {
                d2 = (dArr[i5] + d) - 1.0d;
                d3 = (i + (d * size)) - size;
            }
            list.get(i5).finalizeEStep(d2 / d3, d);
        }
    }
}
