package elki.projection;

import elki.Algorithm;
import elki.data.DoubleVector;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.data.type.VectorFieldTypeInformation;
import elki.database.Database;
import elki.database.datastore.DataStoreFactory;
import elki.database.datastore.WritableDataStore;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDs;
import elki.database.relation.MaterializedRelation;
import elki.database.relation.Relation;
import elki.logging.Logging;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.Duration;
import elki.logging.statistics.LongStatistic;
import elki.math.MathUtil;
import elki.utilities.documentation.Reference;
import elki.utilities.exceptions.AbortException;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;
import java.util.Arrays;
import java.util.Random;

@Reference(authors = "G. Hinton, S. Roweis", title = "Stochastic Neighbor Embedding", booktitle = "Advances in Neural Information Processing Systems 15", url = "http://papers.nips.cc/paper/2276-stochastic-neighbor-embedding", bibkey = "DBLP:conf/nips/HintonR02")
/* loaded from: input_file:elki/projection/SNE.class */
public class SNE<O> extends AbstractProjectionAlgorithm<Relation<DoubleVector>> {
    private static final Logging LOG;
    protected static final double MIN_QIJ = 1.0E-12d;
    protected static final double INITIAL_SOLUTION_SCALE = 1.0E-4d;
    protected static final double MIN_GAIN = 0.01d;
    protected AffinityMatrixBuilder<? super O> affinity;
    protected long projectedDistances;
    protected int dim;
    protected double learningRate;
    protected double initialMomentum;
    protected double finalMomentum;
    protected int momentumSwitch;
    protected int iterations;
    protected RandomFactory random;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:elki/projection/SNE$Par.class */
    public static class Par<O> implements Parameterizer {
        public static final OptionID AFFINITY_ID = new OptionID("tsne.affinity", "Affinity matrix builder.");
        public static final OptionID DIM_ID = new OptionID("tsne.dim", "Output dimensionality.");
        public static final OptionID MOMENTUM_ID = new OptionID("tsne.momentum", "The final momentum to use.");
        public static final OptionID LEARNING_RATE_ID = new OptionID("tsne.learningrate", "Learning rate of the method.");
        public static final OptionID ITER_ID = new OptionID("tsne.iter", "Number of iterations to perform.");
        public static final OptionID RANDOM_ID = new OptionID("tsne.seed", "Random generator seed");
        protected AffinityMatrixBuilder<? super O> affinity;
        protected int dim;
        protected double learningRate;
        protected double finalMomentum;
        protected int iterations;
        protected RandomFactory random;
        protected boolean keep;

