package elki.algorithm.statistics;

import elki.Algorithm;
import elki.data.LabelList;
import elki.data.type.AlternativeTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.DoubleDBIDListIter;
import elki.database.query.QueryBuilder;
import elki.database.query.knn.KNNSearcher;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.logging.Logging;
import elki.logging.progress.FiniteProgress;
import elki.math.MeanVarianceMinMax;
import elki.result.CollectionResult;
import elki.result.Metadata;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;
import java.util.ArrayList;

/* loaded from: input_file:elki/algorithm/statistics/AveragePrecisionAtK.class */
public class AveragePrecisionAtK<O> implements Algorithm {
    private static final Logging LOG = Logging.getLogger(AveragePrecisionAtK.class);
    private Distance<? super O> distance;
    private int k;
    private double sampling;
    private RandomFactory random;
    private boolean includeSelf;

    /* loaded from: input_file:elki/algorithm/statistics/AveragePrecisionAtK$Par.class */
    public static class Par<O> implements Parameterizer {
        private static final OptionID K_ID = new OptionID("avep.k", "K to compute the average precision at.");
        public static final OptionID SAMPLING_ID = new OptionID("avep.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("avep.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("avep.includeself", "Include the query object in the evaluation.");
        protected Distance<? super O> distance;
        protected int k = 20;
        protected double sampling = 1.0d;
        protected RandomFactory seed = null;
        protected boolean includeSelf;

        public void configure(Parameterization parameterization) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class).grab(parameterization, distance -> {
                this.distance = distance;
            });
            new IntParameter(K_ID).addConstraint(CommonConstraints.GREATER_THAN_ONE_INT).grab(parameterization, i -> {
                this.k = i;
            });
            new DoubleParameter(SAMPLING_ID).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).addConstraint(CommonConstraints.LESS_EQUAL_ONE_DOUBLE).setOptional(true).grab(parameterization, d -> {
                this.sampling = d;
            });
            new RandomParameter(SEED_ID).grab(parameterization, randomFactory -> {
                this.seed = randomFactory;
            });
            new Flag(INCLUDESELF_ID).grab(parameterization, z -> {
                this.includeSelf = z;
            });
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public AveragePrecisionAtK<O> m16make() {
            return new AveragePrecisionAtK<>(this.distance, this.k, this.sampling, this.seed, this.includeSelf);
        }
    }

    public AveragePrecisionAtK(Distance<? super O> distance, int i, double d, RandomFactory randomFactory, boolean z) {
        this.sampling = 1.0d;
        this.random = null;
        this.distance = distance;
        this.k = i;
        this.sampling = d;
        this.random = randomFactory;
        this.includeSelf = z;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{this.distance.getInputTypeRestriction(), new AlternativeTypeInformation(new TypeInformation[]{TypeUtil.CLASSLABEL, TypeUtil.LABELLIST})});
    }

    public CollectionResult<double[]> run(Relation<O> relation, Relation<?> relation2) {
        int i = this.k + (this.includeSelf ? 0 : 1);
        KNNSearcher kNNByDBID = new QueryBuilder(relation, this.distance).kNNByDBID(i);
        DBIDs randomSample = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        MeanVarianceMinMax[] newArray = MeanVarianceMinMax.newArray(this.k);
        FiniteProgress finiteProgress = LOG.isVerbose() ? new FiniteProgress("Computing nearest neighbors", randomSample.size(), LOG) : null;
        DBIDIter iter = randomSample.iter();
        while (iter.valid()) {
            Object obj = relation2.get(iter);
            int i2 = 0;
            int i3 = 0;
            DoubleDBIDListIter iter2 = kNNByDBID.getKNN(iter, i).iter();
            while (i3 < this.k && iter2.valid()) {
                if (this.includeSelf || !DBIDUtil.equal(iter, iter2)) {
                    i2 += match(obj, relation2.get(iter2)) ? 1 : 0;
                    newArray[i3].put(i2 / (i3 + 1));
                    i3++;
                }
                iter2.advance();
            }
            LOG.incrementProcessed(finiteProgress);
            iter.advance();
        }
        LOG.ensureCompleted(finiteProgress);
        ArrayList arrayList = new ArrayList(this.k);
        for (int i4 = 0; i4 < this.k; i4++) {
            MeanVarianceMinMax meanVarianceMinMax = newArray[i4];
            arrayList.add(new double[]{i4 + 1, meanVarianceMinMax.getMean(), meanVarianceMinMax.getCount() > 1.0d ? meanVarianceMinMax.getSampleStddev() : 0.0d, meanVarianceMinMax.getMin(), meanVarianceMinMax.getMax(), meanVarianceMinMax.getCount()});
        }
        CollectionResult<double[]> collectionResult = new CollectionResult<>(arrayList);
        Metadata.of(collectionResult).setLongName("Average Precision");
        return collectionResult;
    }

    protected static boolean match(Object obj, Object obj2) {
        if (obj == null) {
            return false;
        }
        if (obj == obj2) {
            return true;
        }
        if ((obj instanceof LabelList) && (obj2 instanceof LabelList)) {
            LabelList labelList = (LabelList) obj;
            LabelList labelList2 = (LabelList) obj2;
            int size = labelList.size();
            int size2 = labelList2.size();
            if (size == 0 || size2 == 0) {
                return false;
            }
            for (int i = 0; i < size; i++) {
                String str = labelList.get(i);
                if (str != null) {
                    for (int i2 = 0; i2 < size2; i2++) {
                        if (str.equals(labelList2.get(i2))) {
                            return true;
                        }
                    }
                }
            }
        }
        return obj.equals(obj2);
    }
}
