/*
 * Decompiled with CFR 0.152.
 */
package elki.utilities.scaling.outlier;

import elki.database.ids.ArrayDBIDs;
import elki.database.ids.DBIDArrayIter;
import elki.database.ids.DBIDIter;
import elki.database.ids.DBIDRef;
import elki.database.ids.DBIDUtil;
import elki.database.ids.DBIDs;
import elki.database.relation.DoubleRelation;
import elki.logging.Logging;
import elki.math.MeanVariance;
import elki.result.outlier.OutlierResult;
import elki.utilities.datastructures.BitsUtil;
import elki.utilities.datastructures.arraylike.NumberArrayAdapter;
import elki.utilities.documentation.Reference;
import elki.utilities.scaling.outlier.OutlierScaling;
import net.jafama.FastMath;

@Reference(authors="J. Gao, P.-N. Tan", title="Converting Output Scores from Outlier Detection Algorithms into Probability Estimates", booktitle="Proc. Sixth International Conference on Data Mining, 2006. ICDM'06.", url="https://doi.org/10.1109/ICDM.2006.43", bibkey="DBLP:conf/icdm/GaoT06")
public class SigmoidOutlierScaling
implements OutlierScaling {
    private static final Logging LOG = Logging.getLogger(SigmoidOutlierScaling.class);
    double Afinal;
    double Bfinal;

    @Override
    public void prepare(OutlierResult or) {
        MeanVariance mv = new MeanVariance();
        DoubleRelation scores = or.getScores();
        DBIDIter id = scores.iterDBIDs();
        while (id.valid()) {
            double val = scores.doubleValue((DBIDRef)id);
            if (Double.isFinite(val)) {
                mv.put(val);
            }
            id.advance();
        }
        double a = 1.0;
        double b = -mv.getMean();
        int iter = 0;
        ArrayDBIDs ids = DBIDUtil.ensureArray((DBIDs)or.getScores().getDBIDs());
        DBIDArrayIter it = ids.iter();
        long[] t = BitsUtil.zero((int)ids.size());
        boolean changing = true;
        while (changing) {
            changing = false;
            it.seek(0);
            for (int i = 0; i < ids.size(); ++i) {
                double val = or.getScores().doubleValue((DBIDRef)it);
                double targ = a * val + b;
                if (targ > 0.0) {
                    if (!BitsUtil.get((long[])t, (int)i)) {
                        BitsUtil.setI((long[])t, (int)i);
                        changing = true;
                    }
                } else if (BitsUtil.get((long[])t, (int)i)) {
                    BitsUtil.clearI((long[])t, (int)i);
                    changing = true;
                }
                it.advance();
            }
            if (!changing) break;
            double[] newab = this.MStepLevenbergMarquardt(a, b, ids, t, or.getScores());
            a = newab[0];
            b = newab[1];
            if (++iter <= 100) continue;
            LOG.warning((CharSequence)"Max iterations met in sigmoid fitting.");
            break;
        }
        this.Afinal = a;
        this.Bfinal = b;
        LOG.debugFine((CharSequence)("A = " + this.Afinal + " B = " + this.Bfinal));
    }

    @Override
    public <A> void prepare(A array, NumberArrayAdapter<?, A> adapter) {
        MeanVariance mv = new MeanVariance();
        int size = adapter.size(array);
        for (int i = 0; i < size; ++i) {
            double val = adapter.getDouble(array, i);
            if (!Double.isFinite(val)) continue;
            mv.put(val);
        }
        double a = 1.0;
        double b = -mv.getMean();
        int iter = 0;
        long[] t = BitsUtil.zero((int)size);
        boolean changing = true;
        while (changing) {
            changing = false;
            for (int i = 0; i < size; ++i) {
                double val = adapter.getDouble(array, i);
                double targ = a * val + b;
                if (targ > 0.0) {
                    if (BitsUtil.get((long[])t, (int)i)) continue;
                    BitsUtil.setI((long[])t, (int)i);
                    changing = true;
                    continue;
                }
                if (!BitsUtil.get((long[])t, (int)i)) continue;
                BitsUtil.clearI((long[])t, (int)i);
                changing = true;
            }
            if (!changing) break;
            double[] newab = this.MStepLevenbergMarquardt(a, b, t, array, adapter);
            a = newab[0];
            b = newab[1];
            if (++iter <= 100) continue;
            LOG.warning((CharSequence)"Max iterations met in sigmoid fitting.");
            break;
        }
        this.Afinal = a;
        this.Bfinal = b;
        LOG.debugFine((CharSequence)("A = " + this.Afinal + " B = " + this.Bfinal));
    }

    private double[] MStepLevenbergMarquardt(double a, double b, ArrayDBIDs ids, long[] t, DoubleRelation scores) {
        int prior1 = BitsUtil.cardinality((long[])t);
        int prior0 = ids.size() - prior1;
        int maxiter = 10;
        double minstep = 1.0E-8;
        double sigma = 1.0E-12;
        double loTarget = ((double)prior1 + 1.0) / ((double)prior1 + 2.0);
        double hiTarget = 1.0 / ((double)prior0 + 2.0);
        double fval = 0.0;
        DBIDArrayIter iter = ids.iter();
        for (int i = 0; i < ids.size(); ++i) {
            double ti;
            double val = scores.doubleValue((DBIDRef)iter);
            double fApB = val * a + b;
            double d = ti = BitsUtil.get((long[])t, (int)i) ? hiTarget : loTarget;
            fval = fApB >= 0.0 ? (fval += ti * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)(-fApB))))) : (fval += (ti - 1.0) * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)fApB))));
            iter.advance();
        }
        block1: for (int it = 0; it < 10; ++it) {
            double h11 = 1.0E-12;
            double h22 = 1.0E-12;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            iter.seek(0);
            for (int i = 0; i < ids.size(); ++i) {
                double q;
                double p;
                double val = scores.doubleValue((DBIDRef)iter);
                double fApB = val * a + b;
                if (fApB >= 0.0) {
                    p = FastMath.exp((double)(-fApB)) / (1.0 + FastMath.exp((double)(-fApB)));
                    q = 1.0 / (1.0 + FastMath.exp((double)(-fApB)));
                } else {
                    p = 1.0 / (1.0 + FastMath.exp((double)fApB));
                    q = FastMath.exp((double)fApB) / (1.0 + FastMath.exp((double)fApB));
                }
                double d2 = p * q;
                h11 += val * val * d2;
                h22 += d2;
                h21 += val * d2;
                double d1 = (BitsUtil.get((long[])t, (int)i) ? hiTarget : loTarget) - p;
                g1 += val * d1;
                g2 += d1;
                iter.advance();
            }
            if (Math.abs(g1) < 1.0E-5 && Math.abs(g2) < 1.0E-5) break;
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;
            double stepsize = 1.0;
            while (stepsize >= 1.0E-8) {
                double newA = a + stepsize * dA;
                double newB = b + stepsize * dB;
                double newf = 0.0;
                iter.seek(0);
                for (int i = 0; i < ids.size(); ++i) {
                    double ti;
                    double val = scores.doubleValue((DBIDRef)iter);
                    double fApB = val * newA + newB;
                    double d = ti = BitsUtil.get((long[])t, (int)i) ? hiTarget : loTarget;
                    newf = fApB >= 0.0 ? (newf += ti * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)(-fApB))))) : (newf += (ti - 1.0) * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)fApB))));
                    iter.advance();
                }
                if (newf < fval + 1.0E-4 * stepsize * gd) {
                    a = newA;
                    b = newB;
                    fval = newf;
                    continue block1;
                }
                if (!((stepsize /= 2.0) < 1.0E-8)) continue;
                LOG.debug((CharSequence)"Minstep hit.");
                continue block1;
            }
        }
        return new double[]{a, b};
    }

    private <A> double[] MStepLevenbergMarquardt(double a, double b, long[] t, A array, NumberArrayAdapter<?, A> adapter) {
        int size = adapter.size(array);
        int prior1 = BitsUtil.cardinality((long[])t);
        int prior0 = size - prior1;
        int maxiter = 10;
        double minstep = 1.0E-8;
        double sigma = 1.0E-12;
        double loTarget = ((double)prior1 + 1.0) / ((double)prior1 + 2.0);
        double hiTarget = 1.0 / ((double)prior0 + 2.0);
        double fval = 0.0;
        for (int i = 0; i < size; ++i) {
            double ti;
            double val = adapter.getDouble(array, i);
            double fApB = val * a + b;
            double d = ti = BitsUtil.get((long[])t, (int)i) ? hiTarget : loTarget;
            if (fApB >= 0.0) {
                fval += ti * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)(-fApB))));
                continue;
            }
            fval += (ti - 1.0) * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)fApB)));
        }
        block1: for (int it = 0; it < 10; ++it) {
            double h11 = 1.0E-12;
            double h22 = 1.0E-12;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            for (int i = 0; i < size; ++i) {
                double q;
                double p;
                double val = adapter.getDouble(array, i);
                double fApB = val * a + b;
                if (fApB >= 0.0) {
                    p = FastMath.exp((double)(-fApB)) / (1.0 + FastMath.exp((double)(-fApB)));
                    q = 1.0 / (1.0 + FastMath.exp((double)(-fApB)));
                } else {
                    p = 1.0 / (1.0 + FastMath.exp((double)fApB));
                    q = FastMath.exp((double)fApB) / (1.0 + FastMath.exp((double)fApB));
                }
                double d2 = p * q;
                h11 += val * val * d2;
                h22 += d2;
                h21 += val * d2;
                double d1 = (BitsUtil.get((long[])t, (int)i) ? hiTarget : loTarget) - p;
                g1 += val * d1;
                g2 += d1;
            }
            if (Math.abs(g1) < 1.0E-5 && Math.abs(g2) < 1.0E-5) break;
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;
            double stepsize = 1.0;
            while (stepsize >= 1.0E-8) {
                double newA = a + stepsize * dA;
                double newB = b + stepsize * dB;
                double newf = 0.0;
                for (int i = 0; i < size; ++i) {
                    double ti;
                    double val = adapter.getDouble(array, i);
                    double fApB = val * newA + newB;
                    double d = ti = BitsUtil.get((long[])t, (int)i) ? hiTarget : loTarget;
                    if (fApB >= 0.0) {
                        newf += ti * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)(-fApB))));
                        continue;
                    }
                    newf += (ti - 1.0) * fApB + FastMath.log((double)(1.0 + FastMath.exp((double)fApB)));
                }
                if (newf < fval + 1.0E-4 * stepsize * gd) {
                    a = newA;
                    b = newB;
                    fval = newf;
                    continue block1;
                }
                if (!((stepsize /= 2.0) < 1.0E-8)) continue;
                LOG.debug((CharSequence)"Minstep hit.");
                continue block1;
            }
        }
        return new double[]{a, b};
    }

    public double getMax() {
        return 1.0;
    }

    public double getMin() {
        return 0.0;
    }

    public double getScaled(double value) {
        return 1.0 / (1.0 + FastMath.exp((double)(-this.Afinal * value - this.Bfinal)));
    }
}

