package org.dromara.easyai.conv;

import java.util.List;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;
import org.dromara.easyai.nerveEntity.ConvParameter;
import org.dromara.easyai.nerveEntity.ConvSize;

/* loaded from: input_file:org/dromara/easyai/conv/ConvCount.class */
public abstract class ConvCount {
    private final MatrixOperation matrixOperation = new MatrixOperation();

    /* JADX INFO: Access modifiers changed from: protected */
    public int getConvMyDep(int i, int i2, int i3, int i4, int i5) {
        return Math.min(getConvDeep(i, i3, i4, i5), getConvDeep(i2, i3, i4, i5));
    }

    private int getConvDeep(int i, int i2, int i3, int i4) {
        int i5 = i;
        int i6 = 0;
        do {
            for (int i7 = 0; i7 < i4; i7++) {
                i5 = (i5 - (i2 - 1)) / 1;
            }
            i5 = (i5 / 2) + (i5 % 2);
            i6++;
        } while (i5 > i3);
        return i6 - 1;
    }

    private Matrix upPooling(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix(x * 2, y * 2);
        for (int i = 0; i < x; i++) {
            for (int i2 = 0; i2 < y; i2++) {
                insertMatrixValue(i * 2, i2 * 2, matrix.getNumber(i, i2), matrix2);
            }
        }
        return matrix2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix backUpPooling(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix(x / 2, y / 2);
        for (int i = 0; i < x - 1; i += 2) {
            for (int i2 = 0; i2 < y - 1; i2 += 2) {
                matrix2.setNub(i / 2, i2 / 2, matrix.getNumber(i, i2) + matrix.getNumber(i, i2 + 1) + matrix.getNumber(i + 1, i2) + matrix.getNumber(i + 1, i2 + 1));
            }
        }
        return matrix2;
    }

    private Matrix downPooling(Matrix matrix) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix((x / 2) + (x % 2 == 1 ? 1 : 0), (y / 2) + (y % 2 == 1 ? 1 : 0));
        for (int i = 0; i < x - 1; i += 2) {
            for (int i2 = 0; i2 < y - 1; i2 += 2) {
                matrix2.setNub(i / 2, i2 / 2, (((matrix.getNumber(i, i2) + matrix.getNumber(i, i2 + 1)) + matrix.getNumber(i + 1, i2)) + matrix.getNumber(i + 1, i2 + 1)) / 4.0f);
            }
        }
        return matrix2;
    }

