package org.dromara.easyai.rnnNerveEntity;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;

/* loaded from: input_file:org/dromara/easyai/rnnNerveEntity/Nerve.class */
public abstract class Nerve {
    private final int id;
    boolean fromOutNerve;
    protected int upNub;
    protected int downNub;
    protected int rnnOutNub;
    protected Matrix nerveMatrix;
    protected double threshold;
    protected String name;
    protected double outNub;
    protected double E;
    protected double gradient;
    protected double studyPoint;
    protected double sigmaW;
    protected ActiveFunction activeFunction;
    private final int rzType;
    private final double lParam;
    private final int step;
    private final int kernLen;
    private Matrix im2col;
    private int xInput;
    private int yInput;
    private Matrix outMatrix;
    private final List<Nerve> son = new ArrayList();
    private final List<Nerve> rnnOut = new ArrayList();
    private final List<Nerve> father = new ArrayList();
    protected Map<Integer, Double> dendrites = new HashMap();
    protected Map<Integer, Double> wg = new HashMap();
    protected Map<Long, List<Double>> features = new HashMap();
    private int backNub = 0;
    private final MatrixOperation matrixOperation = new MatrixOperation();

    public Map<Integer, Double> getDendrites() {
        return this.dendrites;
    }

    public Matrix getNerveMatrix() {
        return this.nerveMatrix;
    }

    public void setNerveMatrix(Matrix matrix) {
        this.nerveMatrix = matrix;
    }

