package com.xdja.ca.pkcs7;

import com.xdja.ca.asn1.DigestObjectIdentifiers;
import com.xdja.ca.asn1.RsaObjectIdentifiers;
import com.xdja.ca.asn1.SM2ObjectIdentifiers;
import com.xdja.ca.constant.SdkConstants;
import com.xdja.ca.utils.DnUtil;
import com.xdja.ca.utils.SdkCertUtils;
import com.xdja.ca.utils.SdkHsmUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.crypto.sdf.SdfCryptoType;
import com.xdja.pki.gmssl.crypto.sdf.SdfSymmetricKeyParameters;
import com.xdja.pki.gmssl.crypto.utils.GMSSLSM2EncryptUtils;
import com.xdja.pki.gmssl.crypto.utils.GMSSLSM4ECBEncryptUtils;
import org.bouncycastle.asn1.*;
import org.bouncycastle.asn1.cms.*;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.cert.X509CertificateHolder;
import org.bouncycastle.cert.jcajce.JcaCertStore;
import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter;
import org.bouncycastle.cms.CMSProcessableByteArray;
import org.bouncycastle.cms.CMSSignedData;
import org.bouncycastle.cms.CMSSignedDataGenerator;
import org.bouncycastle.crypto.AsymmetricBlockCipher;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.bouncycastle.crypto.encodings.PKCS1Encoding;
import org.bouncycastle.crypto.engines.RSABlindedEngine;
import org.bouncycastle.crypto.params.RSAKeyParameters;
import org.bouncycastle.util.Store;
import org.bouncycastle.util.encoders.Base64;
import org.bouncycastle.util.io.pem.PemObject;
import org.bouncycastle.util.io.pem.PemWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.crypto.BadPaddingException;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.naming.NamingException;
import java.io.IOException;
import java.io.StringWriter;
import java.math.BigInteger;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPublicKey;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;

/**
 * @author zjx
 * @Date 2017-9-5
 * @Status Create
 */
public class Pkcs7Utils {

    private static final Logger logger = LoggerFactory.getLogger(Pkcs7Utils.class);

    public static final String PKCS7_TYPE = "PKCS7";
    public static final String PKCS_BEGIN_HEADER = "-----BEGIN PKCS7-----";
    public static final String PKCS_END_HEADER = "-----END PKCS7-----";

    static final ASN1ObjectIdentifier smAlgorithm = new ASN1ObjectIdentifier("1.2.156.10197.1");

    static final ASN1ObjectIdentifier sm4 = smAlgorithm.branch("104");

    static final ASN1ObjectIdentifier sm2256 = smAlgorithm.branch("301");
    static final ASN1ObjectIdentifier sm2256_encrypt = sm2256.branch("3");


    static final ASN1ObjectIdentifier rsa = new ASN1ObjectIdentifier("1.2.840.113549.1.1.1");