    private void insertMatrixValue(int i, int i2, float f, Matrix matrix) throws Exception {
        int i3 = i + 2;
        int i4 = i2 + 2;
        if (i3 > matrix.getX()) {
            i3--;
        }
        if (i4 > matrix.getY()) {
            i4--;
        }
        for (int i5 = i; i5 < i3; i5++) {
            for (int i6 = i2; i6 < i4; i6++) {
                matrix.setNub(i5, i6, f);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix backDownPooling(Matrix matrix, int i, int i2) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix2 = new Matrix((x * 2) - (i % 2 == 1 ? 1 : 0), (y * 2) - (i2 % 2 == 1 ? 1 : 0));
        for (int i3 = 0; i3 < x; i3++) {
            for (int i4 = 0; i4 < y; i4++) {
                insertMatrixValue(i3 * 2, i4 * 2, matrix.getNumber(i3, i4) / 4.0f, matrix2);
            }
        }
        return matrix2;
    }

    private int getUpSize(int i, int i2) {
        return (i + i2) - 1;
    }

    private int backUpSize(int i, int i2) {
        return (i - i2) + 1;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix backUpConv(Matrix matrix, int i, ConvParameter convParameter, float f, ActiveFunction activeFunction) throws Exception {
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix upOutMatrix = convParameter.getUpOutMatrix();
        for (int i2 = 0; i2 < x; i2++) {
            for (int i3 = 0; i3 < y; i3++) {
                matrix.setNub(i2, i3, activeFunction.functionG(upOutMatrix.getNumber(i2, i3)) * matrix.getNumber(i2, i3));
            }
        }
        int backUpSize = backUpSize(x, i);
        int backUpSize2 = backUpSize(y, i);
        Matrix upNerveMatrix = convParameter.getUpNerveMatrix();
        Matrix upFeatureMatrix = convParameter.getUpFeatureMatrix();
        Matrix im2col = this.matrixOperation.im2col(matrix, i, 1);
        Matrix matrixMulPd = this.matrixOperation.matrixMulPd(im2col, upFeatureMatrix, upNerveMatrix, false);
        Matrix matrixMulPd2 = this.matrixOperation.matrixMulPd(im2col, upFeatureMatrix, upNerveMatrix, true);
        this.matrixOperation.mathMul(matrixMulPd, f);
        convParameter.setUpNerveMatrix(this.matrixOperation.add(matrixMulPd, upNerveMatrix));
        return this.matrixOperation.vectorToMatrix(matrixMulPd2, backUpSize, backUpSize2);
    }

    private ConvResult upConv(Matrix matrix, int i, Matrix matrix2, ActiveFunction activeFunction) throws Exception {
        ConvResult convResult = new ConvResult();
        int upSize = getUpSize(matrix.getX(), i);
        int upSize2 = getUpSize(matrix.getY(), i);
        Matrix matrixToVector = this.matrixOperation.matrixToVector(matrix, false);
        Matrix reverseIm2col = this.matrixOperation.reverseIm2col(this.matrixOperation.mulMatrix(matrixToVector, matrix2), i, 1, upSize, upSize2);
        Matrix matrix3 = new Matrix(upSize, upSize2);
        for (int i2 = 0; i2 < upSize; i2++) {
            for (int i3 = 0; i3 < upSize2; i3++) {
                matrix3.setNub(i2, i3, activeFunction.function(reverseIm2col.getNumber(i2, i3)));
            }
        }
        convResult.setLeftMatrix(matrixToVector);
        convResult.setResultMatrix(matrix3);
        return convResult;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix upConvAndPooling(Matrix matrix, ConvParameter convParameter, int i, ActiveFunction activeFunction, int i2, boolean z) throws Exception {
        Matrix downConvAndPooling = downConvAndPooling(matrix, convParameter, i, activeFunction, i2, false, -1L);
        if (!z) {
            return downConvAndPooling;
        }
        ConvResult upConv = upConv(downConvAndPooling, i2, convParameter.getUpNerveMatrix(), activeFunction);
        convParameter.setUpOutMatrix(upConv.getResultMatrix());
        convParameter.setUpFeatureMatrix(upConv.getLeftMatrix());
        return upPooling(upConv.getResultMatrix());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix downConvAndPooling(Matrix matrix, ConvParameter convParameter, int i, ActiveFunction activeFunction, int i2, boolean z, long j) throws Exception {
        List<ConvSize> convSizeList = convParameter.getConvSizeList();
        List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
        List<Matrix> im2colMatrixList = convParameter.getIm2colMatrixList();
        List<Matrix> outMatrixList = convParameter.getOutMatrixList();
        im2colMatrixList.clear();
        outMatrixList.clear();
        for (int i3 = 0; i3 < i; i3++) {
            ConvSize convSize = convSizeList.get(i3);
            Matrix matrix2 = nerveMatrixList.get(i3);
            int x = matrix.getX();
            int y = matrix.getY();
            convSize.setXInput(x);
            convSize.setYInput(y);
            ConvResult downConvCount = downConvCount(matrix, activeFunction, i2, matrix2);
            im2colMatrixList.add(downConvCount.getLeftMatrix());
            Matrix resultMatrix = downConvCount.getResultMatrix();
            outMatrixList.add(resultMatrix);
            matrix = resultMatrix;
        }
        if (j >= 0) {
            convParameter.getFeatureMap().put(Long.valueOf(j), matrix);
        }
        convParameter.setOutX(matrix.getX());
        convParameter.setOutY(matrix.getY());
        return z ? downPooling(matrix) : matrix;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix oneConv(List<Matrix> list, List<Float> list2) throws Exception {
        int size = list2.size();
        Matrix matrix = null;
        int i = 0;
        while (i < size) {
            Matrix mathMulBySelf = this.matrixOperation.mathMulBySelf(list.get(i), list2.get(i).floatValue());
            matrix = i == 0 ? mathMulBySelf : this.matrixOperation.add(matrix, mathMulBySelf);
            i++;
        }
        return matrix;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void backOneConv(Matrix matrix, List<Matrix> list, List<Float> list2, float f, boolean z) throws Exception {
        int size = list2.size();
        for (int i = 0; i < size; i++) {
            Matrix matrix2 = list.get(i);
            int x = matrix2.getX();
            int y = matrix2.getY();
            float floatValue = list2.get(i).floatValue();
            float f2 = 0.0f;
            float sqrt = (float) Math.sqrt(x * y);
            for (int i2 = 0; i2 < x; i2++) {
                for (int i3 = 0; i3 < y; i3++) {
                    f2 += matrix2.getNumber(i2, i3) * matrix.getNumber(i2, i3) * f;
                }
            }
            list2.set(i, Float.valueOf(floatValue + (z ? f2 / sqrt : f2)));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix backAllDownConv(ConvParameter convParameter, Matrix matrix, float f, ActiveFunction activeFunction, int i, int i2) throws Exception {
        List<Matrix> outMatrixList = convParameter.getOutMatrixList();
        List<Matrix> im2colMatrixList = convParameter.getIm2colMatrixList();
        List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
        List<ConvSize> convSizeList = convParameter.getConvSizeList();
        for (int i3 = i - 1; i3 >= 0; i3--) {
            Matrix matrix2 = outMatrixList.get(i3);
            Matrix matrix3 = im2colMatrixList.get(i3);
            Matrix matrix4 = nerveMatrixList.get(i3);
            ConvSize convSize = convSizeList.get(i3);
            ConvResult backDownConv = backDownConv(matrix, matrix2, activeFunction, matrix3, matrix4, f, i2, convSize.getXInput(), convSize.getYInput());
            nerveMatrixList.set(i3, backDownConv.getNervePowerMatrix());
            matrix = backDownConv.getResultMatrix();
        }
        return matrix;
    }

    private ConvResult backDownConv(Matrix matrix, Matrix matrix2, ActiveFunction activeFunction, Matrix matrix3, Matrix matrix4, float f, int i, int i2, int i3) throws Exception {
        ConvResult convResult = new ConvResult();
        int x = matrix.getX();
        int y = matrix.getY();
        Matrix matrix5 = new Matrix(x * y, 1);
        for (int i4 = 0; i4 < x; i4++) {
            for (int i5 = 0; i5 < y; i5++) {
                matrix5.setNub((y * i4) + i5, 0, matrix.getNumber(i4, i5) * activeFunction.functionG(matrix2.getNumber(i4, i5)));
            }
        }
        Matrix matrixMulPd = this.matrixOperation.matrixMulPd(matrix5, matrix3, matrix4, false);
        Matrix matrixMulPd2 = this.matrixOperation.matrixMulPd(matrix5, matrix3, matrix4, true);
        this.matrixOperation.mathMul(matrixMulPd, f);
        Matrix add = this.matrixOperation.add(matrix4, matrixMulPd);
        Matrix reverseIm2col = this.matrixOperation.reverseIm2col(matrixMulPd2, i, 1, i2, i3);
        convResult.setNervePowerMatrix(add);
        convResult.setResultMatrix(reverseIm2col);
        return convResult;
    }

    private ConvResult downConvCount(Matrix matrix, ActiveFunction activeFunction, int i, Matrix matrix2) throws Exception {
        ConvResult convResult = new ConvResult();
        int i2 = i - 1;
        int x = matrix.getX() - i2;
        int y = matrix.getY() - i2;
        Matrix matrix3 = new Matrix(x, y);
        Matrix im2col = this.matrixOperation.im2col(matrix, i, 1);
        convResult.setLeftMatrix(im2col);
        Matrix mulMatrix = this.matrixOperation.mulMatrix(im2col, matrix2);
        for (int i3 = 0; i3 < x; i3++) {
            for (int i4 = 0; i4 < y; i4++) {
                matrix3.setNub(i3, i4, activeFunction.function(mulMatrix.getNumber((i3 * y) + i4, 0)));
            }
        }
        convResult.setResultMatrix(matrix3);
        return convResult;
    }
}
