/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.norm;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;

public class LayerNorm
extends AbstractBlock {
    private float epsilon;
    private Shape normalizedShape;
    private boolean center;
    private boolean scale;
    private int[] axis;
    private Parameter gamma;
    private Parameter beta;

    LayerNorm(Builder builder) {
        this.epsilon = builder.epsilon;
        this.scale = builder.scale;
        this.center = builder.center;
        this.axis = builder.axis;
        this.gamma = this.addParameter(Parameter.builder().setName("gamma").setType(Parameter.Type.GAMMA).optRequiresGrad(this.scale).build());
        this.beta = this.addParameter(Parameter.builder().setName("beta").setType(Parameter.Type.BETA).optRequiresGrad(this.center).build());
    }

    public static NDList layerNorm(NDArray input, Shape normalizedShape, NDArray gamma, NDArray beta, float eps) {
        NDArrayEx ex = input.getNDArrayInternal();
        return ex.layerNorm(input, normalizedShape, gamma, beta, eps);
    }

    public static Builder builder() {
        return new Builder();
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        Device device = input.getDevice();
        NDArray gammaArr = parameterStore.getValue(this.gamma, device, training);
        NDArray betaArr = parameterStore.getValue(this.beta, device, training);
        return LayerNorm.layerNorm(input, this.normalizedShape, gammaArr, betaArr, this.epsilon);
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputShapes) {
        return new Shape[]{inputShapes[0]};
    }

    @Override
    protected void beforeInitialize(Shape ... inputShapes) {
        super.beforeInitialize(inputShapes);
        this.normalizedShape = this.axis == null ? inputShapes[0].slice(1) : new Shape(Arrays.stream(this.axis).mapToLong(dim -> inputShapes[0].get(dim)).toArray());
    }

    @Override
    public void prepare(Shape[] inputShapes) {
        this.gamma.setShape(this.normalizedShape);
        this.beta.setShape(this.normalizedShape);
    }

    @Override
    protected void saveMetadata(DataOutputStream os) throws IOException {
        this.saveInputShapes(os);
        os.write(this.normalizedShape.getEncoded());
    }

    @Override
    public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        if (loadVersion != this.version) {
            throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
        }
        this.readInputShapes(is);
        this.normalizedShape = Shape.decode(is);
    }

    public static final class Builder {
        private float epsilon = 1.0E-5f;
        private boolean scale = true;
        private boolean center = true;
        private int[] axis;

        Builder() {
        }

        public Builder axis(int ... axis) {
            this.axis = axis;
            return this;
        }

        public Builder optCenter(boolean val) {
            this.center = val;
            return this;
        }

        public Builder optScale(boolean val) {
            this.scale = val;
            return this;
        }

        public Builder optEpsilon(float val) {
            this.epsilon = val;
            return this;
        }

        public LayerNorm build() {
            return new LayerNorm(this);
        }
    }
}

