package com.xdja.ra.utils;


import com.xdja.pki.gmssl.core.utils.GMSSLX509Utils;
import org.bouncycastle.asn1.*;
import org.bouncycastle.asn1.cmp.CMPCertificate;
import org.bouncycastle.asn1.cms.ContentInfo;
import org.bouncycastle.asn1.pkcs.SignedData;
import org.bouncycastle.asn1.x500.X500Name;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.jce.ECNamedCurveTable;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jce.spec.ECParameterSpec;
import org.bouncycastle.jce.spec.ECPublicKeySpec;
import org.bouncycastle.math.ec.ECCurve;
import org.bouncycastle.util.BigIntegers;
import org.bouncycastle.util.encoders.Base64;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.naming.NamingException;
import javax.security.auth.x500.X500Principal;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStreamWriter;
import java.io.StringWriter;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.PublicKey;
import java.security.Security;
import java.security.cert.Certificate;
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.spec.RSAPublicKeySpec;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;

import static com.xdja.pki.gmssl.core.utils.GMSSLX509Utils.ECC_SM2_NAME;

public class CertUtils {

    private static Logger logger = LoggerFactory.getLogger(CertUtils.class);

    public static final String CERT_HEAD = "-----BEGIN CERTIFICATE-----";
    public static final String CERT_TAIL = "-----END CERTIFICATE-----";

    public static final String PUBLIC_KEY_HEAD = "-----BEGIN PUBLIC KEY-----";
    public static final String PUBLIC_KEY_TAIL = "-----END PUBLIC KEY-----";

    private static String provider;
    static {
        BouncyCastleProvider bouncyCastleProvider = new BouncyCastleProvider();
        provider = bouncyCastleProvider.getName();
    }
    /**
     * 将各种编码的字符串转换成证书
     *
     * @param str
     * @return
     */
    public static X509Certificate getCertFromStr(String str) {
        str = str.replace(CERT_HEAD, "").replace(CERT_TAIL, "");
        str = str.replace("\r", "").replace("\n", "");
        str = str.replace("\\r", "").replace("\\n", "");

        X509Certificate x509Cert;
        x509Cert = getCertFromB64(str);
        if (x509Cert == null) {
            x509Cert = getCertFromNormalStr(str);
        }
        if (x509Cert == null) {
            x509Cert = getCertFromStr16(str);
        }
        return x509Cert;
    }

    /**
     * 从Base64进制字符串获取证书对象
     *
     * @param b64
     * @return
     */
    private synchronized static X509Certificate getCertFromB64(String b64) {
        CertificateFactory cf;
        X509Certificate x509Cert;
        try {
            cf = CertificateFactory.getInstance("X.509", provider);
            byte[] bsCert = Base64.decode(b64);
            InputStream inStream = new ByteArrayInputStream(bsCert);
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
            return x509Cert;
        } catch (Exception e) {
            System.err.println("getCertFromB64 error: " + e.toString());
        }
        return null;
    }

    /**
     * 从16进制字符串获取证书对象
     *
     * @param str
     * @return
     */
    private synchronized static X509Certificate getCertFromStr16(String str) {
        byte[] bs = hex2byte(str);
        CertificateFactory cf;
        X509Certificate x509Cert;
        try {
            cf = CertificateFactory.getInstance("X.509", provider);
            InputStream inStream = new ByteArrayInputStream(bs);
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
            return x509Cert;
        } catch (Exception e) {
            System.err.println("getCertFromFullStr error: " + e.toString());
        }
        return null;
    }

    /**
     * 16进制字符串转换为二进制
     *
     * @param str String 类型参数
     * @return
     */
    public static byte[] hex2byte(String str) {
        if (null == str || str.equals("")) {
            return null;
        }
        str = str.trim();
        StringBuffer sb = new StringBuffer(str);
        int len = sb.length();
        if (len == 0 || len % 2 == 1) {
            return null;
        }
        byte[] b = new byte[len / 2];
        try {
            for (int i = 0; i < len; i += 2) {
                b[i / 2] = (byte) Integer.decode("0x" + sb.substring(i, i + 2)).intValue();
            }
            return b;
        } catch (Exception e) {
            return null;
        }
    }

