package com.xdja.pki.oer.gbt.asn1.utils;

import com.xdja.pki.gmssl.core.utils.GMSSLX509Utils;
import com.xdja.pki.oer.core.ByteArrayUtils;
import com.xdja.pki.oer.core.TimeUtils;
import com.xdja.pki.oer.core.calculate.CalculateFactory;
import com.xdja.pki.oer.core.calculate.CalculateService;
import com.xdja.pki.oer.gbt.asn1.*;
import com.xdja.pki.oer.gbt.asn1.utils.bean.OERCertificate;
import com.xdja.pki.oer.gbt.asn1.utils.bean.OEREccPoint;
import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
import org.bouncycastle.math.ec.custom.gm.SM2P256V1Curve;
import org.bouncycastle.util.BigIntegers;
import org.bouncycastle.util.encoders.Hex;

import java.math.BigInteger;
import java.security.PublicKey;
import java.util.Date;

/**
 * @ClassName OERUtils
 * @Description TODO
 * @Date 2020/3/23 16:37
 * @Author FengZhen
 */
public class OERUtils {
    private static CalculateService calculateService = CalculateFactory.getInstance();

    /**
     * 校验公钥方法
     *
     * @param key 128位字符串
     */
    public static boolean checkPublicKey(String key) {
        try {
            getPublicKeyFromStr(key);
            return true;
        } catch (Exception e) {
            e.printStackTrace();
            return false;
        }
    }

    public static PublicKey getPublicKeyFromStr(String key) throws Exception {
        String x = key.substring(0, 64);
        String y = key.substring(64);
        byte[] xCrood = Hex.decode(x);
        byte[] yCrood = Hex.decode(y);
        return GMSSLX509Utils.convertSM2PublicKey(xCrood, yCrood);

    }

    /**
     * 验证OER证书签名
     *
     * @param verifyCertData 需要验签的证书
     * @param signCertData   签发者证书
     * @return true ? false 是否验证成功
     * @throws Exception 解析转换证书过程中抛出的异常
     */
    public static boolean verifyOerSignature(byte[] verifyCertData, byte[] signCertData) throws Exception {
        OERCertificate signCert = CertificateHolder.build(signCertData);

        if (((BCECPublicKey) signCert.getSignPublicKey()).getParameters().getCurve() instanceof SM2P256V1Curve) {
            return verifySM2(verifyCertData, signCertData, signCert.getSignPublicKey());
        }
        return verifyECDSA(verifyCertData, signCertData, signCert.getSignPublicKey());
    }


    public static boolean verifySecuredMessageSignatureByCert(Certificate cert, SecuredMessage securedMessage) throws Exception {
        SignedData signedData = securedMessage.getPayload().getSignedData();
        byte[] encode = signedData.getTbs().getEncode();
        PublicVerifyKey verifyKey = cert.getTbsCert().getSubjectAttribute().getVerifyKey();
        OEREccPoint oerSignEccPoint = EccPointHolder.build(verifyKey.getEccPoint().getEncode(), verifyKey.getEccCurve());
        PublicKey publicKey = oerSignEccPoint.getPublicKey();
        byte[] tbsHash = null;
        byte[] certHash = null;
        if (((BCECPublicKey) publicKey).getParameters().getCurve() instanceof SM2P256V1Curve) {
            tbsHash = calculateService.sm3Hash(encode);
            certHash = calculateService.sm3Hash(cert.getEncode());
        } else {
            tbsHash = calculateService.sha256Hash(encode);
            certHash = calculateService.sha256Hash(cert.getEncode());
        }
        byte[] hash = ByteArrayUtils.buildUpByte(tbsHash, certHash);
        OERCertificate build = CertificateHolder.build(cert.getEncode());
        return SignatureVerify.verify(build.getSignPublicKey(), hash, signedData.getSign());
    }