    @SuppressWarnings("deprecation")
    public static ASN1Set makeRecipientInfos(Integer caAlg, String protectPublicKeyCertIssuerDN, BigInteger protectPublicKeyCertSN, PublicKey protectPublicKey,
                                             byte[] sessionKey) throws Exception {
        RecipientIdentifier rid = new RecipientIdentifier(new IssuerAndSerialNumber(DnUtil.getRFC4519X500Name(protectPublicKeyCertIssuerDN), protectPublicKeyCertSN));
        AlgorithmIdentifier keyEncryptionAlgorithm = null;
        ASN1OctetString encryptedKey = null;

        if (SdkConstants.RSA_ALG_1 == caAlg.intValue()) {
            keyEncryptionAlgorithm = new AlgorithmIdentifier(rsa);
/*			Cipher cipher = Cipher.getInstance("RSA", new BouncyCastleProvider());
			cipher.init(Cipher.ENCRYPT_MODE, protectPublicKey);
			encryptedKey = new DEROctetString(cipher.doFinal(sessionKey));*/
            RSAPublicKey publicKey = (RSAPublicKey) protectPublicKey;
            AsymmetricBlockCipher theEngine = new PKCS1Encoding(new RSABlindedEngine());
            /*AsymmetricBlockCipher theEngine = new RSAEngine();
            theEngine = new PKCS1Encoding(theEngine);*/
            RSAKeyParameters rsaKeyParameters = new RSAKeyParameters(false, publicKey.getModulus(), publicKey.getPublicExponent());
            theEngine.init(true, rsaKeyParameters);
            encryptedKey = new DEROctetString(theEngine.processBlock(sessionKey, 0, sessionKey.length));
        } else {
            keyEncryptionAlgorithm = new AlgorithmIdentifier(sm2256_encrypt);
            try {
                String encData = GMSSLSM2EncryptUtils.encryptASN1ByYunhsm(protectPublicKey, GMSSLByteArrayUtils.base64Encode(sessionKey));
                encryptedKey = new DEROctetString(GMSSLByteArrayUtils.base64Decode(encData));
            } catch (Exception e) {
                logger.error("使用公钥加密会话密钥异常", e);
                throw new Exception("使用公钥加密会话密钥异常");
            }
        }

        // 组装
        KeyTransRecipientInfo keyTransRecipientInfo = new KeyTransRecipientInfo(rid, keyEncryptionAlgorithm, encryptedKey);

        return new DERSet(keyTransRecipientInfo);
    }

    public static EncryptedContentInfo makeEncryptedContentInfo(Integer caAlg,byte[] encryptedContentInfo) {
        DERObjectIdentifier contentType =  new DERObjectIdentifier("1.2.156.10197.6.1.4.2.104");
        AlgorithmIdentifier contentEncryptionAlgorithm = null;
        ASN1OctetString encryptedContent = null;


        return new EncryptedContentInfo(contentType, contentEncryptionAlgorithm, encryptedContent);
    }

    public EncryptedContentInfo makeEncryptedContentInfo(Integer caAlg, byte[] sessionKey, X509Certificate encCert) throws Exception {
        DERObjectIdentifier contentType =  new DERObjectIdentifier("1.2.156.10197.6.1.4.2.104");
        AlgorithmIdentifier contentEncryptionAlgorithm = null;
        ASN1OctetString encryptedContent = null;

        if (SdkConstants.RSA_ALG_1 == caAlg.intValue()) {
/*            contentEncryptionAlgorithm = new AlgorithmIdentifier(AES128_EBC);

            try {
                Cipher cipher = Cipher.getInstance("AES/ECB/PKCS5Padding");
                cipher.init(Cipher.ENCRYPT_MODE, sessionKey);
                encryptedContent = new DEROctetString(cipher.doFinal(CertUtil.writeObjectToByteArray(encCert)));
//                RSAPrivateKey rsaPrivateKey = (RSAPrivateKey) privateKey;
//                byte[] bsS = rsaPrivateKey.getPrivateExponent().toByteArray();
//                System.out.println(new String(Base64.encode(rsaPrivateKey.getEncoded())));
//                System.out.println(new String(Base64.encode(privateKey.getEncoded())));

//                byte[] bsRequestS = new byte[appKeyLen / 8];
//                for (int i = 0; i < bsRequestS.length; i++) {
//                    if (bsS.length >= i) {
//                        bsRequestS[bsRequestS.length - 1 - i] = bsS[bsS.length - 1 - i];
//                    }
//
//                }
            } catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException("加密异常：" + e.toString());
            }*/
            contentEncryptionAlgorithm = new AlgorithmIdentifier(sm4);

            try {
                encryptedContent = new DEROctetString(Sm4.sm4_encrypt_ecb(sessionKey, SdkCertUtils.writeObjectToByteArray(encCert)));
            } catch (Exception e) {
                logger.error("加密异常", e);
                throw new Exception("加密异常");
            }
        } else {
            contentEncryptionAlgorithm = new AlgorithmIdentifier(sm4);
            byte[] bsRequsetS = SdkCertUtils.writeObjectToByteArray(encCert);
            //  System.out.println(Hex.toHexString(bsRequsetS));*/
            byte[] encData = GMSSLSM4ECBEncryptUtils.sm4SymmetricSdfWithPadding(true, SdfCryptoType.YUNHSM, SdfSymmetricKeyParameters.PaddingType.PKCS7Padding, sessionKey, bsRequsetS);
            //String encData = GMSSLSM4ECBEncryptUtils.sm4SymmetricSdfWithPadding(true, SdfCryptoType.YUNHSM, SdfSymmetricKeyParameters.PaddingType.PKCS7Padding, sessionKey.getEncoded(), bsRequsetS);
            //encryptedContent = new DEROctetString(Sm4.sm4_encrypt_ecb(sessionKey.getEncoded(), bsRequsetS));
            encryptedContent = new DEROctetString(encData);
        }

        return new EncryptedContentInfo(contentType, contentEncryptionAlgorithm, encryptedContent);
    }

