package org.nd4j.linalg.learning;

import java.io.Serializable;
import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaGrad.class */
public class AdaGrad implements Serializable, GradientUpdater {
    public INDArray historicalGradient;
    public int[] shape;
    protected double learningRate;
    protected int numIterations;
    private double epsilon;

    /* loaded from: input_file:org/nd4j/linalg/learning/AdaGrad$AdaGradAggregator.class */
    public static class AdaGradAggregator implements GradientUpdaterAggregator {
        private INDArray historicalGradientSum;
        private double lrSum;
        private long numIterationsSum = 0;
        private int count = 0;

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdater getUpdater() {
            AdaGrad adaGrad = new AdaGrad(this.lrSum / this.count);
            adaGrad.setHistoricalGradient(this.historicalGradientSum.div(Integer.valueOf(this.count)));
            adaGrad.setNumIterations((int) (this.numIterationsSum / this.count));
            return adaGrad;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public void aggregate(GradientUpdater gradientUpdater) {
            if (!(gradientUpdater instanceof AdaGrad)) {
                throw new UnsupportedOperationException("Cannot aggregate AdaGrad with updater: " + gradientUpdater);
            }
            AdaGrad adaGrad = (AdaGrad) gradientUpdater;
            if (this.historicalGradientSum == null) {
                this.historicalGradientSum = adaGrad.historicalGradient.dup();
                this.lrSum = adaGrad.learningRate;
                this.numIterationsSum = adaGrad.numIterations;
            } else {
                this.historicalGradientSum.addi(adaGrad.historicalGradient);
                this.lrSum += adaGrad.learningRate;
                this.numIterationsSum += adaGrad.numIterations;
            }
            this.count++;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdaterAggregator combine(GradientUpdaterAggregator gradientUpdaterAggregator) {
            if (!(gradientUpdaterAggregator instanceof AdaGradAggregator)) {
                throw new IllegalArgumentException("Cannot combine AdaGradAggregator with aggregator: " + gradientUpdaterAggregator);
            }
            AdaGradAggregator adaGradAggregator = (AdaGradAggregator) gradientUpdaterAggregator;
            this.historicalGradientSum.addi(adaGradAggregator.historicalGradientSum);
            this.lrSum += adaGradAggregator.lrSum;
            this.numIterationsSum += adaGradAggregator.numIterationsSum;
            this.count += adaGradAggregator.count;
            return this;
        }
    }

    public AdaGrad(int i, int i2, double d) {
        this.learningRate = 0.1d;
        this.numIterations = 0;
        this.epsilon = 1.0E-8d;
        this.shape = new int[]{i, i2};
        this.learningRate = d;
    }

    public AdaGrad(int i, int i2) {
        this(i, i2, 0.1d);
    }

    public AdaGrad(int[] iArr, double d) {
        this.learningRate = 0.1d;
        this.numIterations = 0;
        this.epsilon = 1.0E-8d;
        this.shape = iArr;
        this.learningRate = d;
    }

    public AdaGrad(double d) {
        this.learningRate = 0.1d;
        this.numIterations = 0;
        this.epsilon = 1.0E-8d;
        this.learningRate = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void update(Object... objArr) {
        if (objArr.length > 0) {
            this.learningRate = ((Double) objArr[0]).doubleValue();
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        INDArray muli;
        if (this.historicalGradient == null) {
            this.historicalGradient = iNDArray.mul(iNDArray).add(Double.valueOf(this.epsilon));
        } else {
            this.historicalGradient.addi(iNDArray.mul(iNDArray));
        }
        INDArray sqrt = Transforms.sqrt(this.historicalGradient);
        try {
            muli = sqrt.rdivi(Double.valueOf(this.learningRate)).muli(iNDArray);
        } catch (ArithmeticException e) {
            muli = sqrt.rdivi(Double.valueOf(this.learningRate)).muli(iNDArray.add(Double.valueOf(this.epsilon)));
        }
        this.numIterations++;
        return muli;
    }

    public double getGradient(double d, int i, int[] iArr) {
        boolean z = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(iArr);
            z = true;
        }
        double sqrt = d * (this.learningRate / ((!z ? Math.sqrt(this.historicalGradient.getDouble(i)) : this.historicalGradient.getDouble(i)) + this.epsilon));
        this.historicalGradient.putScalar(i, this.historicalGradient.getDouble(i) + (d * d));
        this.numIterations++;
        return sqrt;
    }

    public INDArray getGradient(INDArray iNDArray, int i, int[] iArr) {
        INDArray sqrt;
        INDArray rdivi;
        boolean z = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.zeros(iArr).add(Double.valueOf(this.epsilon));
            z = true;
        } else if (!this.historicalGradient.isVector() && this.historicalGradient.slice(i).length() != iNDArray.length()) {
            throw new IllegalArgumentException("Illegal gradient");
        }
        if (this.historicalGradient.isVector()) {
            sqrt = Transforms.sqrt(this.historicalGradient);
        } else {
            sqrt = !z ? Transforms.sqrt(this.historicalGradient.slice(i)) : this.historicalGradient;
        }
        try {
            rdivi = sqrt.rdivi(Double.valueOf(this.learningRate));
        } catch (ArithmeticException e) {
            rdivi = sqrt.rdivi(Double.valueOf(this.learningRate + this.epsilon));
        }
        if (iNDArray.length() != rdivi.length()) {
            iNDArray.muli(rdivi.slice(i));
        } else {
            iNDArray.muli(rdivi);
        }
        this.historicalGradient.slice(i).addi(iNDArray.mul(iNDArray));
        this.numIterations++;
        return iNDArray;
    }

    public AdaGrad createSubset(int i) {
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(this.shape);
        }
        if (Shape.isMatrix(this.shape)) {
            AdaGrad adaGrad = new AdaGrad(1, this.historicalGradient.columns());
            adaGrad.historicalGradient = this.historicalGradient.slice(i).dup();
            adaGrad.setLearningRate(this.learningRate);
            return adaGrad;
        }
        AdaGrad adaGrad2 = new AdaGrad(1, 1);
        adaGrad2.historicalGradient = Nd4j.scalar(this.historicalGradient.getDouble(i));
        adaGrad2.setLearningRate(this.learningRate);
        return adaGrad2;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public GradientUpdaterAggregator getAggregator(boolean z) {
        AdaGradAggregator adaGradAggregator = new AdaGradAggregator();
        if (z) {
            adaGradAggregator.aggregate(this);
        }
        return adaGradAggregator;
    }

    public INDArray getHistoricalGradient() {
        return this.historicalGradient;
    }

    public int[] getShape() {
        return this.shape;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public int getNumIterations() {
        return this.numIterations;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public void setHistoricalGradient(INDArray iNDArray) {
        this.historicalGradient = iNDArray;
    }

    public void setShape(int[] iArr) {
        this.shape = iArr;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public void setEpsilon(double d) {
        this.epsilon = d;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaGrad)) {
            return false;
        }
        AdaGrad adaGrad = (AdaGrad) obj;
        if (!adaGrad.canEqual(this)) {
            return false;
        }
        INDArray historicalGradient = getHistoricalGradient();
        INDArray historicalGradient2 = adaGrad.getHistoricalGradient();
        if (historicalGradient == null) {
            if (historicalGradient2 != null) {
                return false;
            }
        } else if (!historicalGradient.equals(historicalGradient2)) {
            return false;
        }
        return Arrays.equals(getShape(), adaGrad.getShape()) && Double.compare(getLearningRate(), adaGrad.getLearningRate()) == 0 && getNumIterations() == adaGrad.getNumIterations() && Double.compare(getEpsilon(), adaGrad.getEpsilon()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaGrad;
    }

    public int hashCode() {
        INDArray historicalGradient = getHistoricalGradient();
        int hashCode = (((1 * 59) + (historicalGradient == null ? 0 : historicalGradient.hashCode())) * 59) + Arrays.hashCode(getShape());
        long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
        int numIterations = (((hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + getNumIterations();
        long doubleToLongBits2 = Double.doubleToLongBits(getEpsilon());
        return (numIterations * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
    }

    public String toString() {
        return "AdaGrad(historicalGradient=" + getHistoricalGradient() + ", shape=" + Arrays.toString(getShape()) + ", learningRate=" + getLearningRate() + ", numIterations=" + getNumIterations() + ", epsilon=" + getEpsilon() + ")";
    }

    public AdaGrad() {
        this.learningRate = 0.1d;
        this.numIterations = 0;
        this.epsilon = 1.0E-8d;
    }
}
