package smile.neighbor;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;
import smile.math.MathEx;
import smile.neighbor.lsh.Bucket;
import smile.neighbor.lsh.Hash;
import smile.sort.HeapSelect;
import smile.util.IntArrayList;

/* loaded from: input_file:smile/neighbor/LSH.class */
public class LSH<E> implements NearestNeighborSearch<double[], E>, KNNSearch<double[], E>, RNNSearch<double[], E>, Serializable {
    private static final long serialVersionUID = 2;
    protected ArrayList<double[]> keys;
    protected ArrayList<E> data;
    protected List<Hash> hash;
    protected int H;
    protected int k;
    protected double w;

    public LSH(double[][] dArr, E[] eArr, double d) {
        this(dArr, eArr, d, 1017881);
    }

    public LSH(double[][] dArr, E[] eArr, double d, int i) {
        this(dArr[0].length, Math.max(50, (int) Math.pow(dArr.length, 0.25d)), Math.max(3, (int) Math.log10(dArr.length)), d, i);
        if (dArr.length != eArr.length) {
            throw new IllegalArgumentException("The array size of keys and data are different.");
        }
        if (i < dArr.length) {
            throw new IllegalArgumentException("Hash table size is too small: " + i);
        }
        int length = dArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            put(dArr[i2], eArr[i2]);
        }
    }

    public LSH(int i, int i2, int i3, double d) {
        this(i, i2, i3, d, 1017881);
    }

    public LSH(int i, int i2, int i3, double d, int i4) {
        if (i < 2) {
            throw new IllegalArgumentException("Invalid input space dimension: " + i);
        }
        if (i2 < 1) {
            throw new IllegalArgumentException("Invalid number of hash tables: " + i2);
        }
        if (i3 < 1) {
            throw new IllegalArgumentException("Invalid number of random projections per hash value: " + i3);
        }
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid width of random projections: " + d);
        }
        if (i4 < 1) {
            throw new IllegalArgumentException("Invalid size of hash tables: " + i4);
        }
        this.k = i3;
        this.w = d;
        this.H = i4;
        this.keys = new ArrayList<>();
        this.data = new ArrayList<>();
        initHashTable(i, i2, i3, d, i4);
    }

    protected void initHashTable(int i, int i2, int i3, double d, int i4) {
        this.hash = new ArrayList(i2);
        for (int i5 = 0; i5 < i2; i5++) {
            this.hash.add(new Hash(i, i3, d, i4));
        }
    }

    public String toString() {
        return String.format("LSH(L=%d, k=%d, H=%d, w=%.4f)", Integer.valueOf(this.hash.size()), Integer.valueOf(this.k), Integer.valueOf(this.H), Double.valueOf(this.w));
    }

    public void put(double[] dArr, E e) {
        int size = this.keys.size();
        this.keys.add(dArr);
        this.data.add(e);
        Iterator<Hash> it = this.hash.iterator();
        while (it.hasNext()) {
            it.next().add(size, dArr);
        }
    }

    @Override // smile.neighbor.NearestNeighborSearch
    public Neighbor<double[], E> nearest(double[] dArr) {
        int i = -1;
        double d = Double.MAX_VALUE;
        Iterator<Integer> it = getCandidates(dArr).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double[] dArr2 = this.keys.get(intValue);
            if (dArr != dArr2) {
                double distance = MathEx.distance(dArr, dArr2);
                if (distance < d) {
                    i = intValue;
                    d = distance;
                }
            }
        }
        if (i == -1) {
            return null;
        }
        return new Neighbor<>(this.keys.get(i), this.data.get(i), i, d);
    }

    @Override // smile.neighbor.KNNSearch
    public Neighbor<double[], E>[] knn(double[] dArr, int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Invalid k: " + i);
        }
        Set<Integer> candidates = getCandidates(dArr);
        HeapSelect heapSelect = new HeapSelect(new Neighbor[Math.min(i, candidates.size())]);
        Iterator<Integer> it = candidates.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double[] dArr2 = this.keys.get(intValue);
            if (dArr != dArr2) {
                heapSelect.add(new Neighbor(dArr2, this.data.get(intValue), intValue, MathEx.distance(dArr, dArr2)));
            }
        }
        heapSelect.sort();
        return (Neighbor[]) heapSelect.toArray();
    }

    @Override // smile.neighbor.RNNSearch
    public void range(double[] dArr, double d, List<Neighbor<double[], E>> list) {
        if (d <= 0.0d) {
            throw new IllegalArgumentException("Invalid radius: " + d);
        }
        Iterator<Integer> it = getCandidates(dArr).iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            double[] dArr2 = this.keys.get(intValue);
            if (dArr != dArr2) {
                double distance = MathEx.distance(dArr, dArr2);
                if (distance <= d) {
                    list.add(new Neighbor<>(dArr2, this.data.get(intValue), intValue, distance));
                }
            }
        }
    }

    private Set<Integer> getCandidates(double[] dArr) {
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        Iterator<Hash> it = this.hash.iterator();
        while (it.hasNext()) {
            Bucket bucket = it.next().get(dArr);
            if (bucket != null) {
                IntArrayList points = bucket.points();
                int size = points.size();
                for (int i = 0; i < size; i++) {
                    linkedHashSet.add(Integer.valueOf(points.get(i)));
                }
            }
        }
        return linkedHashSet;
    }
}