    /**
     * 根据证书字符串生成证书实体
     *
     * @param str
     * @return
     */
    private synchronized static X509Certificate getCertFromNormalStr(String str) {
        CertificateFactory cf;
        X509Certificate x509Cert;
        try {
            cf = CertificateFactory.getInstance("X.509", provider);
            InputStream inStream = new ByteArrayInputStream(str.getBytes());
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
            return x509Cert;
        } catch (Exception e) {
            System.err.println("getCertFromFullStr error: " + e.toString());
        }
        return null;
    }

    /**
     * 将base64转换为publicKey
     *
     * @param base64
     * @return
     * @throws Exception
     */
    public static PublicKey convertSM2PublicKey(String base64) throws Exception {
        base64 = base64.replace(PUBLIC_KEY_HEAD, "").replace(PUBLIC_KEY_TAIL, "");
        base64 = base64.replace("\r", "").replace("\n", "");
        base64 = base64.replace("\\r", "").replace("\\n", "");
        int digestLen = 32;
        byte[] b = Base64.decode(base64);
        byte[] x = new byte[digestLen];
        System.arraycopy(b, 1, x, 0, digestLen);
        byte[] y = new byte[digestLen];
        System.arraycopy(b, digestLen + 1, y, 0, digestLen);
        return convertSM2PublicKey(x, y);
    }

    public static PublicKey convertSM2PublicKey(byte[] x, byte[] y) throws Exception {

        ECParameterSpec ecParameterSpec = ECNamedCurveTable.getParameterSpec(ECC_SM2_NAME);
        ECCurve curve = ecParameterSpec.getCurve();
        ECPublicKeySpec ecPublicKeySpec = new ECPublicKeySpec(curve.createPoint(BigIntegers.fromUnsignedByteArray(x), BigIntegers.fromUnsignedByteArray(y)), ecParameterSpec);
        System.out.println("==========" + ecPublicKeySpec.getClass().getName().toString());
        /*KeyFactory keyFactory = KeyFactory.getInstance(GMSSLX509Utils.ECC_ALGORITHM_NAME, BouncyCastleProvider.PROVIDER_NAME);
        return keyFactory.generatePublic(ecPublicKeySpec);*/
        return new BCECPublicKey(ECC_SM2_NAME, ecPublicKeySpec, BouncyCastleProvider.CONFIGURATION);
    }

