package com.xdja.pki.oer.gbt.asn1.data;

import com.xdja.pki.gmssl.core.utils.GMSSLBCAeadUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.crypto.sdf.SdfPrivateKey;
import com.xdja.pki.gmssl.crypto.sdf.SdfSymmetricKeyParameters;
import com.xdja.pki.gmssl.crypto.utils.GMSSLRandomUtils;
import com.xdja.pki.gmssl.crypto.utils.GMSSLSHA256DigestUtils;
import com.xdja.pki.gmssl.crypto.utils.GMSSLSM3DigestUtils;
import com.xdja.pki.gmssl.crypto.utils.GMSSLSM4ECBEncryptUtils;
import com.xdja.pki.oer.base.Null;
import com.xdja.pki.oer.core.TimeUtils;
import com.xdja.pki.oer.core.calculate.CalculateFactory;
import com.xdja.pki.oer.core.calculate.CalculateService;
import com.xdja.pki.oer.gbt.asn1.*;
import com.xdja.pki.oer.gbt.asn1.bean.PKRecipientInfoType;
import com.xdja.pki.oer.gbt.asn1.utils.EccPointHolder;
import com.xdja.pki.oer.gbt.asn1.utils.KekBuilder;
import com.xdja.pki.oer.gbt.asn1.utils.KekResolveUtils;
import com.xdja.pki.oer.gbt.asn1.utils.SignatureBuilder;
import com.xdja.pki.oer.gbt.asn1.utils.bean.OEREccPoint;
import com.xdja.pki.oer.gbt.asn1.utils.enums.EccCurveTypeEnum;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.math.ec.custom.gm.SM2P256V1Curve;
import org.bouncycastle.util.encoders.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.security.PrivateKey;
import java.security.PublicKey;

public class SecuredMessageBuilder {
    private static Logger logger = LoggerFactory.getLogger(SecuredMessageBuilder.class);
    private static CalculateService calculateService = CalculateFactory.getInstance();

    public static SecuredMessage buildSignedDataSecuredMessage(
            ItsAidInt itsAidInt,
            PrivateKey caPrivateKey,
            Certificate caCertificate,
            byte[] data
    ) throws Exception {
        SecuredMessage sign = new SecuredMessage();
        SignedData signedData = new SignedData();
        SequenceOfCertificate sequenceOfCertificate = new SequenceOfCertificate();
        sequenceOfCertificate.addCertificate(caCertificate);
        SignerInfo signerInfo = new SignerInfo(sequenceOfCertificate);
        signedData.setSignerInfo(signerInfo);
        TBSData tbsData = new TBSData();
        HeaderInfo headerInfo = new HeaderInfo();
        headerInfo.setItsAid(itsAidInt);
        tbsData.setHeaderInfo(headerInfo);
        tbsData.setData(data);
        signedData.setTbs(tbsData);
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        byte[] ecaEncode = caCertificate.getEncode();
        byte[] tbsDataEncode = tbsData.getEncode();
        PublicVerifyKey verifyKey = caCertificate.getTbsCert().getSubjectAttribute().getVerifyKey();
        OEREccPoint oerSignEccPoint = EccPointHolder.build(verifyKey.getEccPoint().getEncode(), verifyKey.getEccCurve());
        PublicKey publicKey = oerSignEccPoint.getPublicKey();
        Signature signature = null;
        if (((BCECPublicKey) publicKey).getParameters().getCurve() instanceof SM2P256V1Curve) {
            byte[] ecaHash = calculateService.sm3Hash(ecaEncode);
            byte[] tbsHash = calculateService.sm3Hash(tbsDataEncode);
            bos.write(tbsHash);
            bos.write(ecaHash);
            signature = SignatureBuilder.build(caPrivateKey, bos.toByteArray(), EccCurveTypeEnum.SGD_SM2);
        } else {
            byte[] ecaHash = calculateService.sha256Hash(ecaEncode);
            byte[] tbsHash = calculateService.sha256Hash(tbsDataEncode);
            bos.write(tbsHash);
            bos.write(ecaHash);
            signature = SignatureBuilder.build(caPrivateKey, bos.toByteArray(), EccCurveTypeEnum.NIST_P_256);
        }
        signedData.setSign(signature);
        Payload payload = new Payload(signedData);
        sign.setPayload(payload);
        return sign;
    }

