package elki.algorithm.statistics;

import elki.Algorithm;
import elki.data.LabelList;
import elki.data.type.AlternativeTypeInformation;
import elki.data.type.TypeInformation;
import elki.data.type.TypeUtil;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.ids.DoubleDBIDListMIter;
import elki.database.ids.HashSetModifiableDBIDs;
import elki.database.ids.ModifiableDBIDs;
import elki.database.ids.ModifiableDoubleDBIDList;
import elki.database.query.QueryBuilder;
import elki.database.query.distance.DistanceQuery;
import elki.database.relation.Relation;
import elki.distance.Distance;
import elki.distance.minkowski.EuclideanDistance;
import elki.evaluation.scores.AveragePrecisionEvaluation;
import elki.evaluation.scores.ROCEvaluation;
import elki.logging.Logging;
import elki.logging.progress.FiniteProgress;
import elki.logging.statistics.DoubleStatistic;
import elki.result.textwriter.TextWriteable;
import elki.result.textwriter.TextWriterStream;
import elki.utilities.exceptions.AbortException;
import elki.utilities.optionhandling.OptionID;
import elki.utilities.optionhandling.Parameterizer;
import elki.utilities.optionhandling.constraints.CommonConstraints;
import elki.utilities.optionhandling.parameterization.Parameterization;
import elki.utilities.optionhandling.parameters.DoubleParameter;
import elki.utilities.optionhandling.parameters.Flag;
import elki.utilities.optionhandling.parameters.IntParameter;
import elki.utilities.optionhandling.parameters.ObjectParameter;
import elki.utilities.optionhandling.parameters.RandomParameter;
import elki.utilities.random.RandomFactory;
import it.unimi.dsi.fastutil.objects.Object2IntMap;
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap;
import it.unimi.dsi.fastutil.objects.ObjectIterator;

/* loaded from: input_file:elki/algorithm/statistics/EvaluateRetrievalPerformance.class */
public class EvaluateRetrievalPerformance<O> implements Algorithm {
    private static final Logging LOG = Logging.getLogger(EvaluateRetrievalPerformance.class);
    private final String PREFIX = getClass().getName();
    protected Distance<? super O> distance;
    protected double sampling;
    protected RandomFactory random;
    protected boolean includeSelf;
    protected int maxk;

    /* loaded from: input_file:elki/algorithm/statistics/EvaluateRetrievalPerformance$KNNEvaluator.class */
    public static class KNNEvaluator {
        public static final KNNEvaluator STATIC = new KNNEvaluator();

        public void evaluateKNN(double[] dArr, ModifiableDoubleDBIDList modifiableDoubleDBIDList, Relation<?> relation, Object2IntOpenHashMap<Object> object2IntOpenHashMap, Object obj) {
            int length = dArr.length;
            int i = 1;
            int i2 = 0;
            int i3 = 0;
            object2IntOpenHashMap.clear();
            DoubleDBIDListMIter iter = modifiableDoubleDBIDList.iter();
            while (iter.valid() && i2 < length) {
                double doubleValue = iter.doubleValue();
                i3 = Math.max(i3, countkNN(object2IntOpenHashMap, relation.get(iter)));
                iter.advance();
                i++;
                if (!iter.valid() || iter.doubleValue() > doubleValue) {
                    int i4 = 0;
                    int i5 = 0;
                    ObjectIterator fastIterator = object2IntOpenHashMap.object2IntEntrySet().fastIterator();
                    while (fastIterator.hasNext()) {
                        Object2IntMap.Entry entry = (Object2IntMap.Entry) fastIterator.next();
                        if (entry.getIntValue() >= i3) {
                            i5++;
                            Object key = entry.getKey();
                            if (key != null) {
                                if (key.equals(obj)) {
                                    i4++;
                                } else if (obj instanceof LabelList) {
                                    LabelList labelList = (LabelList) obj;
                                    int i6 = 0;
                                    int size = labelList.size();
                                    while (true) {
                                        if (i6 >= size) {
                                            break;
                                        }
                                        if (key.equals(labelList.get(i6))) {
                                            i4++;
                                            break;
                                        }
                                        i6++;
                                    }
                                }
                            }
                        }
                    }
                    while (i2 < i && i2 < length) {
                        int i7 = i2;
                        i2++;
                        dArr[i7] = dArr[i7] + (i4 / i5);
                    }
                }
            }
        }

