package org.dromara.easyai.transFormer.seflAttention;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.CodecBlock;
import org.dromara.easyai.transFormer.model.MultiSelfAttentionModel;
import org.dromara.easyai.transFormer.model.QKVModel;

/* loaded from: input_file:org/dromara/easyai/transFormer/seflAttention/MultiSelfAttention.class */
public class MultiSelfAttention {
    private final CodecBlock codecBlock;
    private final List<SelfAttention> selfAttentions = new ArrayList();
    private LayNorm layNorm;
    private final double studyPoint;
    private Matrix powerMatrix;
    private final int multiNumber;
    private final int wordVectorDimension;
    private Matrix featureMatrix;
    private final int depth;
    private final boolean encoder;
    private final int maxLength;
    private final boolean selfTimeCode;
    private final MatrixOperation matrixOperation;

    public void setLayNorm(LayNorm layNorm) {
        this.layNorm = layNorm;
    }

    public int getDepth() {
        return this.depth;
    }

    private QKVModel getQKV(List<QKVModel> list, int i) {
        QKVModel qKVModel = null;
        Iterator<QKVModel> it = list.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            QKVModel next = it.next();
            if (next.getSelfID() == i) {
                qKVModel = next;
                break;
            }
        }
        return qKVModel;
    }

    public void insertModel(MultiSelfAttentionModel multiSelfAttentionModel) throws Exception {
        insertPower(multiSelfAttentionModel.getPowerModel(), this.powerMatrix);
        List<QKVModel> qkvModelList = multiSelfAttentionModel.getQkvModelList();
        for (int i = 0; i < this.selfAttentions.size(); i++) {
            QKVModel qkv = getQKV(qkvModelList, i);
            if (qkv == null) {
                throw new Exception("模型与激活参数不匹配!内存与模型文件的多头数量不一致！");
            }
            this.selfAttentions.get(i).insertModel(qkv);
        }
    }

    private void insertPower(double[][] dArr, Matrix matrix) throws Exception {
        for (int i = 0; i < matrix.getX(); i++) {
            for (int i2 = 0; i2 < matrix.getY(); i2++) {
                matrix.setNub(i, i2, dArr[i][i2]);
            }
        }
    }

    public MultiSelfAttentionModel getModel() {
        MultiSelfAttentionModel multiSelfAttentionModel = new MultiSelfAttentionModel();
        ArrayList arrayList = new ArrayList();
        Iterator<SelfAttention> it = this.selfAttentions.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getModel());
        }
        multiSelfAttentionModel.setPowerModel(this.powerMatrix.getMatrix());
        multiSelfAttentionModel.setQkvModelList(arrayList);
        multiSelfAttentionModel.setDepth(this.depth);
        return multiSelfAttentionModel;
    }

    private void mergeFeatureMatrix(Matrix matrix, Matrix matrix2, int i) throws Exception {
        int i2 = this.wordVectorDimension * i;
        int i3 = i2 + this.wordVectorDimension;
        for (int i4 = 0; i4 < matrix2.getX(); i4++) {
            for (int i5 = i2; i5 < i3; i5++) {
                matrix.setNub(i4, i5, matrix2.getNumber(i4, i5 - i2));
            }
        }
    }

    private List<Matrix> splitMatrix(Matrix matrix) {
        ArrayList arrayList = new ArrayList();
        int x = matrix.getX();
        for (int i = 0; i < this.selfAttentions.size(); i++) {
            arrayList.add(matrix.getSonOfMatrix(0, i * this.wordVectorDimension, x, this.wordVectorDimension));
        }
        return arrayList;
    }

    public void backError(Matrix matrix, long j) throws Exception {
        Matrix matrixMulPd = this.matrixOperation.matrixMulPd(this.matrixOperation.mathMulBySelf(matrix, this.studyPoint), this.featureMatrix, this.powerMatrix, false);
        Matrix matrixMulPd2 = this.matrixOperation.matrixMulPd(matrix, this.featureMatrix, this.powerMatrix, true);
        this.powerMatrix = this.matrixOperation.add(this.powerMatrix, matrixMulPd);
        List<Matrix> splitMatrix = splitMatrix(matrixMulPd2);
        Matrix matrix2 = null;
        Matrix matrix3 = null;
        for (int i = 0; i < this.selfAttentions.size(); i++) {
            AttentionError backError = getSefAttentionBySelfID(i).backError(splitMatrix.get(i), j);
            Matrix nextFeatureError = backError.getNextFeatureError();
            matrix2 = matrix2 == null ? nextFeatureError : this.matrixOperation.add(matrix2, nextFeatureError);
            if (!this.encoder && this.depth > 1) {
                Matrix lastEncoderError = backError.getLastEncoderError();
                matrix3 = matrix3 == null ? lastEncoderError : this.matrixOperation.add(matrix3, lastEncoderError);
            }
        }
        if (!this.encoder && this.depth > 1) {
            this.codecBlock.backLastEncoderError(matrix3);
        }
        if (this.codecBlock != null) {
            this.codecBlock.backCodecError(matrix2, j, matrix);
        }
    }

    private SelfAttention getSefAttentionBySelfID(int i) {
        SelfAttention selfAttention = null;
        Iterator<SelfAttention> it = this.selfAttentions.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            SelfAttention next = it.next();
            if (next.getSelfID() == i) {
                selfAttention = next;
                break;
            }
        }
        return selfAttention;
    }

    private Matrix countMultiSelfAttention(List<EventBody> list, boolean z) throws Exception {
        int i = this.wordVectorDimension * this.multiNumber;
        Matrix matrix = null;
        for (int i2 = 0; i2 < list.size(); i2++) {
            Matrix featureMatrix = getEventBodyBySelfID(i2, list).getFeatureMatrix();
            if (i2 == 0) {
                matrix = new Matrix(featureMatrix.getX(), i);
            }
            mergeFeatureMatrix(matrix, featureMatrix, i2);
        }
        Matrix mulMatrix = this.matrixOperation.mulMatrix(matrix, this.powerMatrix);
        if (z) {
            this.featureMatrix = matrix;
        }
        return mulMatrix;
    }

    private EventBody getEventBodyBySelfID(int i, List<EventBody> list) {
        EventBody eventBody = null;
        Iterator<EventBody> it = list.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            EventBody next = it.next();
            if (next.getSelfID() == i) {
                eventBody = next;
                break;
            }
        }
        return eventBody;
    }

    private void addTimeCode(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        for (int i = 0; i < x; i++) {
            for (int i2 = 0; i2 < y; i2++) {
                double pow = 1.0d / Math.pow(10000.0d, (2.0d * (i2 / 2)) / y);
                matrix.setNub(i, i2, matrix.getNumber(i, i2) + (i2 % 2 == 0 ? Math.sin(pow * i) : Math.cos(pow * i)));
            }
        }
    }

    private void addTimeCodeBySelf(Matrix matrix) throws Exception {
        double d = 1.0d / this.maxLength;
        int x = matrix.getX();
        int y = matrix.getY();
        for (int i = 1; i < x; i++) {
            double d2 = i * d;
            for (int i2 = 0; i2 < y; i2++) {
                matrix.setNub(i, i2, matrix.getNumber(i, i2) + d2);
            }
        }
    }

    public void sendMatrixMessage(long j, Matrix matrix, boolean z, OutBack outBack, List<Integer> list, Matrix matrix2, boolean z2) throws Exception {
        if (this.depth == 1) {
            if (this.selfTimeCode) {
                addTimeCodeBySelf(matrix);
            } else {
                addTimeCode(matrix);
            }
        }
        ArrayList arrayList = new ArrayList();
        Iterator<SelfAttention> it = this.selfAttentions.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().sendMatrixFeature(j, z, matrix, matrix2));
        }
        this.layNorm.addNorm(matrix, countMultiSelfAttention(arrayList, z), j, z, outBack, list, matrix2, z2);
    }

    public MultiSelfAttention(int i, double d, int i2, int i3, boolean z, CodecBlock codecBlock, int i4, boolean z2, int i5) throws Exception {
        Random random = new Random();
        this.matrixOperation = new MatrixOperation(i5);
        this.selfTimeCode = z2;
        this.maxLength = i4;
        this.codecBlock = codecBlock;
        this.encoder = z;
        int i6 = i3 * i;
        this.studyPoint = d;
        this.wordVectorDimension = i3;
        this.multiNumber = i;
        this.depth = i2;
        for (int i7 = 0; i7 < i; i7++) {
            this.selfAttentions.add(new SelfAttention(d, i2, i3, i7, z, i5));
        }
        this.powerMatrix = new Matrix(i6, i3);
        int x = this.powerMatrix.getX();
        int y = this.powerMatrix.getY();
        for (int i8 = 0; i8 < x; i8++) {
            for (int i9 = 0; i9 < y; i9++) {
                this.powerMatrix.setNub(i8, i9, random.nextDouble() / i6);
            }
        }
    }
}
