package com.xdja.pki.gmssl.crypto.utils;


import com.xdja.pki.gmssl.crypto.init.GMSSLPkiCryptoInit;
import com.xdja.pki.gmssl.crypto.sdf.SdfCryptoType;
import com.xdja.pki.gmssl.crypto.sdf.SdfSymmetricCipher;
import com.xdja.pki.gmssl.crypto.sdf.SdfSymmetricKeyParameters;
import com.xdja.pki.gmssl.sdf.SdfSDKException;
import com.xdja.pki.gmssl.sdf.bean.SdfAlgIdSymmetric;
import com.xdja.pki.gmssl.x509.utils.bean.GMSSLCryptoType;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.params.ParametersWithIV;
import org.bouncycastle.jce.provider.BouncyCastleProvider;

import javax.crypto.Cipher;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.security.Key;
import java.security.Security;

public class GMSSLSymmetricEncryptUtils {

    private static final String SM4_ALGORITHM = "SM4";

    static {
        if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) {
            Security.addProvider(new BouncyCastleProvider());
        }
    }

    public enum EncryptTypeByBC {
        /**
         * ECB加密模式无填充
         */
        SM4_ECB_NoPadding("SM4/ECB/NoPadding"),

        /**
         * ECB加密模式PKCS5Padding填充
         */
        SM4_ECB_PKCS5Padding("SM4/ECB/PKCS5Padding"),

        /**
         * ECB加密模式PKCS7Padding填充
         */
        SM4_ECB_PKCS7Padding("SM4/ECB/PKCS7Padding"),

        /**
         * CBC加密模式无填充
         */
        SM4_CBC_NoPadding("SM4/CBC/NoPadding"),

        /**
         * CBC加密模式PKCS5Padding填充
         */
        SM4_CBC_PKCS5Padding("SM4/CBC/PKCS5Padding"),


        /**
         * CBC加密模式PKCS7Padding填充
         */
        SM4_CBC_PKCS7Padding("SM4/CBC/PKCS7Padding"),

        AES_CCM_NoPadding("AES/CCM/NoPadding"),
        ;


        private String name;

        EncryptTypeByBC(String name) {

            this.name = name;
        }

        public String getName() {
            return name;
        }
    }


    /**
     * 使用 padding模式填充的加解密公用类
     *
     * @param forEncryption 是否为加密
     * @param encryptType   加密填充类型
     * @param key           16位byte[] 会话密钥
     * @param data          二进制byte[]数据 待加密消息或已加密消息
     * @param initIV        16位byte[]数据 初始化向量
     * @return byte[]   明文消息
     */
    public static byte[] symmetricCBCEncryptByBC(
            boolean forEncryption, EncryptTypeByBC encryptType,
            byte[] key, byte[] data, byte[] initIV) throws Exception {
        Cipher cipher = Cipher.getInstance(encryptType.getName(), BouncyCastleProvider.PROVIDER_NAME);
        Key keySpec = new SecretKeySpec(key, SM4_ALGORITHM);
        if (forEncryption) {
            cipher.init(Cipher.ENCRYPT_MODE, keySpec, new IvParameterSpec(initIV));
        } else {
            cipher.init(Cipher.DECRYPT_MODE, keySpec, new IvParameterSpec(initIV));
        }
        return cipher.doFinal(data);
    }

    /**
     * 使用 padding模式填充的加解密公用类
     *
     * @param forEncryption 是否为加密
     * @param encryptType   加密填充类型
     * @param key           16位byte[] 会话密钥
     * @param data          二进制byte[]数据 待加密消息或已加密消息
     * @return byte[] 明文消息
     */
    public static byte[] symmetricECBEncryptByBC(
            boolean forEncryption, EncryptTypeByBC encryptType,
            byte[] key, byte[] data) throws Exception {
        Cipher cipher = Cipher.getInstance(encryptType.getName(), BouncyCastleProvider.PROVIDER_NAME);
        Key keySpec = new SecretKeySpec(key, SM4_ALGORITHM);

        if (forEncryption) {
            cipher.init(Cipher.ENCRYPT_MODE, keySpec);
        } else {
            cipher.init(Cipher.DECRYPT_MODE, keySpec);
        }
        return cipher.doFinal(data);
    }

    /**
     * 使用 padding模式填充的加解密公用类
     *
     * @param forEncryption 是否为加密
     * @param sdfCryptoType 加密使用类型 包括加密机YumHsm和PCIE卡
     * @param paddingType   填充类型 NoPadding PKCS5Padding PKCS7Padding SSL3Padding
     * @param key           16位byte[] 会话密钥
     * @param data          二进制byte[]数据 待加密消息或已加密消息
     * @return byte[] 明文消息
     */
    public static byte[] symmetricECBEncryptBySdf(boolean forEncryption, SdfCryptoType sdfCryptoType,
                                                  SdfSymmetricKeyParameters.PaddingType paddingType,
                                                  byte[] key, SdfAlgIdSymmetric symmetric, byte[] data) throws Exception {
        if (GMSSLPkiCryptoInit.isHsmServer() ||
                GMSSLPkiCryptoInit.getCryptoType() == GMSSLCryptoType.SANC_HSM) {
            if (symmetric == SdfAlgIdSymmetric.SGD_SM4_ECB) {
                if (forEncryption) {
                    return GMSSLSM4ECBEncryptUtils.encrypt(key, data, paddingType);
                } else {
                    return GMSSLSM4ECBEncryptUtils.decrypt(key, data, paddingType);
                }
            } else {
                //剩余默认是SM1 ECB
                if (forEncryption) {
                    return GMSSLSM1ECBEncryptUtils.encrypt(key, data, paddingType);
                } else {
                    return GMSSLSM1ECBEncryptUtils.decrypt(key, data, paddingType);
                }
            }
        }
        if (GMSSLPkiCryptoInit.getCryptoType() == GMSSLCryptoType.DONGJIN_HSM) {
            sdfCryptoType = SdfCryptoType.DONGJIN;
        }
        CipherParameters param = new SdfSymmetricKeyParameters(paddingType, symmetric, key);
        SdfSymmetricCipher sdfSymmetric = new SdfSymmetricCipher(sdfCryptoType);
        sdfSymmetric.init(forEncryption, param);
        byte[] output = sdfSymmetric.doFinal(data);
        sdfSymmetric.release();
        return output;
    }

    /**
     * ECB 模式加密
     *
     * @param forEncryption     是否为加密
     * @param sdfCryptoType     SDF加密类型参见 SdfCryptoType
     * @param paddingType       填充类型 参见 paddingType
     * @param keyEncryptKeyType 密钥加密密钥的类型
     * @param kekIndex          密钥加密密钥的密钥索引
     * @param kek               KEK Key Encrypt Key
     * @param symmetric         算法标识
     * @param data              数据
     * @return 加密数据
     * @throws Exception
     */
    public static byte[] symmetricECBEncryptWithKekBySdf(boolean forEncryption, SdfCryptoType sdfCryptoType,
                                                         SdfSymmetricKeyParameters.PaddingType paddingType,
                                                         SdfAlgIdSymmetric keyEncryptKeyType, int kekIndex, byte[] kek, SdfAlgIdSymmetric symmetric, byte[] data) throws Exception {
        if (GMSSLPkiCryptoInit.getCryptoType() == GMSSLCryptoType.DONGJIN_HSM) {
            sdfCryptoType = SdfCryptoType.DONGJIN;
        }
        CipherParameters param = new SdfSymmetricKeyParameters(keyEncryptKeyType, paddingType, symmetric, kekIndex, kek);
        SdfSymmetricCipher sdfSymmetric = new SdfSymmetricCipher(sdfCryptoType);
        sdfSymmetric.init(forEncryption, param);
        byte[] output = sdfSymmetric.doFinal(data);
        sdfSymmetric.release();
        return output;
    }

    /**
     * 使用 padding模式填充的加解密公用类
     *
     * @param forEncryption 是否为加密
     * @param sdfCryptoType 加密使用类型 包括加密机YumHsm和PCIE卡
     * @param paddingType   填充类型 NoPadding PKCS5Padding PKCS7Padding SSL3Padding
     * @param key           16位byte[] 会话密钥
     * @param iv            初始化向量
     * @param data          二进制byte[]数据 待加密消息或已加密消息
     * @return byte[] 明文消息
     */
    public static byte[] symmetricCBCEncryptBySdf(boolean forEncryption, SdfCryptoType sdfCryptoType,
                                                  SdfSymmetricKeyParameters.PaddingType paddingType,
                                                  byte[] key, SdfAlgIdSymmetric symmetric, byte[] iv, byte[] data) throws Exception {
        if (GMSSLPkiCryptoInit.isHsmServer() ||
                GMSSLPkiCryptoInit.getCryptoType() == GMSSLCryptoType.SANC_HSM) {
            if (symmetric == SdfAlgIdSymmetric.SGD_SM4_CBC) {
                if (forEncryption) {
                    return GMSSLSM4CBCEncryptUtils.encrypt(key, data, iv, paddingType);
                } else {
                    return GMSSLSM4CBCEncryptUtils.decrypt(key, data, iv, paddingType);
                }
            } else {
                //剩余默认是SM1 CBC
                if (forEncryption) {
                    return GMSSLSM1CBCEncryptUtils.encrypt(key, data, iv, paddingType);
                } else {
                    return GMSSLSM1CBCEncryptUtils.decrypt(key, data, iv, paddingType);
                }
            }
        }
        if (GMSSLPkiCryptoInit.getCryptoType() == GMSSLCryptoType.DONGJIN_HSM) {
            sdfCryptoType = SdfCryptoType.DONGJIN;
        }
        CipherParameters param = new SdfSymmetricKeyParameters(paddingType, symmetric, key);
        ParametersWithIV ivParam = new ParametersWithIV(param, iv);
        SdfSymmetricCipher sdfSymmetric = new SdfSymmetricCipher(sdfCryptoType);
        sdfSymmetric.init(forEncryption, ivParam);
        byte[] output = sdfSymmetric.doFinal(data);
        sdfSymmetric.release();
        return output;
    }


    /**
     * ECB 模式加密
     *
     * @param forEncryption 是否为加密
     * @param sdfCryptoType SDF加密类型参见 SdfCryptoType
     * @param paddingType   填充类型 参见 paddingType
     * @param kekIndex      密钥加密密钥的密钥索引
     * @param kek           KEK Key Encrypt Key
     * @param symmetric     算法标识
     * @param data          数据
     * @param iv            初始化向量
     * @return 加密数据
     * @throws Exception
     */
    public static byte[] symmetricCBCEncryptWithKekBySdf(boolean forEncryption, SdfCryptoType sdfCryptoType,
                                                         SdfSymmetricKeyParameters.PaddingType paddingType,
                                                         SdfAlgIdSymmetric keyEncryptKeyType, int kekIndex, byte[] kek, SdfAlgIdSymmetric symmetric, byte[] iv, byte[] data) throws Exception {
        if (GMSSLPkiCryptoInit.getCryptoType() == GMSSLCryptoType.DONGJIN_HSM) {
            sdfCryptoType = SdfCryptoType.DONGJIN;
        }
        CipherParameters param = new SdfSymmetricKeyParameters(keyEncryptKeyType, paddingType, symmetric, kekIndex, kek);
        ParametersWithIV ivParam = new ParametersWithIV(param, iv);
        SdfSymmetricCipher sdfSymmetric = new SdfSymmetricCipher(sdfCryptoType);
        sdfSymmetric.init(forEncryption, ivParam);
        byte[] output = sdfSymmetric.doFinal(data);
        sdfSymmetric.release();
        return output;
    }
}