    public static SecuredMessage buildSelfSignedDataSecuredMessage(
            ItsAidInt itsAidInt,
            PrivateKey privateKey,
            byte[] data, EccCurveTypeEnum eccCurveTypeEnum
    ) throws Exception {
        SecuredMessage securedMessage = new SecuredMessage();
        SignedData signedData = new SignedData();
        SignerInfo signerInfo = new SignerInfo(new Null());
        TBSData tbsData = new TBSData();
        HeaderInfo headerInfo = new HeaderInfo();
        headerInfo.setItsAid(itsAidInt);
        headerInfo.setGenTime(new Time64(TimeUtils.getNowTime() * 1000));
        tbsData.setHeaderInfo(headerInfo);
        tbsData.setData(data);
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        byte[] tbsDataEncode = tbsData.getEncode();
        byte[] nullHash = null;
        byte[] tbsHash = null;
        if (privateKey instanceof SdfPrivateKey) {
            if (eccCurveTypeEnum == EccCurveTypeEnum.SGD_SM2) {
                nullHash = GMSSLSM3DigestUtils.digestByYunhsm("".getBytes());
                tbsHash = GMSSLSM3DigestUtils.digestByYunhsm(tbsDataEncode);
            } else if (eccCurveTypeEnum == EccCurveTypeEnum.NIST_P_256) {
                nullHash = GMSSLSHA256DigestUtils.digestByBC("".getBytes());
                tbsHash = GMSSLSHA256DigestUtils.digestByBC(tbsDataEncode);
            }
        } else {
            BCECPrivateKey key = (BCECPrivateKey) privateKey;
            if (key.getParameters().getCurve() instanceof SM2P256V1Curve) {
                nullHash = GMSSLSM3DigestUtils.digestByYunhsm("".getBytes());
                tbsHash = GMSSLSM3DigestUtils.digestByYunhsm(tbsDataEncode);
            } else {
                nullHash = GMSSLSHA256DigestUtils.digestByBC("".getBytes());
                tbsHash = GMSSLSHA256DigestUtils.digestByBC(tbsDataEncode);
            }
        }
        bos.write(tbsHash);
        bos.write(nullHash);
        Signature signature = SignatureBuilder.build(privateKey, bos.toByteArray(), eccCurveTypeEnum);
        signedData.setSignerInfo(signerInfo);
        signedData.setTbs(tbsData);
        signedData.setSign(signature);
        Payload payload = new Payload(signedData);
        securedMessage.setPayload(payload);
        return securedMessage;
    }

    public static SecuredMessage buildEncryptedDataSecuredMessage(Certificate caCertificate, byte[] data) throws Exception {
        return buildEncryptedDataSecuredMessage(caCertificate, data, EccCurveTypeEnum.SGD_SM2);
    }

    public static SecuredMessage buildEncryptedDataSecuredMessage(Certificate caCertificate, byte[] data, EccCurveTypeEnum eccCurveTypeEnum) throws Exception {
        SecuredMessage enc = new SecuredMessage();
        byte[] key = GMSSLByteArrayUtils.base64Decode(GMSSLRandomUtils.generateRandomByYunhsm(16));
        EncryptedData encryptedData = new EncryptedData();
        SequenceOfRecipientInfo sequenceOfRecipientInfo = new SequenceOfRecipientInfo();
        PKRecipientInfo certRecipInfo = new PKRecipientInfo();
        EccPoint publicKey = caCertificate.getTbsCert().getSubjectAttribute().getEncryptionKey().getPublicKey();
        RecipientInfo recipientInfo = new RecipientInfo(certRecipInfo, PKRecipientInfoType.CERT_RECIPINFO);
        sequenceOfRecipientInfo.addRecipientInfo(recipientInfo);
        encryptedData.setRecipients(sequenceOfRecipientInfo);
        SymmetricCipherText symmetricCipherText = null;
        byte[] certHash = null;
        if (eccCurveTypeEnum == EccCurveTypeEnum.SGD_SM2) {
            certHash = GMSSLSM3DigestUtils.digestByYunhsm(caCertificate.getEncode());
            EciesEncryptedKey kek = KekBuilder.build(publicKey, key, EccCurveTypeEnum.SGD_SM2);
            certRecipInfo.setKek(kek);
            certRecipInfo.setHashAlg(new HashAlgorithm(HashAlgorithm.SGD_SM3));
            CipherText cipherText = new CipherText();
            byte[] cipher = GMSSLSM4ECBEncryptUtils.sm4SymmetricWithPaddingByYunHsm(true,
                    SdfSymmetricKeyParameters.PaddingType.PKCS7Padding, key, data);
            //ByteArrayUtils.printHexBinary(logger, "cipher", cipher);
            cipherText.setString(cipher);
            symmetricCipherText = new SymmetricCipherText(cipherText);
        } else if (eccCurveTypeEnum == EccCurveTypeEnum.NIST_P_256) {
            certHash = GMSSLSHA256DigestUtils.digestByYunHsm(caCertificate.getEncode());
            EciesEncryptedKey kek = KekBuilder.build(publicKey, key, EccCurveTypeEnum.NIST_P_256);
            certRecipInfo.setKek(kek);
            certRecipInfo.setHashAlg(new HashAlgorithm(HashAlgorithm.SHA_256));
            AesCcmCipherText ccmCipherText = new AesCcmCipherText();
            byte[] nonce = com.xdja.pki.gmssl.core.utils.GMSSLRandomUtils.generateRandom(12);
            ccmCipherText.setNonce(nonce);
            byte[] cipher = GMSSLBCAeadUtils.encryptAESCCM(key, 16, nonce, null, data);
            //ByteArrayUtils.printHexBinary(logger, "cipher", cipher);
            CipherText cipherText = new CipherText();
            cipherText.setString(cipher);
            ccmCipherText.setCipher(cipherText);
            symmetricCipherText = new SymmetricCipherText(ccmCipherText);
        }
        byte[] certHashId8 = new byte[8];
        System.arraycopy(certHash, certHash.length - certHashId8.length, certHashId8, 0, certHashId8.length);
        HashedId8 recipientId = new HashedId8(certHashId8);
        certRecipInfo.setRecipientId(recipientId);
        encryptedData.setCipherText(symmetricCipherText);
        Payload payload = new Payload(encryptedData);
        enc.setPayload(payload);
        //ByteArrayUtils.printHexBinary(logger, "payload", payload.getEncode());
       // logger.info("encrypt data secured message have finished");
        return enc;
    }