        public int countkNN(Object2IntOpenHashMap<Object> object2IntOpenHashMap, Object obj) {
            if (!(obj instanceof LabelList)) {
                return object2IntOpenHashMap.addTo(obj, 1);
            }
            LabelList labelList = (LabelList) obj;
            int i = 0;
            int size = labelList.size();
            for (int i2 = 0; i2 < size; i2++) {
                i = Math.max(i, object2IntOpenHashMap.addTo(labelList.get(i2), 1));
            }
            return i;
        }
    }

    /* loaded from: input_file:elki/algorithm/statistics/EvaluateRetrievalPerformance$Par.class */
    public static class Par<O> implements Parameterizer {
        public static final OptionID SAMPLING_ID = new OptionID("map.sampling", "Relative amount of object to sample.");
        public static final OptionID SEED_ID = new OptionID("map.sampling-seed", "Random seed for deterministic sampling.");
        public static final OptionID INCLUDESELF_ID = new OptionID("map.includeself", "Include the query object in the evaluation.");
        public static final OptionID MAXK_ID = new OptionID("map.maxk", "Maximum value of k for kNN evaluation.");
        protected Distance<? super O> distance;
        protected boolean includeSelf;
        protected double sampling = 1.0d;
        protected RandomFactory seed = null;
        protected int maxk = 0;

