package org.dromara.easyai.transFormer;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.dromara.easyai.function.ReLu;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.transFormer.model.CodecBlockModel;
import org.dromara.easyai.transFormer.nerve.HiddenNerve;
import org.dromara.easyai.transFormer.nerve.Nerve;
import org.dromara.easyai.transFormer.seflAttention.LayNorm;
import org.dromara.easyai.transFormer.seflAttention.MultiSelfAttention;

/* loaded from: input_file:org/dromara/easyai/transFormer/CodecBlock.class */
public class CodecBlock {
    private final MultiSelfAttention multiSelfAttention;
    private final LayNorm attentionLayNorm;
    private final LayNorm lineLayNorm;
    private CodecBlock afterEncoderBlock;
    private CodecBlock beforeEncoderBlock;
    private CodecBlock lastEncoderBlock;
    private final boolean encoder;
    private LineBlock lineBlock;
    private FirstDecoderBlock firstDecoderBlock;
    private final MatrixOperation matrixOperation;
    private final int coreNumber;
    private final List<HiddenNerve> fistHiddenNerves = new ArrayList();
    private final List<HiddenNerve> secondHiddenNerves = new ArrayList();
    private final Map<Long, Matrix> outMatrixMap = new HashMap();

    public CodecBlockModel getModel() throws Exception {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < this.fistHiddenNerves.size(); i++) {
            arrayList.add(this.fistHiddenNerves.get(i).getModel());
            arrayList2.add(this.secondHiddenNerves.get(i).getModel());
        }
        CodecBlockModel codecBlockModel = new CodecBlockModel();
        codecBlockModel.setMultiSelfAttentionModel(this.multiSelfAttention.getModel());
        codecBlockModel.setAttentionLayNormModel(this.attentionLayNorm.getModel());
        codecBlockModel.setFistNervesModel(arrayList);
        codecBlockModel.setSecondNervesModel(arrayList2);
        codecBlockModel.setLineLayNormModel(this.lineLayNorm.getModel());
        return codecBlockModel;
    }

    public void insertModel(CodecBlockModel codecBlockModel) throws Exception {
        this.multiSelfAttention.insertModel(codecBlockModel.getMultiSelfAttentionModel());
        this.attentionLayNorm.insertModel(codecBlockModel.getAttentionLayNormModel());
        List<float[][]> fistNervesModel = codecBlockModel.getFistNervesModel();
        List<float[][]> secondNervesModel = codecBlockModel.getSecondNervesModel();
        for (int i = 0; i < this.fistHiddenNerves.size(); i++) {
            this.fistHiddenNerves.get(i).insertModel(fistNervesModel.get(i));
            this.secondHiddenNerves.get(i).insertModel(secondNervesModel.get(i));
        }
        this.lineLayNorm.insertModel(codecBlockModel.getLineLayNormModel());
    }

    public void setFirstDecoderBlock(FirstDecoderBlock firstDecoderBlock) {
        this.firstDecoderBlock = firstDecoderBlock;
    }

    public void setLineBlock(LineBlock lineBlock) {
        this.lineBlock = lineBlock;
    }

    public void setLastEncoderBlock(CodecBlock codecBlock) {
        this.lastEncoderBlock = codecBlock;
    }

    public void setAfterEncoderBlock(CodecBlock codecBlock) {
        this.afterEncoderBlock = codecBlock;
    }

    public void setBeforeEncoderBlock(CodecBlock codecBlock) {
        this.beforeEncoderBlock = codecBlock;
    }

    public CodecBlock(int i, int i2, float f, int i3, boolean z, int i4, float f2, int i5, boolean z2, int i6) throws Exception {
        this.matrixOperation = new MatrixOperation(i6);
        this.encoder = z;
        this.coreNumber = i6;
        this.attentionLayNorm = new LayNorm(1, i2, this, null, f, i6);
        this.lineLayNorm = new LayNorm(2, i2, this, null, f, i6);
        this.multiSelfAttention = new MultiSelfAttention(i, f, i3, i2, z, this, i5, z2, i6);
        this.multiSelfAttention.setLayNorm(this.attentionLayNorm);
        this.attentionLayNorm.setMultiSelfAttention(this.multiSelfAttention);
        initLine(i2, f, i4, f2);
        this.attentionLayNorm.setHiddenNerves(this.fistHiddenNerves);
        this.lineLayNorm.setHiddenNerves(this.secondHiddenNerves);
    }

    public void backError(long j, Matrix matrix) throws Exception {
        this.lineLayNorm.backErrorFromLine(matrix, j);
    }

    public void removeOutMatrix(long j) {
        this.outMatrixMap.remove(Long.valueOf(j));
    }

    public Matrix getOutMatrix(long j) {
        return this.outMatrixMap.get(Long.valueOf(j));
    }

    public void sendOutputMatrix(long j, Matrix matrix, boolean z, OutBack outBack, List<Integer> list, Matrix matrix2, boolean z2) throws Exception {
        if (this.beforeEncoderBlock != null) {
            this.beforeEncoderBlock.sendInputMatrix(j, matrix, z, outBack, list, matrix2, z2);
        } else if (this.encoder) {
            this.outMatrixMap.put(Long.valueOf(j), matrix);
        } else {
            this.lineBlock.sendParameter(j, matrix, z, outBack, list, z2);
        }
    }

    public void backCodecError(Matrix matrix, long j, Matrix matrix2) throws Exception {
        Matrix add = this.matrixOperation.add(matrix, matrix2);
        if (this.afterEncoderBlock != null) {
            this.afterEncoderBlock.backError(j, add);
        } else if (this.firstDecoderBlock != null) {
            this.firstDecoderBlock.backError(j, add);
        }
    }

    public void backLastEncoderError(Matrix matrix) throws Exception {
        this.lastEncoderBlock.backLastError(matrix);
    }

    private void backLastError(Matrix matrix) throws Exception {
        this.lineLayNorm.backLastError(matrix);
    }

    public void encoderBackStart(long j) throws Exception {
        this.lineLayNorm.encoderBackStart(j);
    }

    public void sendInputMatrix(long j, Matrix matrix, boolean z, OutBack outBack, List<Integer> list, Matrix matrix2, boolean z2) throws Exception {
        this.multiSelfAttention.sendMatrixMessage(j, matrix, z, outBack, list, matrix2, z2);
    }

    private void initLine(int i, float f, int i2, float f2) throws Exception {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i3 = 0; i3 < i; i3++) {
            HiddenNerve hiddenNerve = new HiddenNerve(i3 + 1, 1, f, new ReLu(), i, i, null, i2, f2, this.coreNumber);
            this.fistHiddenNerves.add(hiddenNerve);
            hiddenNerve.setAfterLayNorm(this.attentionLayNorm);
            arrayList.add(hiddenNerve);
        }
        for (int i4 = 0; i4 < i; i4++) {
            HiddenNerve hiddenNerve2 = new HiddenNerve(i4 + 1, 2, f, null, i, 1, null, i2, f2, this.coreNumber);
            hiddenNerve2.setBeforeLayNorm(this.lineLayNorm);
            this.secondHiddenNerves.add(hiddenNerve2);
            arrayList2.add(hiddenNerve2);
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            ((Nerve) it.next()).connect(arrayList2);
        }
        Iterator it2 = arrayList2.iterator();
        while (it2.hasNext()) {
            ((Nerve) it2.next()).connectFather(arrayList);
        }
    }
}