    public static void printBytes(byte[] bs) {
        if (bs == null) {
            logger.info("bs is null =======================\n");
            return;
        }

        for (int i = 0; i < bs.length; i++) {
            if (i % 20 == 0) {
                System.out.printf("%4s:  ", i);
            }
            System.out.print(toHex(bs[i]));
            if (i % 10 == 9) {
                System.out.print("  ");
            }
            if (i % 20 == 19) {
                System.out.print("\n");
            } else {
                System.out.print(" ");
            }

        }
        System.out.println("\n");
    }

    public static final String toHex(byte b) {
        return ("" + "0123456789ABCDEF".charAt(0xf & b >> 4) + "0123456789ABCDEF".charAt(b & 0xf));
    }


    public static ASN1Set makeSignerInfos(byte[] raSignPriKey, boolean isUseHsm, String caAlg, String issueCaCertDN, BigInteger keyGeneraterCertSN, int privateKeyIndex,String privateKeyPassword, byte[] structData)
            throws Exception {
        SignerIdentifier sid = new SignerIdentifier(new IssuerAndSerialNumber(DnUtil.getRFC4519X500Name(issueCaCertDN), keyGeneraterCertSN));
        AlgorithmIdentifier digAlgorithm = null;
        AlgorithmIdentifier digEncryptionAlgorithm = null;
        String signature = null;
        if ("SHA-1WithRSA".equalsIgnoreCase(caAlg) || "SHA1WithRSA".equalsIgnoreCase(caAlg)||"SHA256WithRSA".equalsIgnoreCase(caAlg)) {
            digAlgorithm = new AlgorithmIdentifier(DigestObjectIdentifiers.sha1);
            digEncryptionAlgorithm = new AlgorithmIdentifier(RsaObjectIdentifiers.rsaEncryption);
        } else {
            digAlgorithm = new AlgorithmIdentifier(DigestObjectIdentifiers.sm3);
            digEncryptionAlgorithm = new AlgorithmIdentifier(SM2ObjectIdentifiers.sm2256_sign);
        }

        if(isUseHsm) {
            signature = SdkHsmUtils.signByYunHsm(caAlg, privateKeyIndex, privateKeyPassword, Base64.toBase64String(structData));
        }else {
            signature = SdkHsmUtils.signByBC(caAlg, raSignPriKey,  Base64.toBase64String(structData));
        }


        //  ASN1OctetString encryptedDigest = new DEROctetString(signature.getBytes());
        // do 密钥恢复 苏士辉要求改
        ASN1OctetString encryptedDigest = new DEROctetString(GMSSLByteArrayUtils.base64Decode(signature));

        ASN1OctetString encryptedDigestString = ASN1OctetString.getInstance(encryptedDigest);
        ASN1Set asn1Set = null;
        SignerInfo signerInfo = new SignerInfo(sid, digAlgorithm, asn1Set, digEncryptionAlgorithm, encryptedDigestString, null);

        return new DERSet(signerInfo);
    }