    public static SecuredMessage buildEncPskpiSecuredMessage(byte[] key, byte[] data, EccCurveTypeEnum eccCurveTypeEnum) throws Exception {
        SecuredMessage securedMessage = new SecuredMessage();
        EncryptedData encryptedData = new EncryptedData();
        SequenceOfRecipientInfo recipients = new SequenceOfRecipientInfo();
        PreSharedKeyRecipientInfo sharedKeyRecipientInfo = new PreSharedKeyRecipientInfo();
        RecipientInfo recipientInfo = new RecipientInfo(sharedKeyRecipientInfo);
        recipients.addRecipientInfo(recipientInfo);
        SymmetricCipherText symmetricCipherText = null;
        if (eccCurveTypeEnum == EccCurveTypeEnum.SGD_SM2) {
            byte[] keyHash = GMSSLSM3DigestUtils.digestByBC(key);
            byte[] keyBytes = new byte[8];
            System.arraycopy(keyHash, keyHash.length - 8, keyBytes, 0, keyBytes.length);
            sharedKeyRecipientInfo.setSmyKeyHash(new HashedId8(keyBytes));
            sharedKeyRecipientInfo.setHashAlg(new HashAlgorithm(HashAlgorithm.SGD_SM3));
            String cipher = GMSSLSM4ECBEncryptUtils.encryptByYumhsmWithPKCS7Padding(Base64.toBase64String(key), Base64.toBase64String(data));
            CipherText sm4Ecb = new CipherText();
            sm4Ecb.setString(Base64.decode(cipher));
            symmetricCipherText = new SymmetricCipherText(sm4Ecb);
        } else if (eccCurveTypeEnum == EccCurveTypeEnum.NIST_P_256) {
            byte[] keyHash = GMSSLSHA256DigestUtils.digestByYunHsm(key);
            byte[] keyBytes = new byte[8];
            System.arraycopy(keyHash, keyHash.length - 8, keyBytes, 0, keyBytes.length);
            sharedKeyRecipientInfo.setSmyKeyHash(new HashedId8(keyBytes));
            sharedKeyRecipientInfo.setHashAlg(new HashAlgorithm(HashAlgorithm.SHA_256));
            AesCcmCipherText ccmCipherText = new AesCcmCipherText();
            byte[] nonce = com.xdja.pki.gmssl.core.utils.GMSSLRandomUtils.generateRandom(12);
            ccmCipherText.setNonce(nonce);
            byte[] cipher = GMSSLBCAeadUtils.encryptAESCCM(key, 16, nonce, null, data);
            CipherText cipherText = new CipherText();
            cipherText.setString(cipher);
            ccmCipherText.setCipher(cipherText);
            symmetricCipherText = new SymmetricCipherText(ccmCipherText);
        }
        encryptedData.setCipherText(symmetricCipherText);
        encryptedData.setRecipients(recipients);
        Payload payload = new Payload(encryptedData);
        securedMessage.setPayload(payload);
        return securedMessage;
    }

    public static byte[] resolveEncSecuredMessage(int privateKeyIndex, String privateKeyPassword, byte[] data) throws Exception {
        SecuredMessage securedMessage = SecuredMessage.getInstance(data);
        EncryptedData encData = securedMessage.getPayload().getEncData();
        PKRecipientInfo pkRecipientInfo = encData.getRecipients().getRecipientInfos().get(0).getCertRecipInfo();
        if (null == pkRecipientInfo) {
            pkRecipientInfo = encData.getRecipients().getRecipientInfos().get(0).getSignedDataRecipInfo();
        }
        EciesEncryptedKey kek = pkRecipientInfo.getKek();
        byte[] key = KekResolveUtils.getPlain(kek, new SdfPrivateKey(privateKeyIndex, privateKeyPassword.getBytes()));
        SymmetricCipherText symmetricCipherText = encData.getCipherText();
        byte[] plain;
        if (null != symmetricCipherText.getSm4Ecb()) {
            String decryptText = GMSSLSM4ECBEncryptUtils.decryptByYumhsmWithPKCS7Padding(Base64.toBase64String(key), Base64.toBase64String(symmetricCipherText.getSm4Ecb().getString()));
            plain = Base64.decode(decryptText);
        } else {
            plain = GMSSLBCAeadUtils.decryptAESCCM(key, 16, symmetricCipherText.getAesCcm().getNonce().getString(), null, symmetricCipherText.getAesCcm().getCipher().getString());
        }
        return plain;
    }
}