        public void configure(Parameterization parameterization) {
            new ObjectParameter(AFFINITY_ID, AffinityMatrixBuilder.class, getDefaultAffinity()).grab(parameterization, affinityMatrixBuilder -> {
                this.affinity = affinityMatrixBuilder;
            });
            new IntParameter(DIM_ID).setDefaultValue(2).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).grab(parameterization, i -> {
                this.dim = i;
            });
            new DoubleParameter(MOMENTUM_ID).setDefaultValue(Double.valueOf(0.8d)).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).addConstraint(CommonConstraints.LESS_EQUAL_ONE_DOUBLE).grab(parameterization, d -> {
                this.finalMomentum = d;
            });
            new DoubleParameter(LEARNING_RATE_ID).setDefaultValue(Double.valueOf(200.0d)).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).grab(parameterization, d2 -> {
                this.learningRate = d2;
            });
            new IntParameter(ITER_ID).setDefaultValue(1000).addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT).grab(parameterization, i2 -> {
                this.iterations = i2;
            });
            new RandomParameter(RANDOM_ID).grab(parameterization, randomFactory -> {
                this.random = randomFactory;
            });
            new Flag(AbstractProjectionAlgorithm.KEEP_ID).grab(parameterization, z -> {
                this.keep = z;
            });
        }

        protected Class<?> getDefaultAffinity() {
            return PerplexityAffinityMatrixBuilder.class;
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public SNE<O> m177make() {
            return new SNE<>(this.affinity, this.dim, this.finalMomentum, this.learningRate, this.iterations, this.random, this.keep);
        }
    }

    public SNE(AffinityMatrixBuilder<? super O> affinityMatrixBuilder, int i, RandomFactory randomFactory) {
        this(affinityMatrixBuilder, i, 0.8d, 200.0d, 1000, randomFactory, true);
    }

    public SNE(AffinityMatrixBuilder<? super O> affinityMatrixBuilder, int i, double d, double d2, int i2, RandomFactory randomFactory, boolean z) {
        super(z);
        this.momentumSwitch = 250;
        this.affinity = affinityMatrixBuilder;
        this.dim = i;
        this.iterations = i2;
        this.learningRate = d2;
        this.initialMomentum = d >= 0.6d ? 0.5d : 0.5d * d;
        this.finalMomentum = d;
        this.momentumSwitch = i2 / 4;
        this.random = randomFactory;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{this.affinity.getInputTypeRestriction()});
    }

    /* renamed from: autorun, reason: merged with bridge method [inline-methods] */
    public Relation<DoubleVector> m175autorun(Database database) {
        return (Relation) Algorithm.Utils.autorun(this, database);
    }

    public Relation<DoubleVector> run(Relation<O> relation) {
        AffinityMatrix computeAffinityMatrix = this.affinity.computeAffinityMatrix(relation, 1.0d);
        double[][] randomInitialSolution = randomInitialSolution(computeAffinityMatrix.size(), this.dim, this.random.getSingleThreadedRandom());
        this.projectedDistances = 0L;
        optimizeSNE(computeAffinityMatrix, randomInitialSolution);
        LOG.statistics(new LongStatistic(getClass().getName() + ".projected-distances", this.projectedDistances));
        removePreviousRelation(relation);
        DBIDs dBIDs = relation.getDBIDs();
        WritableDataStore makeStorage = DataStoreFactory.FACTORY.makeStorage(dBIDs, 30, DoubleVector.class);
        VectorFieldTypeInformation vectorFieldTypeInformation = new VectorFieldTypeInformation(DoubleVector.FACTORY, this.dim);
        DBIDArrayIter iterDBIDs = computeAffinityMatrix.iterDBIDs();
        while (iterDBIDs.valid()) {
            makeStorage.put(iterDBIDs, DoubleVector.wrap(randomInitialSolution[iterDBIDs.getOffset()]));
            iterDBIDs.advance();
        }
        return new MaterializedRelation("SNE", vectorFieldTypeInformation, dBIDs, makeStorage);
    }

    protected static double[][] randomInitialSolution(int i, int i2, Random random) {
        double[][] dArr = new double[i][i2];
        for (int i3 = 0; i3 < i; i3++) {
            for (int i4 = 0; i4 < i2; i4++) {
                dArr[i3][i4] = random.nextGaussian() * INITIAL_SOLUTION_SCALE;
            }
        }
        return dArr;
    }

    protected void optimizeSNE(AffinityMatrix affinityMatrix, double[][] dArr) {
        int size = affinityMatrix.size();
        if (size * 3 * this.dim > 2147483642) {
            throw new AbortException("Memory exceeds Java array size limit.");
        }
        double[] dArr2 = new double[size * 3 * this.dim];
        int i = this.dim * 3;
        int i2 = 2 * this.dim;
        while (true) {
            int i3 = i2;
            if (i3 >= dArr2.length) {
                break;
            }
            Arrays.fill(dArr2, i3, i3 + this.dim, 1.0d);
            i2 = i3 + i;
        }
        double[][] dArr3 = new double[size][size];
        FiniteProgress finiteProgress = LOG.isVerbose() ? new FiniteProgress("Iterative Optimization", this.iterations, LOG) : null;
        Duration begin = LOG.isStatistics() ? LOG.newDuration(getClass().getName() + ".runtime.optimization").begin() : null;
        for (int i4 = 0; i4 < this.iterations; i4++) {
            computeGradient(affinityMatrix, dArr3, 1.0d / computeQij(dArr3, dArr), dArr, dArr2);
            updateSolution(dArr, dArr2, i4);
            LOG.incrementProcessed(finiteProgress);
        }
        LOG.ensureCompleted(finiteProgress);
        if (begin != null) {
            LOG.statistics(begin.end());
        }
    }

    protected double computeQij(double[][] dArr, double[][] dArr2) {
        double d = 0.0d;
        for (int i = 1; i < dArr.length; i++) {
            double[] dArr3 = dArr[i];
            double[] dArr4 = dArr2[i];
            for (int i2 = 0; i2 < i; i2++) {
                double exp = MathUtil.exp(-sqDist(dArr4, dArr2[i2]));
                dArr[i2][i] = exp;
                dArr3[i2] = exp;
                d += exp;
            }
        }
        return d * 2.0d;
    }

    protected double sqDist(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError("Lengths do not agree: " + dArr.length + " " + dArr2.length);
        }
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            double d2 = dArr[i] - dArr2[i];
            d += d2 * d2;
        }
        this.projectedDistances++;
        return d;
    }

    protected void computeGradient(AffinityMatrix affinityMatrix, double[][] dArr, double d, double[][] dArr2, double[] dArr3) {
        int i = this.dim * 3;
        int size = affinityMatrix.size();
        int i2 = 0;
        int i3 = 0;
        while (true) {
            int i4 = i3;
            if (i2 >= size) {
                return;
            }
            double[] dArr4 = dArr2[i2];
            double[] dArr5 = dArr[i2];
            Arrays.fill(dArr3, i4, i4 + this.dim, 0.0d);
            for (int i5 = 0; i5 < size; i5++) {
                if (i2 != i5) {
                    double[] dArr6 = dArr2[i5];
                    double max = 4.0d * (affinityMatrix.get(i2, i5) - MathUtil.max(dArr5[i5] * d, MIN_QIJ));
                    for (int i6 = 0; i6 < this.dim; i6++) {
                        int i7 = i4 + i6;
                        dArr3[i7] = dArr3[i7] + (max * (dArr4[i6] - dArr6[i6]));
                    }
                }
            }
            i2++;
            i3 = i4 + i;
        }
    }

    protected void updateSolution(double[][] dArr, double[] dArr2, int i) {
        double d = (i >= this.momentumSwitch || this.initialMomentum >= this.finalMomentum) ? this.finalMomentum : this.initialMomentum;
        int i2 = this.dim * 3;
        int i3 = 0;
        int i4 = 0;
        while (true) {
            int i5 = i4;
            if (i3 >= dArr.length) {
                return;
            }
            double[] dArr3 = dArr[i3];
            for (int i6 = 0; i6 < this.dim; i6++) {
                int i7 = i5 + i6;
                int i8 = i7 + this.dim;
                int i9 = i8 + this.dim;
                dArr2[i9] = MathUtil.max(((dArr2[i7] > 0.0d ? 1 : (dArr2[i7] == 0.0d ? 0 : -1)) > 0) != ((dArr2[i8] > 0.0d ? 1 : (dArr2[i8] == 0.0d ? 0 : -1)) > 0) ? dArr2[i9] + 0.2d : dArr2[i9] * 0.8d, MIN_GAIN);
                dArr2[i8] = dArr2[i8] * d;
                dArr2[i8] = dArr2[i8] - ((this.learningRate * dArr2[i7]) * dArr2[i9]);
                int i10 = i6;
                dArr3[i10] = dArr3[i10] + dArr2[i8];
            }
            i3++;
            i4 = i5 + i2;
        }
    }

    static {
        $assertionsDisabled = !SNE.class.desiredAssertionStatus();
        LOG = Logging.getLogger(SNE.class);
    }
}