    static {
        try {
            Security.addProvider(new BouncyCastleProvider());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 解析证书链
     *
     * @param
     * @return
     */
    public static List<X509Certificate> getCertListFromB64(byte[] info) {

        CertificateFactory cf = null;
        List<X509Certificate> certificates = null;
        try {
            InputStream inStream = new ByteArrayInputStream(info);
            cf = CertificateFactory.getInstance("X.509", "BC");
            certificates = (List<X509Certificate>) cf.generateCertificates(inStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (!(certificates.size()==0||null == certificates)) {
            certificates = sortCerts(certificates);
            return certificates;
        }

        String str = new String(info);
        str = str.replace(CERT_HEAD, "").replace(CERT_TAIL, "");
        str = str.replace("\r", "").replace("\n", "");
        str = str.replace("\\r", "").replace("\\n", "");
        try {
            byte[] certByte = Base64.decode(str);
            if (certByte == null || certByte.length == 0) {
                certByte = hex2byte(str);
            }
            InputStream inStream = new ByteArrayInputStream(certByte);
            certificates = (List<X509Certificate>) cf.generateCertificates(inStream);
        } catch (Exception e) {
            e.printStackTrace();
        }

        if (!(null == certificates || certificates.size() == 0)) {
            certificates = sortCerts(certificates);
            return certificates;
        }
        return certificates;

    }

    /**
     * 获取已排序证书链
     *
     * @param b64P7b
     * @return
     */
    public static List<Certificate> getSortCertListFromB64(String b64P7b) {
        CertificateFactory cf;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            byte[] bsCert = Base64.decode(b64P7b);
            InputStream inStream = new ByteArrayInputStream(bsCert);
            List<Certificate> certificates = (List<Certificate>) cf.generateCertificates(inStream);
            Iterator<Certificate> it = certificates.iterator();
            List<Certificate> list = new ArrayList<>();
            while (it.hasNext()) {
                Certificate elem = it.next();
                list.add(elem);
            }
            list = sortCerts(list);
            return list;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }


    /**
     * 对证书链进行排序
     * 按照证书链倒序排列
     * 集合第一个为用户证书  最后一个为根证书
     *
     * @param certs
     * @return
     */
    public static List sortCerts(List certs) {
        if (certs.size() < 2) {
            return certs;
        }

        X500Principal issuer = ((X509Certificate) certs.get(0)).getIssuerX500Principal();
        boolean okay = true;

        for (int i = 1; i != certs.size(); i++) {
            X509Certificate cert = (X509Certificate) certs.get(i);

            if (issuer.equals(cert.getSubjectX500Principal())) {
                issuer = ((X509Certificate) certs.get(i)).getIssuerX500Principal();
            } else {
                okay = false;
                break;
            }
        }

        if (okay) {
            return certs;
        }

        // find end-entity cert
        List retList = new ArrayList(certs.size());
        List orig = new ArrayList(certs);

        for (int i = 0; i < certs.size(); i++) {
            X509Certificate cert = (X509Certificate) certs.get(i);
            boolean found = false;

            X500Principal subject = cert.getSubjectX500Principal();

            for (int j = 0; j != certs.size(); j++) {
                X509Certificate c = (X509Certificate) certs.get(j);
                if (c.getIssuerX500Principal().equals(subject)) {
                    found = true;
                    break;
                }
            }

            if (!found) {
                retList.add(cert);
                certs.remove(i);
            }
        }

        // can only have one end entity cert - something's wrong, give up.
       /* if (retList.size() > 1) {
            return orig;
        }
*/
        for (int i = 0; i != retList.size(); i++) {
            issuer = ((X509Certificate) retList.get(i)).getIssuerX500Principal();
            for (int j = 0; j < certs.size(); j++) {
                X509Certificate c = (X509Certificate) certs.get(j);
                if (issuer.equals(c.getSubjectX500Principal())) {
                    retList.add(c);
                    certs.remove(j);
                    break;
                }
            }
        }
        // make sure all certificates are accounted for.
        if (certs.size() > 0) {
            return orig;
        }
        return retList;
    }


    /**
     * 验证用户证书与CA证书是否是同一CA签发
     *
     * @param userCert 用户证书
     * @param caCert   CA证书
     * @return true-是；false-否
     */
    public static final boolean verifyCertIssueCa(String userCert, String caCert) {
        X509Certificate userCert1 = getCertFromStr(userCert);
        X509Certificate caCert1 = getCertFromStr(caCert);

        try {
            userCert1.verify(caCert1.getPublicKey());
        } catch (Exception e) {
            return false;
        }

        return true;
    }

    /**
     * 从指定路径加载p7b文件获取证书集合 (已排序)
     *
     * @param bytes
     * @return
     */
    public static List<X509Certificate> getCertListFromP7b(byte[] bytes) {
        List<X509Certificate> list = new ArrayList<>();
        try {
            ASN1Sequence asn1Sequence = ASN1Sequence.getInstance(bytes);
            ContentInfo contentInfo = new ContentInfo(asn1Sequence);
            SignedData instance = SignedData.getInstance(contentInfo.getContent());

            ASN1Set instanceCertificates = instance.getCertificates();

            Enumeration objects = instanceCertificates.getObjects();

            while (objects.hasMoreElements()) {
                ASN1Encodable o = (ASN1Encodable) objects.nextElement();
                InputStream inStream = new ASN1InputStream(o.toASN1Primitive().getEncoded());
                CertificateFactory cf = CertificateFactory.getInstance("X.509", "BC");
                X509Certificate certificate = (X509Certificate) cf.generateCertificate(inStream);
                list.add(certificate);
            }

        } catch (Exception e) {
            throw new RuntimeException();
        }
        List sortCerts = CertUtils.sortCerts(list);
        return sortCerts;
    }

    /**
     * 转换上传的证书文件成为证书对象
     *
     * @param certder
     * @return
     */
    public static X509Certificate convertUploadFileToCert(byte[] certder) {
        // 解析二进制证书流
        CertificateFactory cf = null;
        X509Certificate x509Cert = null;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            InputStream inStream = new ByteArrayInputStream(certder);
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (null != x509Cert) {
            return x509Cert;
        }
        String str = new String(certder);
        str = str.replace(CERT_HEAD, "").replace(CERT_TAIL, "");
        str = str.replace("\r", "").replace("\n", "");
        str = str.replace("\\r", "").replace("\\n", "");
        // 解析base64编码的证书流
        x509Cert = getCertFromB64(str);
        if (x509Cert == null) {
            // 解析16进制编码的证书流
            x509Cert = getCertFromStr16(str);
        }
        return x509Cert;
    }

    /**
     * 将Certificate转换为CMPCertificate
     *
     * @param cert
     * @return
     * @throws CertificateEncodingException
     * @throws IOException
     */
    public static CMPCertificate[] getCMPCert(Certificate cert) throws CertificateEncodingException, IOException {
        ASN1InputStream ins = new ASN1InputStream(cert.getEncoded());
        try {
            ASN1Primitive pcert = ins.readObject();
            org.bouncycastle.asn1.x509.Certificate c = org.bouncycastle.asn1.x509.Certificate.getInstance(pcert.toASN1Primitive());
            CMPCertificate[] res = {new CMPCertificate(c)};
            return res;
        } finally {
            ins.close();
        }
    }

    /**
     * 将对象转换为pem格式
     *
     * @param obj
     * @return
     * @throws Exception
     */
    public static String writeObject(Object obj) throws Exception {
        StringWriter stringWriter = new StringWriter();
        GMSSLX509Utils.writePEM(obj, stringWriter);
        return stringWriter.toString();
    }


    /**
     * 获取证书中的使用者DN
     * @param x509Certificate
     * @return
     */
    public static String getSubjectByX509Cert(X509Certificate x509Certificate) throws NamingException {
        X500Name x500Name;
        try {
            byte[] encoded = x509Certificate.getTBSCertificate();
            ASN1Sequence seq = ASN1Sequence.getInstance(encoded);
            int seqStart = 0;
            if (!(seq.getObjectAt(0) instanceof ASN1TaggedObject)) {
                // field 0 is missing!
                seqStart = -1;
            }
            ASN1Encodable objectAt = seq.getObjectAt(seqStart + 5);

            x500Name = X500Name.getInstance(RFC4519StyleUpperCase.INSTANCE, objectAt.toASN1Primitive());
        } catch (CertificateEncodingException e) {
            logger.error("从x509证书中获取使用者DN异常", e);
            return null;
        }
        return x500Name.toString();
    }

    /**
     * 将证书转化为byte数组
     *
     * @param certificate
     * @return
     * @throws Exception
     */
    public static byte[] writeObjectToByteArray(X509Certificate certificate) throws Exception {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        OutputStreamWriter outputStreamWriter = new OutputStreamWriter(byteArrayOutputStream);
        GMSSLX509Utils.writePEM(certificate, outputStreamWriter);
        return byteArrayOutputStream.toByteArray();
    }


    /**
     * 获取证书中公钥长度
     *
     * @param base64CertBytes
     * @return
     */
    public static int getPublicKeyLength(byte[] base64CertBytes) throws Exception {
        X509Certificate certificate = CertUtils.getCertFromStr(new String(base64CertBytes));
        PublicKey publicKey = certificate.getPublicKey();
        KeyFactory keyFactory = KeyFactory.getInstance(publicKey.getAlgorithm());
        String algorithm = certificate.getPublicKey().getAlgorithm();
        if ("RSA".equalsIgnoreCase(algorithm)) { // 如果是RSA加密
            RSAPublicKeySpec keySpec = keyFactory.getKeySpec(certificate.getPublicKey(), RSAPublicKeySpec.class);
            BigInteger modulus = keySpec.getModulus();
            return modulus.bitLength();
        } else if ("EC".equalsIgnoreCase(algorithm)) {
            return 256;
        } else {
            throw new Exception();
        }
    }

    /**
     * 获取证书中公钥长度
     *
     * @param certificate
     * @return
     */
    public static int getPublicKeyLength(X509Certificate certificate) throws Exception {
        byte[] bytes = CertUtils.writeObjectToByteArray(certificate);
        int publicKeyLength = CertUtils.getPublicKeyLength(bytes);
        return publicKeyLength;
    }


}