package org.dromara.easyai.unet;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.dromara.easyai.config.UNetConfig;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.function.ReLu;
import org.dromara.easyai.function.Tanh;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.nerveEntity.ConvParameter;

/* loaded from: input_file:org/dromara/easyai/unet/UNetManager.class */
public class UNetManager extends ConvCount {
    private final List<UNetEncoder> encoderList = new ArrayList();
    private final List<UNetDecoder> decoderList = new ArrayList();
    private final int kernLen;
    private final int convTimes;
    private final int deep;
    private final float studyRate;
    private final float oneStudyRate;
    private UNetInput input;

    public UNetInput getInput() {
        return this.input;
    }

    public UNetManager(UNetConfig uNetConfig) throws Exception {
        int xSize = uNetConfig.getXSize();
        int ySize = uNetConfig.getYSize();
        int minFeatureValue = uNetConfig.getMinFeatureValue();
        this.kernLen = uNetConfig.getKerSize();
        this.convTimes = uNetConfig.getConvTimes();
        this.studyRate = uNetConfig.getStudyRate();
        this.oneStudyRate = uNetConfig.getOneStudyRate();
        this.deep = getConvMyDep(xSize, ySize, this.kernLen, minFeatureValue, this.convTimes);
        if (this.deep <= 1) {
            throw new Exception("minFeatureValue 设置的值太大了");
        }
        initEncoder(xSize, ySize);
        initDecoder(uNetConfig.isCutting(), uNetConfig.getCutTh());
        connectionCoder();
    }

    private float[] getFValue(Float[] fArr) {
        float[] fArr2 = new float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr2[i] = fArr[i].floatValue();
        }
        return fArr2;
    }

    private Float[] getValue(float[] fArr) {
        Float[] fArr2 = new Float[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr2[i] = Float.valueOf(fArr[i]);
        }
        return fArr2;
    }

    public void insertModel(UNetModel uNetModel) throws Exception {
        List<ConvModel> encoderModels = uNetModel.getEncoderModels();
        List<ConvModel> decoderModels = uNetModel.getDecoderModels();
        if (encoderModels.size() != this.deep) {
            throw new Exception("模型深度不匹配");
        }
        for (int i = 0; i < this.deep; i++) {
            ConvParameter convParameter = this.encoderList.get(i).getConvParameter();
            List<Matrix> nerveMatrixList = convParameter.getNerveMatrixList();
            ConvModel convModel = encoderModels.get(i);
            List<Float[]> downNervePower = convModel.getDownNervePower();
            convParameter.setOneConvPower(convModel.getOneNervePower());
            for (int i2 = 0; i2 < nerveMatrixList.size(); i2++) {
                Matrix matrix = nerveMatrixList.get(i2);
                matrix.setCudaMatrix(getFValue(downNervePower.get(i2)), matrix.getX(), matrix.getY());
            }
        }
        for (int i3 = 0; i3 < this.deep + 1; i3++) {
            ConvParameter convParameter2 = this.decoderList.get(i3).getConvParameter();
            List<Matrix> nerveMatrixList2 = convParameter2.getNerveMatrixList();
            ConvModel convModel2 = decoderModels.get(i3);
            List<Float[]> downNervePower2 = convModel2.getDownNervePower();
            List<Float> oneNervePower = convModel2.getOneNervePower();
            float[] fValue = getFValue(convModel2.getUpNervePower());
            convParameter2.setOneConvPower(oneNervePower);
            Matrix upNerveMatrix = convParameter2.getUpNerveMatrix();
            upNerveMatrix.setCudaMatrix(fValue, upNerveMatrix.getX(), upNerveMatrix.getY());
            for (int i4 = 0; i4 < nerveMatrixList2.size(); i4++) {
                Matrix matrix2 = nerveMatrixList2.get(i4);
                matrix2.setCudaMatrix(getFValue(downNervePower2.get(i4)), matrix2.getX(), matrix2.getY());
            }
        }
    }

    public UNetModel getModel() {
        UNetModel uNetModel = new UNetModel();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        uNetModel.setEncoderModels(arrayList);
        uNetModel.setDecoderModels(arrayList2);
        for (int i = 0; i < this.deep; i++) {
            ConvModel convModel = new ConvModel();
            arrayList.add(convModel);
            ConvParameter convParameter = this.encoderList.get(i).getConvParameter();
            ArrayList arrayList3 = new ArrayList();
            convModel.setDownNervePower(arrayList3);
            convModel.setOneNervePower(convParameter.getOneConvPower());
            Iterator<Matrix> it = convParameter.getNerveMatrixList().iterator();
            while (it.hasNext()) {
                arrayList3.add(getValue(it.next().getCudaMatrix()));
            }
        }
        for (int i2 = 0; i2 < this.deep + 1; i2++) {
            ConvModel convModel2 = new ConvModel();
            arrayList2.add(convModel2);
            ConvParameter convParameter2 = this.decoderList.get(i2).getConvParameter();
            ArrayList arrayList4 = new ArrayList();
            convModel2.setDownNervePower(arrayList4);
            convModel2.setOneNervePower(convParameter2.getOneConvPower());
            convModel2.setUpNervePower(getValue(convParameter2.getUpNerveMatrix().getCudaMatrix()));
            Iterator<Matrix> it2 = convParameter2.getNerveMatrixList().iterator();
            while (it2.hasNext()) {
                arrayList4.add(getValue(it2.next().getCudaMatrix()));
            }
        }
        return uNetModel;
    }

    private void connectionCoder() {
        UNetEncoder uNetEncoder = this.encoderList.get(this.deep - 1);
        UNetDecoder uNetDecoder = this.decoderList.get(0);
        uNetEncoder.setDecoder(uNetDecoder);
        uNetDecoder.setEncoder(uNetEncoder);
        for (int i = 0; i < this.deep; i++) {
            this.decoderList.get(this.deep - i).setMyUNetEncoder(this.encoderList.get(i));
        }
    }

    private void initDecoder(boolean z, float f) throws Exception {
        Cutting cutting = z ? new Cutting(f) : null;
        int i = 0;
        while (i < this.deep + 1) {
            this.decoderList.add(new UNetDecoder(this.kernLen, i + 1, this.convTimes, new Tanh(), i == this.deep, this.studyRate, cutting));
            i++;
        }
        for (int i2 = 0; i2 < this.deep; i2++) {
            UNetDecoder uNetDecoder = this.decoderList.get(i2);
            UNetDecoder uNetDecoder2 = this.decoderList.get(i2 + 1);
            uNetDecoder.setAfterDecoder(uNetDecoder2);
            uNetDecoder2.setBeforeDecoder(uNetDecoder);
        }
    }

    private void initEncoder(int i, int i2) throws Exception {
        for (int i3 = 0; i3 < this.deep; i3++) {
            UNetEncoder uNetEncoder = new UNetEncoder(this.kernLen, this.convTimes, i3 + 1, new ReLu(), this.studyRate, i, i2, this.oneStudyRate);
            if (i3 == 0) {
                this.input = new UNetInput(uNetEncoder);
            }
            this.encoderList.add(uNetEncoder);
        }
        for (int i4 = 0; i4 < this.deep - 1; i4++) {
            UNetEncoder uNetEncoder2 = this.encoderList.get(i4);
            UNetEncoder uNetEncoder3 = this.encoderList.get(i4 + 1);
            uNetEncoder2.setAfterEncoder(uNetEncoder3);
            uNetEncoder3.setBeforeEncoder(uNetEncoder2);
        }
    }
}