    public void setDendrites(Map<Integer, Double> map) {
        this.dendrites = map;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public void setThreshold(double d) {
        this.threshold = d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Nerve(int i, int i2, String str, int i3, double d, boolean z, ActiveFunction activeFunction, boolean z2, int i4, double d2, int i5, int i6, int i7) throws Exception {
        this.fromOutNerve = false;
        this.id = i;
        this.upNub = i2;
        this.name = str;
        this.downNub = i3;
        this.studyPoint = d;
        this.activeFunction = activeFunction;
        this.rzType = i4;
        this.lParam = d2;
        this.step = i5;
        this.kernLen = i6;
        this.rnnOutNub = i7;
        if (str.equals("OutNerve")) {
            this.fromOutNerve = true;
        }
        initPower(z, z2);
    }

    protected void setStudyPoint(double d) {
        this.studyPoint = d;
    }

    public void sendMessage(long j, double d, boolean z, Map<Integer, Double> map, OutBack outBack, boolean z2, Matrix matrix) throws Exception {
        if (this.son.size() <= 0) {
            throw new Exception("this layer is lastIndex");
        }
        Iterator<Nerve> it = this.son.iterator();
        while (it.hasNext()) {
            it.next().input(j, d, z, map, outBack, z2, matrix);
        }
    }

    public void sendRnnMessage(long j, double d, boolean z, Map<Integer, Double> map, OutBack outBack, boolean z2, Matrix matrix) throws Exception {
        if (this.rnnOut.size() <= 0) {
            throw new Exception("this layer is lastIndex");
        }
        Iterator<Nerve> it = this.rnnOut.iterator();
        while (it.hasNext()) {
            it.next().input(j, d, z, map, outBack, z2, matrix);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix conv(Matrix matrix) throws Exception {
        this.xInput = matrix.getX();
        this.yInput = matrix.getY();
        int i = this.kernLen - this.step;
        int i2 = (this.xInput - i) / this.step;
        int i3 = (this.yInput - i) / this.step;
        Matrix matrix2 = new Matrix(i2, i3);
        this.im2col = this.matrixOperation.im2col(matrix, this.kernLen, this.step);
        Matrix mulMatrix = this.matrixOperation.mulMatrix(this.im2col, this.nerveMatrix);
        for (int i4 = 0; i4 < i2; i4++) {
            for (int i5 = 0; i5 < i3; i5++) {
                matrix2.setNub(i4, i5, this.activeFunction.function(mulMatrix.getNumber((i4 * i3) + i5, 0)));
            }
        }
        this.outMatrix = matrix2;
        return matrix2;
    }

    public void sendMatrix(long j, Matrix matrix, boolean z, int i, OutBack outBack) throws Exception {
        if (this.son.size() <= 0) {
            throw new Exception("this layer is lastIndex");
        }
        Iterator<Nerve> it = this.son.iterator();
        while (it.hasNext()) {
            it.next().inputMatrix(j, matrix, z, i, outBack);
        }
    }

    private void backSendMessage(long j, boolean z) throws Exception {
        if (this.father.size() > 0) {
            for (int i = 0; i < this.father.size(); i++) {
                this.father.get(i).backGetMessage(this.wg.get(Integer.valueOf(i + 1)).doubleValue(), j, z);
            }
        }
    }

    private void backMatrixMessage(Matrix matrix) throws Exception {
        if (this.father.size() > 0) {
            for (int i = 0; i < this.father.size(); i++) {
                this.father.get(i).backMatrix(matrix);
            }
        }
    }

    protected void input(long j, double d, boolean z, Map<Integer, Double> map, OutBack outBack, boolean z2, Matrix matrix) throws Exception {
    }

    protected void inputMatrix(long j, Matrix matrix, boolean z, int i, OutBack outBack) throws Exception {
    }

    private void backGetMessage(double d, long j, boolean z) throws Exception {
        this.backNub++;
        this.sigmaW += d;
        if (this.backNub == (z ? this.rnnOutNub : this.downNub)) {
            this.backNub = 0;
            this.gradient = this.activeFunction.functionG(this.outNub) * this.sigmaW;
            updatePower(j);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void backMatrix(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix(x * y, 1);
        int i = 0;
        for (int i2 = 0; i2 < x; i2++) {
            for (int i3 = 0; i3 < y; i3++) {
                matrix2.setNub(i, 0, matrix.getNumber(i2, i3) * this.activeFunction.functionG(this.outMatrix.getNumber(i2, i3)) * this.studyPoint);
                i++;
            }
        }
        Matrix mulMatrix = this.matrixOperation.mulMatrix(this.matrixOperation.transPosition(this.im2col), matrix2);
        int x2 = this.im2col.getX();
        int y2 = this.im2col.getY() - 1;
        for (int i4 = 0; i4 < x2; i4++) {
            double number = matrix2.getNumber(i4, 0);
            for (int i5 = 0; i5 < y2; i5++) {
                this.im2col.setNub(i4, i5, this.nerveMatrix.getNumber(i5, 0) * number);
            }
        }
        Matrix reverseIm2col = this.matrixOperation.reverseIm2col(this.im2col, this.kernLen, this.step, this.xInput, this.yInput);
        this.nerveMatrix = this.matrixOperation.add(this.nerveMatrix, mulMatrix);
        backMatrixMessage(reverseIm2col);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updatePower(long j) throws Exception {
        double d = this.gradient * this.studyPoint;
        this.threshold -= d;
        updateW(d, j);
        this.sigmaW = 0.0d;
        backSendMessage(j, this.fromOutNerve);
    }

    private double regularization(double d, double d2) {
        double d3 = 0.0d;
        if (this.rzType != 0) {
            if (this.rzType == 2) {
                d3 = d2 * (-d);
            } else if (this.rzType == 1) {
                if (d > 0.0d) {
                    d3 = -d2;
                } else if (d < 0.0d) {
                    d3 = d2;
                }
            }
        }
        return d3;
    }

    private void updateW(double d, long j) {
        List<Double> list = this.features.get(Long.valueOf(j));
        double d2 = 0.0d;
        if (this.rzType != 0) {
            double d3 = 0.0d;
            for (Map.Entry<Integer, Double> entry : this.dendrites.entrySet()) {
                d3 = this.rzType == 2 ? d3 + Math.pow(entry.getValue().doubleValue(), 2.0d) : d3 + Math.abs(entry.getValue().doubleValue());
            }
            d2 = d3 * this.lParam * this.studyPoint;
        }
        for (Map.Entry<Integer, Double> entry2 : this.dendrites.entrySet()) {
            int intValue = entry2.getKey().intValue();
            double doubleValue = entry2.getValue().doubleValue();
            double doubleValue2 = list.get(intValue - 1).doubleValue() * d;
            double d4 = doubleValue * this.gradient;
            double regularization = doubleValue + regularization(doubleValue, d2) + doubleValue2;
            this.wg.put(Integer.valueOf(intValue), Double.valueOf(d4));
            this.dendrites.put(Integer.valueOf(intValue), Double.valueOf(regularization));
        }
        this.features.remove(Long.valueOf(j));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v17, types: [java.util.List] */
    public boolean insertParameter(long j, double d) {
        ArrayList arrayList;
        boolean z = false;
        if (this.features.containsKey(Long.valueOf(j))) {
            arrayList = (List) this.features.get(Long.valueOf(j));
        } else {
            arrayList = new ArrayList();
            this.features.put(Long.valueOf(j), arrayList);
        }
        arrayList.add(Double.valueOf(d));
        if (arrayList.size() >= this.upNub) {
            z = true;
        }
        return z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void destoryParameter(long j) {
        this.features.remove(Long.valueOf(j));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getWOne(long j) {
        List<Double> list = this.features.get(Long.valueOf(j));
        double d = 0.0d;
        int i = 0;
        while (true) {
            if (i >= list.size()) {
                break;
            }
            double doubleValue = list.get(i).doubleValue();
            double doubleValue2 = this.dendrites.get(Integer.valueOf(i + 1)).doubleValue();
            if (doubleValue > 0.5d) {
                d = doubleValue2;
                break;
            }
            i++;
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double calculation(long j) {
        double d = 0.0d;
        List<Double> list = this.features.get(Long.valueOf(j));
        for (int i = 0; i < list.size(); i++) {
            d = (this.dendrites.get(Integer.valueOf(i + 1)).doubleValue() * list.get(i).doubleValue()) + d;
        }
        return d - this.threshold;
    }

    private void initPower(boolean z, boolean z2) throws Exception {
        Random random = new Random();
        if (z2) {
            int i = this.kernLen * this.kernLen;
            double sqrt = Math.sqrt(2.0d / i);
            this.nerveMatrix = new Matrix(i + 1, 1);
            for (int i2 = 0; i2 < this.nerveMatrix.getX(); i2++) {
                double d = 0.0d;
                if (z) {
                    d = random.nextDouble() * sqrt;
                }
                this.nerveMatrix.setNub(i2, 0, d);
            }
            return;
        }
        if (this.upNub > 0) {
            double sqrt2 = Math.sqrt(this.upNub);
            for (int i3 = 1; i3 < this.upNub + 1; i3++) {
                double d2 = 0.0d;
                if (z) {
                    d2 = random.nextDouble() / sqrt2;
                }
                this.dendrites.put(Integer.valueOf(i3), Double.valueOf(d2));
            }
            this.threshold = z ? random.nextDouble() / sqrt2 : 0.0d;
        }
    }

    public int getId() {
        return this.id;
    }

    public void connect(List<Nerve> list) {
        this.son.addAll(list);
    }

    public void connectOut(List<Nerve> list) {
        this.rnnOut.addAll(list);
    }

    public void connectFather(List<Nerve> list) {
        this.father.addAll(list);
    }
}
