package com.xdja.ca.utils;

import org.bouncycastle.asn1.x509.*;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.util.encoders.Base64;
import org.bouncycastle.x509.X509V3CertificateGenerator;
import org.bouncycastle.x509.extension.AuthorityKeyIdentifierStructure;

import javax.security.auth.x500.X500Principal;
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.math.BigInteger;
import java.security.PrivateKey;
import java.security.PublicKey;
import java.security.Security;
import java.security.cert.Certificate;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Date;
import java.util.Iterator;
import java.util.List;


/**
 * 证书工具类
 *
 * @author wyf
 *
 */

public class SdkCertUtils {

    public static final String CERT_HEAD = "-----BEGIN CERTIFICATE-----";
    public static final String CERT_TAIL = "-----END CERTIFICATE-----";

    public static final String PKCS7_HEAD = "-----BEGIN PKCS7-----";
    public static final String PKCS7_TAIL = "-----END PKCS7-----";

    static {
        if (Security.getProvider("BC") == null) {
            Security.addProvider(new BouncyCastleProvider());
        }
    }

    /**
     * 从16进制字符串获取证书对象
     *
     * @param str
     * @return
     */
    public static X509Certificate getCertFromStr16(String str) {
        byte[] bs = hex2byte(str);
        CertificateFactory cf;
        X509Certificate x509Cert;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            InputStream inStream = new ByteArrayInputStream(bs);
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
            return x509Cert;
        } catch (Exception e) {
            System.err.println("getCertFromFullStr error: " + e.toString());
        }
        return null;
    }

    /**
     * 16进制字符串转换为二进制
     *
     * @param str String 类型参数
     * @return
     */
    public static byte[] hex2byte(String str) {
        if (null == str || str.equals("")) {
            return null;
        }
        str = str.trim();
        StringBuffer sb = new StringBuffer(str);
        int len = sb.length();
        if (len == 0 || len % 2 == 1) {
            return null;
        }
        byte[] b = new byte[len / 2];
        try {
            for (int i = 0; i < len; i += 2) {
                b[i / 2] = (byte) Integer.decode("0x" + sb.substring(i, i + 2)).intValue();
            }
            return b;
        } catch (Exception e) {
            return null;
        }
    }

    /**
     * 将各种编码的字符串转换成证书
     * @param str
     * @return
     */
    public static X509Certificate getCertFromStr(String str) {
        str = str.replace(CERT_HEAD, "").replace(CERT_TAIL, "");
        str = str.replace("\r", "").replace("\n", "");
        str = str.replace("\\r", "").replace("\\n", "");

        X509Certificate x509Cert;
        x509Cert = getCertFromB64(str);
        if (x509Cert == null) {
            x509Cert = getCertFromFullStr(str);
        }
        if (x509Cert == null) {
            x509Cert = getCertFromStr16(str);
        }
        return x509Cert;
    }

    /**
     * 根据证书字符串生成证书实体
     *
     * @param str
     * @return
     */
    public static X509Certificate getCertFromFullStr(String str) {
        CertificateFactory cf;
        X509Certificate x509Cert;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            InputStream inStream = new ByteArrayInputStream(str.getBytes());
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
            return x509Cert;
        } catch (Exception e) {
            System.err.println("getCertFromFullStr error: " + e.toString());
        }
        return null;
    }

    /**
     * 从Base64进制字符串获取证书对象
     *
     * @param b64
     * @return
     */
    public synchronized static X509Certificate getCertFromB64(String b64) {
        CertificateFactory cf;
        X509Certificate x509Cert;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            byte[] bsCert = Base64.decode(b64);
            InputStream inStream = new ByteArrayInputStream(bsCert);
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
            return x509Cert;
        } catch (Exception e) {
            System.err.println("getCertFromB64 error: " + e.toString());
        }
        return null;
    }


    /**
     * 获取证书sn
     *
     * @param cert
     * @return
     */
    public static String getSn(X509Certificate cert) {
        return cert.getSerialNumber().toString(16).toLowerCase();
    }


    /**
     * 获取证书sn 补零
     *
     * @param cert
     * @return
     */
    public static String getSnFillZero(X509Certificate cert) {
        String sn = cert.getSerialNumber().toString(16).toLowerCase();
        int length = sn.length();
        if (length % 2 != 0) {
            sn = "0" + sn;
        }
        return sn;
    }

    /**
     * 将证书实体转传承Base64编码字符串
     *
     * @param cert
     * @return
     */
    public static String certToFullB64(Certificate cert) {
        if (cert==null){
            return null;
        }
        String certb64 = "";
        try {

            certb64 = bytesToFullB64(cert.getEncoded());

        } catch (Exception e) {
            System.err.println("certToFullB64 error:" + e.toString());
        }
        return certb64;
    }

    /**
     * 将证书字节转化成Base64编码字符串
     *
     * @param certder
     * @return
     */
    public static String bytesToFullB64(byte[] certder) {
        String certb64;
        try {

            certb64 = new String(Base64.encode(certder));
            certb64 = CERT_HEAD + "\n" + certb64 + "\n" + CERT_TAIL + "\n";
            return certb64;
        } catch (Exception e) {
            System.err.println("certDerToFullB64 error:" + e.toString());
        }
        return null;
    }

    public static String bytesToFullB642(byte[] certder) {
        String certb64;
        try {
            certb64 = new String(Base64.encode(certder));
            certb64 = CERT_HEAD + "\n" + certb64 + "\n" + CERT_TAIL + "\n";
            return certb64;
        } catch (Exception e) {
            System.err.println("certDerToFullB64 error:" + e.toString());
        }
        return null;
    }


    /**
     * 转换上传的证书文件成为证书对象
     * @param certder
     * @return
     */
    public static X509Certificate convertUploadFileToCert(byte[] certder) {
        // 解析二进制证书流
        CertificateFactory cf = null;
        X509Certificate x509Cert = null;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            InputStream inStream = new ByteArrayInputStream(certder);
            x509Cert = (X509Certificate) cf.generateCertificate(inStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (null != x509Cert) {
            return x509Cert;
        }

        String str = new String(certder);
        str = str.replace(CERT_HEAD, "").replace(CERT_TAIL, "");
        str = str.replace("\r", "").replace("\n", "");
        str = str.replace("\\r", "").replace("\\n", "");

        // 解析base64编码的证书流
        x509Cert = getCertFromB64(str);
        if (x509Cert == null) {
            // 解析16进制编码的证书流
            x509Cert = getCertFromStr16(str);
        }
        return x509Cert;
    }
    
    /**
     * 16进制证书字符串转换为Base64字符串
     * 
     * @param hexStr 待转换16进制字符串
     * @return 对应的Base64字符串
     */
	public final static String convertHexStr2Base64(String hexStr) {
		byte[] data = hexStr2Bytes(hexStr);
		return new String(Base64.encode(data));
	}
    
    /**
     * 16进制字符串转换二进制
     * 
     * @param hexStr 待转换16进制字符串
     * @return 转换后的二进制
     */
	private final static byte[] hexStr2Bytes(String hexStr) {
		hexStr = hexStr.toUpperCase();

		int length = hexStr.length() / 2;
		char[] hexChars = hexStr.toCharArray();
		byte[] data = new byte[length];
		
		for (int i = 0; i < length; i++) {
			int pos = i * 2;
			data[i] = (byte) (charToByte(hexChars[pos]) << 4 | charToByte(hexChars[pos + 1]));
		}

		return data;
	}
    
    /**
     * 转换字符为byte
     * 
     * @param c 带转换字符
     * @return 字符对应byte
     */
	private final static byte charToByte(char c) {
		return (byte) "0123456789ABCDEF".indexOf(c);
	}

    /**
     * 是否是签名证书
     *
     * @param cert 证书base64字符串
     * @return true-是；false-否
     */
	public final static boolean isSignCert(String cert) {
		X509Certificate x509 = getCertFromB64(cert);
		if (null == x509) {
			throw new IllegalArgumentException("证书转换非法");
		}

		return isSignCert(x509);
	}

    /**
     * 是否是签名证书
     *
     * @param cert 证书对象
     * @return true-是；false-否
     */
    public final static boolean isSignCert(X509Certificate cert) {
        boolean[] keyUsages = cert.getKeyUsage();

        return keyUsages[0] || keyUsages[1];
    }

    /**
     * 是否是加密证书
     *
     * @param cert 证书base64字符串
     * @return true-是；false-否
     */
    public final static boolean isEncCert(String cert) {
    	X509Certificate x509 = getCertFromB64(cert);
		if (null == x509) {
			throw new IllegalArgumentException("证书转换非法");
		}
		
        return isEncCert(x509);
    }

    /**
     * 是否是加密证书
     *
     * @param cert 证书对象
     * @return true-是；false-否
     */
    public final static boolean isEncCert(X509Certificate cert) {
        boolean[] keyUsages = cert.getKeyUsage();

        return keyUsages[2] || keyUsages[3] || keyUsages[7] || keyUsages[8];
    }




    /**
     * 解析证书链
     * @param
     * @return
     */
    public static List<X509Certificate> getCertListFromB64(byte[] info) {

        CertificateFactory cf = null;
        List<X509Certificate> certificates = null;
        try {
            InputStream inStream = new ByteArrayInputStream(info);
            cf = CertificateFactory.getInstance("X.509", "BC");
            certificates = (List<X509Certificate>) cf.generateCertificates(inStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
        if (null != certificates && certificates.size()!=0) {
            certificates = sortCerts(certificates);
            return certificates;
        }

        String str = new String(info);
        str = str.replace(CERT_HEAD, "").replace(CERT_TAIL, "");
        str = str.replace("\r", "").replace("\n", "");
        str = str.replace("\\r", "").replace("\\n", "");
        try {
            byte[] certByte = Base64.decode(str);
            if (certByte == null || certByte.length == 0) {
                certByte = hex2byte(str);
            }
            InputStream inStream = new ByteArrayInputStream(certByte);
            certificates = (List<X509Certificate>) cf.generateCertificates(inStream);
        } catch (Exception e) {
            e.printStackTrace();
        }

        if (null != certificates) {
            certificates = sortCerts(certificates);
            return certificates;
        }
        return certificates;

    }

    /**
     * 获取已排序证书链
     *
     * @param b64P7b
     * @return
     */
    public static List<Certificate> getSortCertListFromB64(String b64P7b) {
        CertificateFactory cf;
        try {
            cf = CertificateFactory.getInstance("X.509", "BC");
            byte[] bsCert = Base64.decode(b64P7b);
            InputStream inStream = new ByteArrayInputStream(bsCert);
            List<Certificate> certificates = (List<Certificate>) cf.generateCertificates(inStream);
            Iterator<Certificate> it = certificates.iterator();
            List<Certificate> list = new ArrayList<Certificate>();
            while (it.hasNext()) {
                Certificate elem = it.next();
                list.add(elem);
            }
            list = sortCerts(list);
            return list;
        } catch (Exception e) {
            e.printStackTrace();
            return null;
        }
    }


    /**
     * 对证书链进行排序
     * 按照证书链倒序排列
     * 集合第一个为用户证书  最后一个为根证书
     * @param certs
     * @return
     */
    public static List sortCerts(List certs) {
        if (certs.size() < 2) {
            return certs;
        }

        X500Principal issuer = ((X509Certificate) certs.get(0)).getIssuerX500Principal();
        boolean okay = true;

        for (int i = 1; i != certs.size(); i++) {
            X509Certificate cert = (X509Certificate) certs.get(i);

            if (issuer.equals(cert.getSubjectX500Principal())) {
                issuer = ((X509Certificate) certs.get(i)).getIssuerX500Principal();
            } else {
                okay = false;
                break;
            }
        }

        if (okay) {
            return certs;
        }

        // find end-entity cert
        List retList = new ArrayList(certs.size());
        List orig = new ArrayList(certs);

        for (int i = 0; i < certs.size(); i++) {
            X509Certificate cert = (X509Certificate) certs.get(i);
            boolean found = false;

            X500Principal subject = cert.getSubjectX500Principal();

            for (int j = 0; j != certs.size(); j++) {
                X509Certificate c = (X509Certificate) certs.get(j);
                if (c.getIssuerX500Principal().equals(subject)) {
                    found = true;
                    break;
                }
            }

            if (!found) {
                retList.add(cert);
                certs.remove(i);
            }
        }

        // can only have one end entity cert - something's wrong, give up.
        if (retList.size() > 1) {
            return orig;
        }

        for (int i = 0; i != retList.size(); i++) {
            issuer = ((X509Certificate) retList.get(i)).getIssuerX500Principal();
            for (int j = 0; j < certs.size(); j++) {
                X509Certificate c = (X509Certificate) certs.get(j);
                if (issuer.equals(c.getSubjectX500Principal())) {
                    retList.add(c);
                    certs.remove(j);
                    break;
                }
            }
        }
        // make sure all certificates are accounted for.
        if (certs.size() > 0) {
            return orig;
        }
        return retList;
    }
    
    
    /**
     * 验证用户证书与CA证书是否是同一CA签发
     * 
     * @param userCert 用户证书
     * @param caCert CA证书
     * @return true-是；false-否
     */
	public static final boolean verifyCertIssueCa(String userCert, String caCert) {
		X509Certificate userCert1 = getCertFromStr(userCert);
		X509Certificate caCert1 = getCertFromStr(caCert);

		try {
			userCert1.verify(caCert1.getPublicKey());
		} catch (Exception e) {
			return false;
		}

		return true;
	}
    public static X509Certificate generateEndEntitySignCert(PublicKey entityKey, PrivateKey caKey, X509Certificate caCert,
                                                            String endDN, int validity) throws Exception {
        X509V3CertificateGenerator certGen = new X509V3CertificateGenerator();
        Date[] ds = computeNotBeforeAndAfter(validity, caCert);
        BigInteger certSn = new BigInteger("11111");
        certGen.setSerialNumber(certSn);
        certGen.setIssuerDN(caCert.getSubjectX500Principal());
        certGen.setNotBefore(ds[0]);
        certGen.setNotAfter(ds[1]);
        certGen.setSubjectDN(new X500Principal(endDN));
        certGen.setPublicKey(entityKey);
        certGen.setSignatureAlgorithm("SM3WithSM2");

        certGen.addExtension(X509Extensions.AuthorityKeyIdentifier, false, new AuthorityKeyIdentifierStructure(caCert
                .getPublicKey()));
        certGen.addExtension(X509Extensions.SubjectKeyIdentifier, false, SubjectKeyIdentifier.getInstance(entityKey));
        certGen.addExtension(X509Extensions.BasicConstraints, false, new BasicConstraints(false));

//        CRLDistPoint cRLDistPoint = genCRLDistPoint(certSn, caAlg);
//        certGen.addExtension(X509Extension.cRLDistributionPoints, false, cRLDistPoint);
        certGen.addExtension(X509Extension.keyUsage, true, new KeyUsage(KeyUsage.digitalSignature | KeyUsage.nonRepudiation));
        X509Certificate signCert = certGen.generate(caKey, "BC");
        return signCert;
    }
    public static Date[] computeNotBeforeAndAfter(int validity, X509Certificate caCert) {

        if (validity < 1 || validity > 10950) {
            validity = 10950;
        }

        Date notBeforeDate = new Date();
        Date notAfteDate = null;

        long notAfter = notBeforeDate.getTime() + 24L * 60 * 60 * 1000 * validity;

        if (notAfter > caCert.getNotAfter().getTime()) {
            notAfter = caCert.getNotAfter().getTime();
        }

        if (notBeforeDate.getTime() > notAfter) {
            return null;
        }
        notAfteDate = new Date(notAfter);

        return new Date[]{notBeforeDate, notAfteDate};
    }

}
