package smile.classification;

import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.classification.DecisionTree;
import smile.data.Attribute;
import smile.data.AttributeDataset;
import smile.data.NumericAttribute;
import smile.math.Math;
import smile.util.SmileUtils;
import smile.validation.Accuracy;
import smile.validation.ClassificationMeasure;

/* loaded from: input_file:smile/classification/AdaBoost.class */
public class AdaBoost implements SoftClassifier<double[]> {
    private static final long serialVersionUID = 1;
    private static final Logger logger = LoggerFactory.getLogger(AdaBoost.class);
    private static final String INVALID_NUMBER_OF_TREES = "Invalid number of trees: ";
    private int k;
    private DecisionTree[] trees;
    private double[] alpha;
    private double[] error;
    private double[] importance;

    /* loaded from: input_file:smile/classification/AdaBoost$Trainer.class */
    public static class Trainer extends ClassifierTrainer<double[]> {
        private int ntrees;
        private int maxNodes;

        public Trainer() {
            this.ntrees = 500;
            this.maxNodes = 2;
        }

        public Trainer(int i) {
            this.ntrees = 500;
            this.maxNodes = 2;
            if (i < 1) {
                throw new IllegalArgumentException(AdaBoost.INVALID_NUMBER_OF_TREES + i);
            }
            this.ntrees = i;
        }

        public Trainer(Attribute[] attributeArr, int i) {
            super(attributeArr);
            this.ntrees = 500;
            this.maxNodes = 2;
            if (i < 1) {
                throw new IllegalArgumentException(AdaBoost.INVALID_NUMBER_OF_TREES + i);
            }
            this.ntrees = i;
        }

        public Trainer setNumTrees(int i) {
            if (i < 1) {
                throw new IllegalArgumentException(AdaBoost.INVALID_NUMBER_OF_TREES + i);
            }
            this.ntrees = i;
            return this;
        }

        public Trainer setMaxNodes(int i) {
            if (i < 2) {
                throw new IllegalArgumentException("Invalid maximum number of leaf nodes: " + i);
            }
            this.maxNodes = i;
            return this;
        }

        @Override // smile.classification.ClassifierTrainer
        public AdaBoost train(double[][] dArr, int[] iArr) {
            return new AdaBoost(this.attributes, dArr, iArr, this.ntrees, this.maxNodes);
        }
    }

    public AdaBoost(double[][] dArr, int[] iArr, int i) {
        this((Attribute[]) null, dArr, iArr, i);
    }

    public AdaBoost(double[][] dArr, int[] iArr, int i, int i2) {
        this(null, dArr, iArr, i, i2);
    }

    public AdaBoost(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i) {
        this(attributeArr, dArr, iArr, i, 2);
    }

    public AdaBoost(AttributeDataset attributeDataset, int i) {
        this(attributeDataset.attributes(), attributeDataset.x(), attributeDataset.labels(), i);
    }

    public AdaBoost(AttributeDataset attributeDataset, int i, int i2) {
        this(attributeDataset.attributes(), attributeDataset.x(), attributeDataset.labels(), i, i2);
    }

    public AdaBoost(Attribute[] attributeArr, double[][] dArr, int[] iArr, int i, int i2) {
        if (dArr.length != iArr.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", Integer.valueOf(dArr.length), Integer.valueOf(iArr.length)));
        }
        if (i < 1) {
            throw new IllegalArgumentException(INVALID_NUMBER_OF_TREES + i);
        }
        if (i2 < 2) {
            throw new IllegalArgumentException("Invalid maximum leaves: " + i2);
        }
        int[] unique = Math.unique(iArr);
        Arrays.sort(unique);
        for (int i3 = 0; i3 < unique.length; i3++) {
            if (unique[i3] < 0) {
                throw new IllegalArgumentException("Negative class label: " + unique[i3]);
            }
            if (i3 > 0 && unique[i3] - unique[i3 - 1] > 1) {
                throw new IllegalArgumentException("Missing class: " + (unique[i3 - 1] + 1));
            }
        }
        this.k = unique.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (attributeArr == null) {
            int length = dArr[0].length;
            attributeArr = new Attribute[length];
            for (int i4 = 0; i4 < length; i4++) {
                attributeArr[i4] = new NumericAttribute("V" + (i4 + 1));
            }
        }
        int[][] sort = SmileUtils.sort(attributeArr, dArr);
        int length2 = dArr.length;
        int[] iArr2 = new int[length2];
        double[] dArr2 = new double[length2];
        boolean[] zArr = new boolean[length2];
        for (int i5 = 0; i5 < length2; i5++) {
            dArr2[i5] = 1.0d;
        }
        double d = 1.0d / this.k;
        double log = Math.log(this.k - 1);
        int i6 = 0;
        this.trees = new DecisionTree[i];
        this.alpha = new double[i];
        this.error = new double[i];
        int i7 = 0;
        while (true) {
            if (i7 >= i) {
                break;
            }
            double sum = Math.sum(dArr2);
            for (int i8 = 0; i8 < length2; i8++) {
                int i9 = i8;
                dArr2[i9] = dArr2[i9] / sum;
            }
            Arrays.fill(iArr2, 0);
            for (int i10 : Math.random(dArr2, length2)) {
                iArr2[i10] = iArr2[i10] + 1;
            }
            this.trees[i7] = new DecisionTree(attributeArr, dArr, iArr, i2, 1, dArr[0].length, DecisionTree.SplitRule.GINI, iArr2, sort);
            for (int i11 = 0; i11 < length2; i11++) {
                zArr[i11] = this.trees[i7].predict(dArr[i11]) != iArr[i11];
            }
            double d2 = 0.0d;
            for (int i12 = 0; i12 < length2; i12++) {
                if (zArr[i12]) {
                    d2 += dArr2[i12];
                }
            }
            if (1.0d - d2 <= d) {
                logger.error(String.format("Skip the weak classifier %d makes %.2f%% weighted error", Integer.valueOf(i7), Double.valueOf(100.0d * d2)));
                i6++;
                if (i6 > 3) {
                    this.trees = (DecisionTree[]) Arrays.copyOf(this.trees, i7);
                    this.alpha = Arrays.copyOf(this.alpha, i7);
                    this.error = Arrays.copyOf(this.error, i7);
                    break;
                }
                i7--;
            } else {
                i6 = 0;
                this.error[i7] = d2;
                this.alpha[i7] = Math.log((1.0d - d2) / Math.max(1.0E-10d, d2)) + log;
                double exp = Math.exp(this.alpha[i7]);
                for (int i13 = 0; i13 < length2; i13++) {
                    if (zArr[i13]) {
                        int i14 = i13;
                        dArr2[i14] = dArr2[i14] * exp;
                    }
                }
            }
            i7++;
        }
        this.importance = new double[attributeArr.length];
        for (DecisionTree decisionTree : this.trees) {
            double[] importance = decisionTree.importance();
            for (int i15 = 0; i15 < importance.length; i15++) {
                double[] dArr3 = this.importance;
                int i16 = i15;
                dArr3[i16] = dArr3[i16] + importance[i15];
            }
        }
    }

