package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.base.mlp.Cost;
import smile.base.mlp.Layer;
import smile.base.mlp.LayerBuilder;
import smile.base.mlp.MultilayerPerceptron;
import smile.math.MathEx;
import smile.util.IntSet;

/* loaded from: input_file:smile/classification/MLP.class */
public class MLP extends MultilayerPerceptron implements OnlineClassifier<double[]>, SoftClassifier<double[]>, Serializable {
    private static final long serialVersionUID = 2;
    private int k;
    private IntSet labels;

    public MLP(int i, LayerBuilder... layerBuilderArr) {
        super(net(i, layerBuilderArr));
        this.k = this.output.getOutputSize();
        if (this.k == 1) {
            this.k = 2;
        }
        this.labels = IntSet.of(this.k);
    }

    public MLP(IntSet intSet, int i, LayerBuilder... layerBuilderArr) {
        super(net(i, layerBuilderArr));
        this.k = this.output.getOutputSize();
        if (this.k == 1) {
            this.k = 2;
        }
        this.labels = intSet;
    }

    private static Layer[] net(int i, LayerBuilder... layerBuilderArr) {
        int length = layerBuilderArr.length;
        Layer[] layerArr = new Layer[length];
        for (int i2 = 0; i2 < length; i2++) {
            layerArr[i2] = layerBuilderArr[i2].build(i);
            i = layerBuilderArr[i2].neurons();
        }
        return layerArr;
    }

    @Override // smile.classification.SoftClassifier
    public int predict(double[] dArr, double[] dArr2) {
        propagate(dArr);
        int outputSize = this.output.getOutputSize();
        if (outputSize == 1 && this.k == 2) {
            dArr2[1] = this.output.output()[0];
            dArr2[0] = 1.0d - dArr2[1];
        } else {
            System.arraycopy(this.output.output(), 0, dArr2, 0, outputSize);
        }
        return this.labels.valueOf(MathEx.whichMax(dArr2));
    }

    @Override // smile.classification.Classifier
    public int predict(double[] dArr) {
        propagate(dArr);
        if (this.output.getOutputSize() == 1 && this.k == 2) {
            return this.labels.valueOf(this.output.output()[0] > 0.5d ? 1 : 0);
        }
        return this.labels.valueOf(MathEx.whichMax(this.output.output()));
    }

    @Override // smile.classification.OnlineClassifier
    public void update(double[] dArr, int i) {
        propagate(dArr);
        setTarget(this.labels.indexOf(i));
        backpropagate();
        update();
    }

    @Override // smile.classification.OnlineClassifier
    public void update(double[][] dArr, int[] iArr) {
        double d = this.alpha;
        this.alpha = 1.0d;
        for (int i = 0; i < dArr.length; i++) {
            propagate(dArr[i]);
            setTarget(this.labels.indexOf(iArr[i]));
            backpropagate();
        }
        update();
        this.alpha = d;
    }

    private void setTarget(int i) {
        int outputSize = this.output.getOutputSize();
        double d = this.output.cost() == Cost.LIKELIHOOD ? 1.0d : 0.9d;
        double d2 = 1.0d - d;
        if (outputSize == 1) {
            this.target[0] = i == 1 ? d : d2;
        } else {
            Arrays.fill(this.target, d2);
            this.target[i] = d;
        }
    }
}