    public static boolean verifySecuredMessageSignatureBySelf(PublicKey publicKey, SecuredMessage securedMessage) throws Exception {
        SignedData signedData = securedMessage.getPayload().getSignedData();
        byte[] encode = signedData.getTbs().getEncode();
        byte[] tbsHash = null;
        byte[] nullHash = null;
        if (((BCECPublicKey) publicKey).getParameters().getCurve() instanceof SM2P256V1Curve) {
            tbsHash = calculateService.sm3Hash(encode);
            nullHash = calculateService.sm3Hash("".getBytes());
        } else {
            tbsHash = calculateService.sha256Hash(encode);
            nullHash = calculateService.sha256Hash("".getBytes());
        }
        byte[] hash = ByteArrayUtils.buildUpByte(tbsHash, nullHash);
        return SignatureVerify.verify(publicKey, hash, signedData.getSign());
    }

    private static boolean verifySM2(byte[] verifyCertData, byte[] signCertData, PublicKey publicKey) throws Exception {
        byte[] encode = Certificate.getInstance(verifyCertData).getTbsCert().getEncode();
        byte[] tbsHash = calculateService.sm3Hash(encode);
        if (isEqual(verifyCertData, signCertData)) {
            byte[] nullHash = calculateService.sm3Hash("".getBytes());
            byte[] hash = ByteArrayUtils.buildUpByte(tbsHash, nullHash);
            return SignatureVerify.verify(publicKey, hash,
                    Certificate.getInstance(verifyCertData).getSignature());
        } else {
            byte[] signCertHash = calculateService.sm3Hash(signCertData);
            byte[] hash = ByteArrayUtils.buildUpByte(tbsHash, signCertHash);
            return SignatureVerify.verify(publicKey, hash,
                    Certificate.getInstance(verifyCertData).getSignature());
        }
    }

    private static boolean verifyECDSA(byte[] verifyCertData, byte[] signCertData, PublicKey publicKey) throws Exception {
        byte[] encode = Certificate.getInstance(verifyCertData).getTbsCert().getEncode();
        byte[] tbsHash = calculateService.sha256Hash(encode);
        if (isEqual(verifyCertData, signCertData)) {
            byte[] nullHash = calculateService.sha256Hash("".getBytes());
            byte[] hash = ByteArrayUtils.buildUpByte(tbsHash, nullHash);
            return SignatureVerify.verify(publicKey, hash,
                    Certificate.getInstance(verifyCertData).getSignature());
        } else {
            byte[] signCertHash = calculateService.sha256Hash(signCertData);
            byte[] hash = ByteArrayUtils.buildUpByte(tbsHash, signCertHash);
            return SignatureVerify.verify(publicKey, hash,
                    Certificate.getInstance(verifyCertData).getSignature());
        }
    }

    public static Date getStartTime(ValidityPeriod validityPeriod, Date caStartDate, Date caEndDate) throws Exception {
        if (null != validityPeriod.getTimeStartAndEnd()) {
            BigInteger startTime = BigIntegers.fromUnsignedByteArray(validityPeriod.getTimeStartAndEnd().getStartValidity().getEncode());
            Date reqStartTime = TimeUtils.getTimeFromNumber(startTime.longValue());
            if (reqStartTime.getTime() >= caEndDate.getTime()) {
                throw new Exception("设置的起始时间不能超过CA证书有效期");
            }
            //比CA证书起始日期小的话按照CA证书起始日期计算
            return caStartDate.getTime() > reqStartTime.getTime() ? caStartDate : reqStartTime;
        } else {
            return null;
        }
    }

    /**
     * OER数加小数点
     *
     * @param point
     * @return
     */
    public static String pointTo(String point) {
        while (point.length() <= 7) {
            point = "0" + point;
        }
        int length = point.length() - 7;
        String decimal = point.substring(length);
        String integer = point.substring(0, length);
        point = integer + "." + decimal;
        return point;
    }

    /**
     * 校验两个数组是否相等
     *
     * @param arr1 数组1
     * @param arr2 数组2
     */
    public static boolean isEqual(byte[] arr1, byte[] arr2) {
        if (arr1.length != arr2.length) {
            return false;
        } else {
            boolean flag1 = true;
            boolean flag2 = true;
            for (int i = 0; i < arr1.length; i++) {
                if (arr1[i] != arr2[i]) {
                    flag1 = false;
                }
            }
            if (!flag1) {
                for (int i = 0; i < arr1.length; i++) {
                    if (arr1[i] != arr2[arr1.length - 1 - i]) {
                        flag2 = false;
                    }
                }
            }
            return flag1 || flag2;
        }
    }

}