    public double[] importance() {
        return this.importance;
    }

    public int size() {
        return this.trees.length;
    }

    public void trim(int i) {
        if (i > this.trees.length) {
            throw new IllegalArgumentException("The new model size is larger than the current size.");
        }
        if (i <= 0) {
            throw new IllegalArgumentException("Invalid new model size: " + i);
        }
        if (i < this.trees.length) {
            this.trees = (DecisionTree[]) Arrays.copyOf(this.trees, i);
            this.alpha = Arrays.copyOf(this.alpha, i);
            this.error = Arrays.copyOf(this.error, i);
        }
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        double[] dArr2 = new double[this.k];
        for (int i = 0; i < this.trees.length; i++) {
            int predict = this.trees[i].predict(dArr);
            dArr2[predict] = dArr2[predict] + this.alpha[i];
        }
        return Math.whichMax(dArr2);
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        Arrays.fill(dArr2, 0.0d);
        for (int i = 0; i < this.trees.length; i++) {
            int predict = this.trees[i].predict(dArr);
            dArr2[predict] = dArr2[predict] + this.alpha[i];
        }
        double sum = Math.sum(dArr2);
        for (int i2 = 0; i2 < this.k; i2++) {
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / sum;
        }
        return Math.whichMax(dArr2);
    }

    public double[] test(double[][] dArr, int[] iArr) {
        int length = this.trees.length;
        double[] dArr2 = new double[length];
        int length2 = dArr.length;
        int[] iArr2 = new int[length2];
        Accuracy accuracy = new Accuracy();
        if (this.k == 2) {
            double[] dArr3 = new double[length2];
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < length2; i2++) {
                    int i3 = i2;
                    dArr3[i3] = dArr3[i3] + (this.alpha[i] * this.trees[i].predict(dArr[i2]));
                    iArr2[i2] = dArr3[i2] > 0.0d ? 1 : 0;
                }
                dArr2[i] = accuracy.measure(iArr, iArr2);
            }
        } else {
            double[][] dArr4 = new double[length2][this.k];
            for (int i4 = 0; i4 < length; i4++) {
                for (int i5 = 0; i5 < length2; i5++) {
                    double[] dArr5 = dArr4[i5];
                    int predict = this.trees[i4].predict(dArr[i5]);
                    dArr5[predict] = dArr5[predict] + this.alpha[i4];
                    iArr2[i5] = Math.whichMax(dArr4[i5]);
                }
                dArr2[i4] = accuracy.measure(iArr, iArr2);
            }
        }
        return dArr2;
    }

    public double[][] test(double[][] dArr, int[] iArr, ClassificationMeasure[] classificationMeasureArr) {
        int length = this.trees.length;
        int length2 = classificationMeasureArr.length;
        double[][] dArr2 = new double[length][length2];
        int length3 = dArr.length;
        int[] iArr2 = new int[length3];
        if (this.k == 2) {
            double[] dArr3 = new double[length3];
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < length3; i2++) {
                    int i3 = i2;
                    dArr3[i3] = dArr3[i3] + (this.alpha[i] * this.trees[i].predict(dArr[i2]));
                    iArr2[i2] = dArr3[i2] > 0.0d ? 1 : 0;
                }
                for (int i4 = 0; i4 < length2; i4++) {
                    dArr2[i][i4] = classificationMeasureArr[i4].measure(iArr, iArr2);
                }
            }
        } else {
            double[][] dArr4 = new double[length3][this.k];
            for (int i5 = 0; i5 < length; i5++) {
                for (int i6 = 0; i6 < length3; i6++) {
                    double[] dArr5 = dArr4[i6];
                    int predict = this.trees[i5].predict(dArr[i6]);
                    dArr5[predict] = dArr5[predict] + this.alpha[i5];
                    iArr2[i6] = Math.whichMax(dArr4[i6]);
                }
                for (int i7 = 0; i7 < length2; i7++) {
                    dArr2[i5][i7] = classificationMeasureArr[i7].measure(iArr, iArr2);
                }
            }
        }
        return dArr2;
    }

    public DecisionTree[] getTrees() {
        return this.trees;
    }
}
