package org.dromara.easyai.nerveEntity;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.dromara.easyai.conv.ConvCount;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.i.ActiveFunction;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.matrixTools.Matrix;
import org.dromara.easyai.matrixTools.MatrixOperation;

/* loaded from: input_file:org/dromara/easyai/nerveEntity/Nerve.class */
public abstract class Nerve extends ConvCount {
    private Nerve sonOnly;
    private Nerve fatherOnly;
    private final int id;
    protected int upNub;
    protected int downNub;
    protected float threshold;
    protected String name;
    protected float outNub;
    protected float E;
    protected float gradient;
    protected float studyPoint;
    protected float sigmaW;
    protected Matrix sigmaMatrix;
    protected ActiveFunction activeFunction;
    private final int rzType;
    private final float lParam;
    private final int kernLen;
    protected final int depth;
    protected final int matrixX;
    protected final int matrixY;
    protected final int convTimes;
    private final MatrixOperation matrixOperation;
    protected final int channelNo;
    protected final float oneConvRate;
    private final List<Nerve> son = new ArrayList();
    private final List<Nerve> father = new ArrayList();
    protected Map<Integer, Float> dendrites = new HashMap();
    protected Map<Integer, Float> wg = new HashMap();
    protected Map<Long, List<Float>> features = new HashMap();
    private int backNub = 0;
    private final ConvParameter convParameter = new ConvParameter();

    public Map<Integer, Float> getDendrites() {
        return this.dendrites;
    }

    public ConvParameter getConvParameter() {
        return this.convParameter;
    }