        public void configure(Parameterization parameterization) {
            new ObjectParameter(Algorithm.Utils.DISTANCE_FUNCTION_ID, Distance.class, EuclideanDistance.class).grab(parameterization, distance -> {
                this.distance = distance;
            });
            new DoubleParameter(SAMPLING_ID).addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE).addConstraint(CommonConstraints.LESS_EQUAL_ONE_DOUBLE).setOptional(true).grab(parameterization, d -> {
                this.sampling = d;
            });
            new RandomParameter(SEED_ID).grab(parameterization, randomFactory -> {
                this.seed = randomFactory;
            });
            new Flag(INCLUDESELF_ID).grab(parameterization, z -> {
                this.includeSelf = z;
            });
            new IntParameter(MAXK_ID).addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT).setOptional(true).grab(parameterization, i -> {
                this.maxk = i;
            });
        }

        /* renamed from: make, reason: merged with bridge method [inline-methods] */
        public EvaluateRetrievalPerformance<O> m31make() {
            return new EvaluateRetrievalPerformance<>(this.distance, this.sampling, this.seed, this.includeSelf, this.maxk);
        }
    }

    /* loaded from: input_file:elki/algorithm/statistics/EvaluateRetrievalPerformance$RetrievalPerformanceResult.class */
    public static class RetrievalPerformanceResult implements TextWriteable {
        private int samplesize;
        private double map;
        private double auroc;
        private double[] knnperf;

        public RetrievalPerformanceResult(int i, double d, double d2, double[] dArr) {
            this.map = d;
            this.auroc = d2;
            this.samplesize = i;
            this.knnperf = dArr;
        }

        public double getAUROC() {
            return this.auroc;
        }

        public double getMAP() {
            return this.map;
        }

        public String getLongName() {
            return "Distance function retrieval evaluation.";
        }

        public String getShortName() {
            return "distance-retrieval-evaluation";
        }

        public void writeToText(TextWriterStream textWriterStream, String str) {
            textWriterStream.inlinePrintNoQuotes("MAP");
            textWriterStream.inlinePrint(Double.valueOf(this.map));
            textWriterStream.flush();
            textWriterStream.inlinePrintNoQuotes("AUROC");
            textWriterStream.inlinePrint(Double.valueOf(this.auroc));
            textWriterStream.flush();
            textWriterStream.inlinePrintNoQuotes("Samplesize");
            textWriterStream.inlinePrint(Integer.valueOf(this.samplesize));
            textWriterStream.flush();
            for (int i = 0; i < this.knnperf.length; i++) {
                textWriterStream.inlinePrintNoQuotes("knn-" + (i + 1));
                textWriterStream.inlinePrint(Double.valueOf(this.knnperf[i]));
                textWriterStream.flush();
            }
        }
    }

    public EvaluateRetrievalPerformance(Distance<? super O> distance, double d, RandomFactory randomFactory, boolean z, int i) {
        this.sampling = 1.0d;
        this.random = null;
        this.maxk = 100;
        this.distance = distance;
        this.sampling = d;
        this.random = randomFactory;
        this.includeSelf = z;
        this.maxk = i;
    }

    public TypeInformation[] getInputTypeRestriction() {
        return TypeUtil.array(new TypeInformation[]{this.distance.getInputTypeRestriction(), new AlternativeTypeInformation(new TypeInformation[]{TypeUtil.CLASSLABEL, TypeUtil.LABELLIST})});
    }

    public RetrievalPerformanceResult run(Relation<O> relation, Relation<?> relation2) {
        DBIDs randomSample = DBIDUtil.randomSample(relation.getDBIDs(), this.sampling, this.random);
        DistanceQuery<O> distanceQuery = new QueryBuilder(relation, this.distance).distanceQuery();
        HashSetModifiableDBIDs newHashSet = DBIDUtil.newHashSet();
        ModifiableDoubleDBIDList newDistanceDBIDList = DBIDUtil.newDistanceDBIDList(relation.size());
        Object2IntOpenHashMap<Object> object2IntOpenHashMap = new Object2IntOpenHashMap<>();
        double d = 0.0d;
        double d2 = 0.0d;
        double[] dArr = new double[this.maxk];
        int i = 0;
        FiniteProgress finiteProgress = LOG.isVerbose() ? new FiniteProgress("Processing query objects", randomSample.size(), LOG) : null;
        DBIDIter iter = randomSample.iter();
        while (iter.valid()) {
            Object obj = relation2.get(iter);
            findMatches(newHashSet.clear(), relation2, obj);
            if (newHashSet.size() > 0) {
                computeDistances(newDistanceDBIDList, iter, distanceQuery, relation);
                if (newDistanceDBIDList.size() != relation.size() - (this.includeSelf ? 0 : 1)) {
                    LOG.warning("Neighbor list does not have the desired size: " + newDistanceDBIDList.size());
                }
                d += AveragePrecisionEvaluation.STATIC.evaluate(newHashSet, newDistanceDBIDList);
                d2 += ROCEvaluation.STATIC.evaluate(newHashSet, newDistanceDBIDList);
                KNNEvaluator.STATIC.evaluateKNN(dArr, newDistanceDBIDList, relation2, object2IntOpenHashMap, obj);
                i++;
            }
            LOG.incrementProcessed(finiteProgress);
            iter.advance();
        }
        LOG.ensureCompleted(finiteProgress);
        if (i < 1) {
            throw new AbortException("No object matched - are labels parsed correctly?");
        }
        if (d < 0.0d || d2 < 0.0d) {
            throw new AbortException("NaN in MAP/ROC.");
        }
        double d3 = d / i;
        double d4 = d2 / i;
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".map", d3));
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".auroc", d4));
        LOG.statistics(new DoubleStatistic(this.PREFIX + ".samples", i));
        for (int i2 = 0; i2 < this.maxk; i2++) {
            dArr[i2] = dArr[i2] / i;
            LOG.statistics(new DoubleStatistic(this.PREFIX + ".knn-" + (i2 + 1), dArr[i2]));
        }
        return new RetrievalPerformanceResult(i, d3, d4, dArr);
    }

    protected static boolean match(Object obj, Object obj2) {
        if (obj == null) {
            return false;
        }
        if (obj == obj2) {
            return true;
        }
        if ((obj instanceof LabelList) && (obj2 instanceof LabelList)) {
            LabelList labelList = (LabelList) obj;
            LabelList labelList2 = (LabelList) obj2;
            int size = labelList.size();
            int size2 = labelList2.size();
            if (size == 0 || size2 == 0) {
                return false;
            }
            for (int i = 0; i < size; i++) {
                String str = labelList.get(i);
                if (str != null) {
                    for (int i2 = 0; i2 < size2; i2++) {
                        if (str.equals(labelList2.get(i2))) {
                            return true;
                        }
                    }
                }
            }
        }
        return obj.equals(obj2);
    }

    private void findMatches(ModifiableDBIDs modifiableDBIDs, Relation<?> relation, Object obj) {
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            if (match(obj, relation.get(iterDBIDs))) {
                modifiableDBIDs.add(iterDBIDs);
            }
            iterDBIDs.advance();
        }
    }

    private void computeDistances(ModifiableDoubleDBIDList modifiableDoubleDBIDList, DBIDIter dBIDIter, DistanceQuery<O> distanceQuery, Relation<O> relation) {
        modifiableDoubleDBIDList.clear();
        Object obj = relation.get(dBIDIter);
        DBIDIter iterDBIDs = relation.iterDBIDs();
        while (iterDBIDs.valid()) {
            if (this.includeSelf || !DBIDUtil.equal(iterDBIDs, dBIDIter)) {
                double distance = distanceQuery.distance(obj, iterDBIDs);
                if (distance != distance) {
                    distance = Double.POSITIVE_INFINITY;
                }
                modifiableDoubleDBIDList.add(distance, iterDBIDs);
            }
            iterDBIDs.advance();
        }
        modifiableDoubleDBIDList.sort();
    }
}