    /**
     * 生成符合PKCS7 singedData格式的pem证书链
     * @param certificateList
     * @author ssh
     * @return
     * @throws Exception
     */
    public static String createCertChainByCerts(List<X509Certificate> certificateList) throws Exception {

        CMSSignedDataGenerator gen = new CMSSignedDataGenerator();
        try {
            CMSProcessableByteArray msg = new CMSProcessableByteArray("".getBytes());
            JcaCertStore jcaCertStore = new JcaCertStore(certificateList);
            gen.addCertificates(jcaCertStore);
            CMSSignedData cmsSignedData = gen.generate(msg);
            return Pkcs7Utils.writeP7bPem(cmsSignedData.toASN1Structure());
        } catch (Exception e) {
            throw new Exception("创建证书链异常",e);
        }
    }
    /**
     * 解析byte数据的证书链
     * @param certChain
     * @return
     * @throws Exception
     */
    public static List<X509Certificate> resolveCertChain(String certChain) throws Exception {
        String b64Cert = certChain.replace(Pkcs7Utils.PKCS_BEGIN_HEADER, "");
        b64Cert = b64Cert.replace(Pkcs7Utils.PKCS_END_HEADER, "");
        b64Cert = b64Cert.replace("\r", "");
        b64Cert = b64Cert.replace("\n", "");
        try {
            CMSSignedData cmsSignedDataResolve = new CMSSignedData(Base64.decode(b64Cert));
            return Pkcs7Utils.getX509CertificateListFromSignedData(cmsSignedDataResolve);
        } catch (Exception e) {
            throw new Exception("解析证书链异常",e);
        }
    }
    /**
     * 从cmsSigendData获取证书链
     * @param cmsSignedData
     * @return
     */
    public static List<X509Certificate> getX509CertificateListFromSignedData(CMSSignedData cmsSignedData) throws Exception {
        List<X509Certificate> certificateList = new ArrayList<X509Certificate>();
        try {
            Store<X509CertificateHolder> store = cmsSignedData.getCertificates();

            Collection collection = store.getMatches(null);
            Iterator<X509CertificateHolder> it = collection.iterator();
            while (it.hasNext()) {
                X509CertificateHolder x509CertificateHolder = it.next();
                X509Certificate x509Certificate = new JcaX509CertificateConverter().setProvider("BC").getCertificate(x509CertificateHolder);
                certificateList.add(x509Certificate);
            }
            return certificateList;
        } catch (Exception e) {
            throw new Exception("从cmsSignedData中获取证书链异常",e);
        }
    }

    /**
     * 进行p7b的pem格式转换
     * @param contentInfo
     * @author ssh
     * @return
     * @throws Exception
     */
    public static String writeP7bPem(ContentInfo contentInfo) throws Exception {
        try {
            PemObject pemObject = new PemObject(Pkcs7Utils.PKCS7_TYPE,contentInfo.getEncoded(ASN1Encoding.DER));
            return Pkcs7Utils.writePemObject(pemObject);
        } catch (Exception e) {
            throw new Exception("将p7b对象转换为Pem格式异常",e);
        }
    }

    /**
     * 进行p7b的pem格式转换
     * @param cmsSignedData
     * @author ssh
     * @return
     * @throws Exception
     */
    public static String writeP7bPem(CMSSignedData cmsSignedData) throws Exception {
        try {
            PemObject pemObject = new PemObject(Pkcs7Utils.PKCS7_TYPE,cmsSignedData.toASN1Structure().getEncoded(ASN1Encoding.DER));
            return Pkcs7Utils.writePemObject(pemObject);
        } catch (Exception e) {
            throw new Exception("将p7b对象转换为Pem格式异常",e);
        }
    }

    /**
     * 打印pemObject
     * @param pemObject
     * @return
     * @throws Exception
     */
    public static String writePemObject(PemObject pemObject) throws Exception {
        PemWriter pemWriter = null;
        try {
            StringWriter stringWriter = new StringWriter();
            pemWriter = new PemWriter(stringWriter);
            pemWriter.writeObject(pemObject);
            pemWriter.flush();
            return stringWriter.toString();
        } catch (Exception e) {
            throw new Exception("打印pemObject对象异常",e);
        } finally {
            if (null != pemWriter) {
                pemWriter.close();
            }
        }
    }
}