package com.xdja.pki.gmssl.crypto.sdf;

import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.sdf.SdfSDK;
import com.xdja.pki.gmssl.sdf.SdfSDKException;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.OutputLengthException;
import org.bouncycastle.crypto.params.ParametersWithIV;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SdfSymmetric implements BlockCipher {

    private Logger logger = LoggerFactory.getLogger(this.getClass());

    private static final int BLOCK_SIZE = 16;

    private SdfSDK sdfSDK;

    private long[] phKeyHandle;
    private boolean forEncryption;

    private ParametersWithIV parametersWithIV;
    private SdfSymmetricKeyParameters sdfBlockCipherKeyParameters;

    public SdfSymmetric() throws SdfSDKException {
        this(SdfCryptoType.YUNHSUM);
    }

    public SdfSymmetric(SdfCryptoType sdfCryptoType) throws SdfSDKException {
        this(sdfCryptoType.getSdfSDK());
    }

    public SdfSymmetric(SdfSDK sdfSDK) throws SdfSDKException {
        this.sdfSDK = sdfSDK;
        this.sdfSDK.init();
    }

    @Override
    public void init(boolean forEncryption, CipherParameters params) throws IllegalArgumentException {
        if (params instanceof SdfSymmetricKeyParameters) {
            //这时 初始化向量为 00000000000000000000000000000000 16个 0
            this.sdfBlockCipherKeyParameters = (SdfSymmetricKeyParameters) params;
            this.parametersWithIV = new ParametersWithIV(params, new byte[16]);
        } else if (params instanceof ParametersWithIV) {
            this.parametersWithIV = (ParametersWithIV) params;
            this.sdfBlockCipherKeyParameters = (SdfSymmetricKeyParameters) ((ParametersWithIV) params).getParameters();
        } else {
            throw new IllegalArgumentException("invalid parameter passed to SdfSymmetric init - " + params.getClass().getName());
        }

        this.forEncryption = forEncryption;

        byte[] key = this.sdfBlockCipherKeyParameters.getKey();

        try {
            switch (this.sdfBlockCipherKeyParameters.getKeyCipherType()) {
                case ECC:
                    //导入密文会话密钥
                    SdfSymmetricKey sdfBlockCipherKey = new SdfSymmetricKey(this.sdfSDK);
                    this.phKeyHandle = sdfBlockCipherKey.importKeyWithIskEcc(this.sdfBlockCipherKeyParameters.getSdfPrivateKey(), key);
                    break;
                case None:
                    //导入明文会话密钥
                    if (key.length != 16) {
                        throw new IllegalArgumentException("SM4 requires a 128 bit key");
                    }
                    this.phKeyHandle = sdfSDK.importKey(key);
                    break;
                default:
                    throw new IllegalArgumentException("SDF SM4 init - unknown key cipher type " + this.sdfBlockCipherKeyParameters.getKeyCipherType());
            }
        } catch (SdfSDKException e) {
            logger.error("init import key", e);
            throw new IllegalArgumentException("SDF SM4 init - import key error");
        }
    }

    @Override
    public String getAlgorithmName() {
        return this.sdfBlockCipherKeyParameters.getSdfAlgIdBlockCipher().getName();
    }

    @Override
    public int getBlockSize() {
        return BLOCK_SIZE;
    }

    public int getEncryptionLength(int inLength, SdfSymmetricKeyParameters.PaddingType paddingType) {
        if (paddingType == SdfSymmetricKeyParameters.PaddingType.NoPadding) {
            return inLength;
        } else {
            int paddingLength = BLOCK_SIZE - (inLength % BLOCK_SIZE);
            return inLength + paddingLength;
        }
    }

    @Override
    public int processBlock(byte[] in, int inOff, byte[] out, int outOff) throws DataLengthException, IllegalStateException {
        GMSSLByteArrayUtils.printHexBinary(logger, "processBlock in", in);

        try {
            byte[] iv = parametersWithIV.getIV();
            GMSSLByteArrayUtils.printHexBinary(logger, "processBlock iv", iv);
            if (this.forEncryption) {
                byte[] plainText;
                if (this.sdfBlockCipherKeyParameters.getPaddingType() == SdfSymmetricKeyParameters.PaddingType.NoPadding) {
                    plainText = new byte[in.length];
                    System.arraycopy(in, 0, plainText, 0, in.length);
                } else {
                    int paddingLength = BLOCK_SIZE - (in.length % BLOCK_SIZE);
                    plainText = new byte[in.length + paddingLength];
                    System.arraycopy(in, 0, plainText, 0, in.length);
                    byte padding;
                    if (this.sdfBlockCipherKeyParameters.getPaddingType() == SdfSymmetricKeyParameters.PaddingType.SSL3Padding) {
                        padding = (byte) (paddingLength - 1);
                    } else {
                        padding = (byte) (paddingLength);
                    }
                    int off = in.length;
                    for (int i = 0; i < paddingLength; i++) {
                        plainText[off++] = padding;
                    }
                }
                byte[] pucEncData = sdfSDK.encrypt(phKeyHandle, this.sdfBlockCipherKeyParameters.getSdfAlgIdBlockCipher(), iv, plainText);
                GMSSLByteArrayUtils.printHexBinary(logger, "processBlock pucEncData", pucEncData);
                System.arraycopy(pucEncData, 0, out, outOff, plainText.length);
                return plainText.length;
            } else {
                byte[] plain = sdfSDK.decrypt(phKeyHandle, this.sdfBlockCipherKeyParameters.getSdfAlgIdBlockCipher(), iv, in);
                byte lastByte = plain[plain.length - 1];
                byte[] plainText;
                switch (this.sdfBlockCipherKeyParameters.getPaddingType()) {
                    case PKCS5Padding:
                    case PKCS7Padding:
                        plainText = new byte[in.length - (lastByte & 0xff)];
                        break;
                    case SSL3Padding:
                        plainText = new byte[in.length - ((lastByte & 0xff) + 1)];
                        break;
                    default:
                        plainText = new byte[in.length];
                        break;
                }
                System.arraycopy(plain, 0, plainText, 0, plainText.length);
                GMSSLByteArrayUtils.printHexBinary(logger, "processBlock plainText", plainText);
                System.arraycopy(plainText, 0, out, outOff, plainText.length);
                return plainText.length;
            }
        } catch (SdfSDKException e) {
            logger.error("processBlock", e);
            throw new DataLengthException(e.getMessage());
        }
    }

    @Override
    public void reset() {

    }

    public void release() throws SdfSDKException {
        if (sdfSDK != null && phKeyHandle != null) {
            sdfSDK.destroyKey(phKeyHandle);
            sdfSDK.release();
        }
    }
}