    public void setDendrites(Map<Integer, Float> map) {
        this.dendrites = map;
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float f) {
        this.threshold = f;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Nerve(int i, int i2, String str, int i3, float f, boolean z, ActiveFunction activeFunction, boolean z2, int i4, float f2, int i5, int i6, int i7, int i8, int i9, int i10, int i11, float f3) throws Exception {
        this.matrixOperation = new MatrixOperation(i9);
        this.matrixX = i7;
        this.matrixY = i8;
        this.channelNo = i11;
        this.id = i;
        this.convTimes = i10;
        this.depth = i6;
        this.upNub = i2;
        this.name = str;
        this.downNub = i3;
        this.studyPoint = f;
        this.activeFunction = activeFunction;
        this.rzType = i4;
        this.lParam = f2;
        this.kernLen = i5;
        this.oneConvRate = f3;
        initPower(z, z2);
    }

    protected void setStudyPoint(float f) {
        this.studyPoint = f;
    }

    public void sendMessage(long j, float f, boolean z, Map<Integer, Float> map, OutBack outBack) throws Exception {
        if (this.son.isEmpty()) {
            throw new Exception("this layer is lastIndex");
        }
        Iterator<Nerve> it = this.son.iterator();
        while (it.hasNext()) {
            it.next().input(j, f, z, map, outBack);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Matrix conv(Matrix matrix) throws Exception {
        return downConvAndPooling(matrix, this.convParameter, this.convTimes, this.activeFunction, this.kernLen, true, -1L);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void demRedByMatrixList(long j, List<Matrix> list, boolean z, Map<Integer, Float> map, OutBack outBack, boolean z2) throws Exception {
        if (z) {
            this.convParameter.setFeatureMatrixList(list);
        }
        sendMatrix(j, conv(oneConv(list, this.convParameter.getOneConvPower())), z, map, outBack, z2);
    }

    public void sendMatrixList(long j, List<Float> list, boolean z, Map<Integer, Float> map, OutBack outBack) throws Exception {
        if (this.son.isEmpty()) {
            throw new Exception("this layer is lastIndex");
        }
        Iterator<Nerve> it = this.son.iterator();
        while (it.hasNext()) {
            it.next().inputMatrixFeature(j, list, z, map, outBack);
        }
    }

    public void sendMatrix(long j, Matrix matrix, boolean z, Map<Integer, Float> map, OutBack outBack, boolean z2) throws Exception {
        if (this.sonOnly == null) {
            throw new Exception("this layer is lastIndex");
        }
        this.sonOnly.inputMatrix(j, matrix, z, map, outBack, z2);
    }

    public void sendThreeChannelMatrix(long j, ThreeChannelMatrix threeChannelMatrix, boolean z, Map<Integer, Float> map, OutBack outBack, boolean z2) throws Exception {
        if (this.sonOnly == null) {
            throw new Exception("this layer is lastIndex");
        }
        this.sonOnly.inputThreeChannelMatrix(j, threeChannelMatrix, z, map, outBack, z2);
    }

    public void sendListMatrix(long j, List<Matrix> list, boolean z, Map<Integer, Float> map, OutBack outBack, boolean z2) throws Exception {
        if (this.sonOnly == null) {
            throw new Exception("this layer is lastIndex");
        }
        this.sonOnly.demRedByMatrixList(j, list, z, map, outBack, z2);
    }

    private void backSendMessage(long j) throws Exception {
        if (!this.father.isEmpty()) {
            for (int i = 0; i < this.father.size(); i++) {
                this.father.get(i).backGetMessage(this.wg.get(Integer.valueOf(i + 1)).floatValue(), j);
            }
            return;
        }
        if (this.fatherOnly == null || this.depth != 1) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < this.wg.size(); i2++) {
            arrayList.add(this.wg.get(Integer.valueOf(i2 + 1)));
        }
        this.fatherOnly.backMatrix(this.matrixOperation.ListToMatrix(arrayList, this.matrixX, this.matrixY));
    }

    private void backMatrixMessage(Matrix matrix) throws Exception {
        if (this.fatherOnly != null) {
            this.fatherOnly.backMatrix(matrix);
        }
    }

    protected void input(long j, float f, boolean z, Map<Integer, Float> map, OutBack outBack) throws Exception {
    }

    protected void inputMatrixFeature(long j, List<Float> list, boolean z, Map<Integer, Float> map, OutBack outBack) throws Exception {
    }

    protected void inputMatrix(long j, Matrix matrix, boolean z, Map<Integer, Float> map, OutBack outBack, boolean z2) throws Exception {
    }

    protected void inputThreeChannelMatrix(long j, ThreeChannelMatrix threeChannelMatrix, boolean z, Map<Integer, Float> map, OutBack outBack, boolean z2) throws Exception {
    }

    private void backGetMessage(float f, long j) throws Exception {
        this.backNub++;
        this.sigmaW += f;
        if (this.backNub == this.downNub) {
            this.backNub = 0;
            this.gradient = this.activeFunction.functionG(this.outNub) * this.sigmaW;
            updatePower(j);
        }
    }

    protected void backMatrix(Matrix matrix) throws Exception {
        this.backNub++;
        if (this.sigmaMatrix == null) {
            this.sigmaMatrix = matrix;
        } else {
            this.sigmaMatrix = this.matrixOperation.add(matrix, this.sigmaMatrix);
        }
        if (this.backNub == this.downNub) {
            this.backNub = 0;
            Matrix backAllDownConv = backAllDownConv(this.convParameter, backDownPooling(this.sigmaMatrix, this.convParameter.getOutX(), this.convParameter.getOutY()), this.studyPoint, this.activeFunction, this.convTimes, this.kernLen);
            this.sigmaMatrix = null;
            if (this.depth == 1) {
                backOneConv(backAllDownConv, this.convParameter.getFeatureMatrixList(), this.convParameter.getOneConvPower(), this.oneConvRate, false);
            } else {
                backMatrixMessage(backAllDownConv);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updatePower(long j) throws Exception {
        float f = this.gradient * this.studyPoint;
        this.threshold -= f;
        updateW(f, j);
        this.sigmaW = 0.0f;
        backSendMessage(j);
    }

    private float regularization(float f, float f2) {
        float f3 = 0.0f;
        if (this.rzType != 0) {
            if (this.rzType == 2) {
                f3 = f2 * (-f);
            } else if (this.rzType == 1) {
                if (f > 0.0f) {
                    f3 = -f2;
                } else if (f < 0.0f) {
                    f3 = f2;
                }
            }
        }
        return f3;
    }

    private void updateW(float f, long j) {
        List<Float> list = this.features.get(Long.valueOf(j));
        float f2 = 0.0f;
        if (this.rzType != 0) {
            float f3 = 0.0f;
            Iterator<Map.Entry<Integer, Float>> it = this.dendrites.entrySet().iterator();
            while (it.hasNext()) {
                f3 = this.rzType == 2 ? f3 + ((float) Math.pow(r0.getValue().floatValue(), 2.0d)) : f3 + Math.abs(it.next().getValue().floatValue());
            }
            f2 = f3 * this.lParam * this.studyPoint;
        }
        for (Map.Entry<Integer, Float> entry : this.dendrites.entrySet()) {
            int intValue = entry.getKey().intValue();
            float floatValue = entry.getValue().floatValue();
            float floatValue2 = list.get(intValue - 1).floatValue() * f;
            float f4 = floatValue * this.gradient;
            float regularization = floatValue + regularization(floatValue, f2) + floatValue2;
            this.wg.put(Integer.valueOf(intValue), Float.valueOf(f4));
            this.dendrites.put(Integer.valueOf(intValue), Float.valueOf(regularization));
        }
        this.features.remove(Long.valueOf(j));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v12, types: [java.util.List] */
    public void insertParameters(long j, List<Float> list) {
        ArrayList arrayList;
        if (this.features.containsKey(Long.valueOf(j))) {
            arrayList = (List) this.features.get(Long.valueOf(j));
        } else {
            arrayList = new ArrayList();
            this.features.put(Long.valueOf(j), arrayList);
        }
        arrayList.addAll(list);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v17, types: [java.util.List] */
    public boolean insertParameter(long j, float f) {
        ArrayList arrayList;
        boolean z = false;
        if (this.features.containsKey(Long.valueOf(j))) {
            arrayList = (List) this.features.get(Long.valueOf(j));
        } else {
            arrayList = new ArrayList();
            this.features.put(Long.valueOf(j), arrayList);
        }
        arrayList.add(Float.valueOf(f));
        if (arrayList.size() >= this.upNub) {
            z = true;
        }
        return z;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void destoryParameter(long j) {
        this.features.remove(Long.valueOf(j));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public float calculation(long j) throws Exception {
        float f = 0.0f;
        List<Float> list = this.features.get(Long.valueOf(j));
        if (this.dendrites.size() != list.size()) {
            throw new Exception("权重数量:" + this.dendrites.size() + ",特征数量:" + list.size());
        }
        for (int i = 0; i < list.size(); i++) {
            f = (this.dendrites.get(Integer.valueOf(i + 1)).floatValue() * list.get(i).floatValue()) + f;
        }
        return f - this.threshold;
    }

    private void initPower(boolean z, boolean z2) throws Exception {
        Random random = new Random();
        if (z2) {
            initMatrixPower(random);
            return;
        }
        if (this.upNub > 0) {
            for (int i = 1; i < this.upNub + 1; i++) {
                float f = 0.0f;
                if (z) {
                    f = random.nextFloat() / ((float) Math.sqrt(this.upNub));
                }
                this.dendrites.put(Integer.valueOf(i), Float.valueOf(f));
            }
            this.threshold = z ? random.nextFloat() / ((float) Math.sqrt(this.upNub)) : 0.0f;
        }
    }

    private void initMatrixPower(Random random) throws Exception {
        int i = this.kernLen * this.kernLen;
        List<Matrix> nerveMatrixList = this.convParameter.getNerveMatrixList();
        ArrayList arrayList = new ArrayList();
        List<ConvSize> convSizeList = this.convParameter.getConvSizeList();
        if (this.depth == 1) {
            for (int i2 = 0; i2 < this.channelNo; i2++) {
                arrayList.add(Float.valueOf(random.nextFloat() / this.channelNo));
            }
            this.convParameter.setOneConvPower(arrayList);
        }
        for (int i3 = 0; i3 < this.convTimes; i3++) {
            Matrix matrix = new Matrix(i, 1);
            convSizeList.add(new ConvSize());
            for (int i4 = 0; i4 < matrix.getX(); i4++) {
                matrix.setNub(i4, 0, random.nextFloat() / this.kernLen);
            }
            nerveMatrixList.add(matrix);
        }
    }

    public int getId() {
        return this.id;
    }

    public void connect(List<Nerve> list) {
        this.son.addAll(list);
    }

    public void connectSonOnly(Nerve nerve) {
        this.sonOnly = nerve;
    }

    public void connectFatherOnly(Nerve nerve) {
        this.fatherOnly = nerve;
    }

    public void connectFather(List<Nerve> list) {
        this.father.addAll(list);
    }
}
