package org.bouncycastle.tls;

import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.crypto.sdf.SdfPrivateKey;
import org.bouncycastle.asn1.x509.Extensions;
import org.bouncycastle.asn1.x509.KeyUsage;
import org.bouncycastle.tls.crypto.*;
import org.bouncycastle.tls.crypto.impl.AbstractTlsCrypto;
import org.bouncycastle.tls.crypto.impl.bc.*;
import org.bouncycastle.util.encoders.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.X509KeyManager;
import javax.net.ssl.X509TrustManager;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.security.PrivateKey;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class GMSSLUtils {

    private static Logger logger = LoggerFactory.getLogger(GMSSLUtils.class.getName());

    //GMSSL SUPPORT:  for ecc sm2 select credentials
    public static DefaultTlsCredentialedSigner generateCredentials(
            String keyType,
            AbstractTlsCrypto crypto,
            X509KeyManager keyManager,
            SignatureAndHashAlgorithm sigAlg,
            TlsContext context
    ) throws IOException {

        String signatureAlias = getSignatureAlias(keyType, keyManager, crypto);
        String encryptionAlias = getEncryptioneAlias(keyType, keyManager, crypto);

        List<String> aliases = Arrays.asList(keyManager.getServerAliases(keyType, null));
        if (signatureAlias == null) {
            signatureAlias = "sign";
            if (!aliases.contains(signatureAlias)) {
                signatureAlias = aliases.get(0);
            }
        }
        if (encryptionAlias == null) {
            encryptionAlias = "enc";
            if (!aliases.contains(encryptionAlias)) {
                encryptionAlias = aliases.get(1);
            }
        }
//        if (log == null){
//            logger.debug("generate Credentials signatureAlias: " + signatureAlias + ", encryptionAlias: " + encryptionAlias);
//        } else {
//            log.debug("generate Credentials signatureAlias: " + signatureAlias + ", encryptionAlias: " + encryptionAlias);
//        }

        if (signatureAlias != null && encryptionAlias != null) {
            PrivateKey signaturePrivateKey = keyManager.getPrivateKey(signatureAlias);
            PrivateKey encryptionPrivateKey = keyManager.getPrivateKey(encryptionAlias);

            Certificate signatureCertificate = getCertificateMessage(crypto, keyManager.getCertificateChain(signatureAlias));
            Certificate encryptionCertificate = getCertificateMessage(crypto, keyManager.getCertificateChain(encryptionAlias));

            Certificate gmsslCertificate = makeGMSSLCertificate(signatureCertificate, encryptionCertificate);

            return new BcDefaultTlsCredentialedECCSM2(
                    new TlsCryptoParameters(context),
                    crypto,
                    gmsslCertificate,
                    signaturePrivateKey,
                    encryptionPrivateKey,
                    sigAlg
            );
        }
        return null;
    }

    //GMSSL SUPPORT:  签名证书在前 加密证书在后
    public static Certificate makeGMSSLCertificate(Certificate signatureCertificate, Certificate encryptionCertificate) {
//        TlsCertificate[] tlsCertificates = {
//                signatureCertificate.getCertificateAt(0),
//                encryptionCertificate.getCertificateAt(0),
//                trust.getCertificateAt(0)
//        };
        List<TlsCertificate> list = new ArrayList<>();
        list.add(signatureCertificate.getCertificateAt(0));
        list.add(encryptionCertificate.getCertificateAt(0));
//        for (int i = 0; i < trust.getLength(); i++) {
//            TlsCertificate t = trust.getCertificateAt(i);
//            list.add(t);
//        }
        List<BigInteger> sns = new ArrayList<>();
        for (int i = 1; i < signatureCertificate.getLength(); i++) {
            TlsCertificate certificateAt = signatureCertificate.getCertificateAt(i);
            list.add(certificateAt);
            sns.add(certificateAt.getSerialNumber());
        }
        for (int i = 1; i < encryptionCertificate.getLength(); i++) {
            TlsCertificate certificateAt = encryptionCertificate.getCertificateAt(i);
            if (!sns.contains(certificateAt.getSerialNumber())){
                list.add(certificateAt);
            }
        }
        TlsCertificate[] tlsCertificates = list.toArray(new TlsCertificate[list.size()]);
        return new Certificate(tlsCertificates);
    }

    //GMSSL SUPPORT:  签名证书在前
    public static TlsCertificate getSignatureCertificate(Certificate gmsslCertificate) {
        return gmsslCertificate.getCertificateAt(0);
    }

    //GMSSL SUPPORT:  加密证书在后
    public static TlsCertificate getEncryptionCertificate(Certificate gmsslCertificate) {
        return gmsslCertificate.getCertificateAt(1);
    }

    public static String getSignatureAlias(String keyType, X509KeyManager keyManager, TlsCrypto crypto) throws IOException {
        return getAliasWithKeyUsage(keyType, KeyUsage.digitalSignature | KeyUsage.nonRepudiation, keyManager, crypto);
    }

    public static String getEncryptioneAlias(String keyType, X509KeyManager keyManager, TlsCrypto crypto) throws IOException {
        return getAliasWithKeyUsage(keyType, KeyUsage.keyEncipherment | KeyUsage.dataEncipherment | KeyUsage.keyAgreement, keyManager, crypto);
    }

    public static String getAliasWithKeyUsage(String keyType, int keyUsage, X509KeyManager keyManager, TlsCrypto crypto) throws IOException {
        String[] aliases = keyManager.getServerAliases(keyType, null);
        if (aliases.length == 1) {
            return aliases[0];
        }

        if (aliases.length != 2) {
            return null;
        }
        for (String alias : aliases) {
            Certificate certificate = getCertificateMessage(crypto, keyManager.getCertificateChain(alias));

            TlsCertificate tlsCertificate = certificate.getCertificateAt(0);

            byte[] encoding = tlsCertificate.getEncoded();
            Extensions exts = org.bouncycastle.asn1.x509.Certificate.getInstance(encoding).getTBSCertificate().getExtensions();

            if (exts != null) {
                KeyUsage ku = KeyUsage.fromExtensions(exts);
                if (ku != null) {
                    int bits = ku.getBytes()[0] & 0xff;
                    if ((bits & keyUsage) == keyUsage) {
                        return alias;
                    }
                }
            }
        }
        return null;
    }

    public static Certificate getCertificateMessage(TlsCrypto crypto, X509Certificate[] chain) throws IOException {
        if (chain == null || chain.length < 1) {
            return Certificate.EMPTY_CHAIN;
        }

        TlsCertificate[] certificateList = new TlsCertificate[chain.length];
        try {
            for (int i = 0; i < chain.length; ++i) {
                certificateList[i] = crypto.createCertificate(chain[i].getEncoded());
            }
        } catch (CertificateEncodingException e) {
            throw new TlsFatalAlert(AlertDescription.internal_error, e);
        }

        return new Certificate(certificateList);
    }


    //GMSSL SUPPORT: 2018/7/31 generate ecc sm2 server key exchange signature

    /**
     * ECC SM2
     * digitally - signed struct {
     * opaque client_random[32];
     * opaque server_random[32];
     * opaque ASN.1Cert<1..2-24--1>;加密证书
     * } signed_params;
     */
    static byte[] generateECCSM2ServerKeyExchangeSignature(TlsContext context, TlsCredentialedSigner credentials, TlsCertificate encryptionCertificate) throws IOException {
        byte[] m = calculateSignatureECCSM2(context, encryptionCertificate);
//        GMSSLByteArrayUtils.printHexBinary(logger, "generateECCSM2ServerKeyExchangeSignature m", m);

        byte[] signature = credentials.generateRawSignature(m);
//        GMSSLByteArrayUtils.printHexBinary(logger, "generateECCSM2ServerKeyExchangeSignature signature", signature);
        return signature;
    }

    //GMSSL SUPPORT: 2018/7/31 generate ecc sm2 client key exchange signature
    static void verifyECCSM2ServerKeyExchangeSignature(TlsContext context, TlsVerifier verifier, DigitallySigned signedParams, TlsCertificate encryptionCertificate) throws IOException {
        byte[] m = calculateSignatureECCSM2(context, encryptionCertificate);
        GMSSLByteArrayUtils.printHexBinary(logger, "verifyECCSM2ServerKeyExchangeSignature m", m);

        boolean verified = verifier.verifyRawSignature(signedParams, m);
        if (!verified) {
            throw new TlsFatalAlert(AlertDescription.decrypt_error);
        }
    }

    //GMSSL SUPPORT: 2018/7/31 计算 签名 : client random; server random; encryption certificate 加密证书
    static byte[] calculateSignatureECCSM2(TlsContext context, TlsCertificate encryptionCertificate) throws IOException {

        ByteArrayOutputStream output = new ByteArrayOutputStream();

        SecurityParameters securityParameters = context.getSecurityParameters();
//        GMSSLByteArrayUtils.printHexBinary(logger, "calculateSignatureECCSM2 client random", securityParameters.clientRandom);
//        GMSSLByteArrayUtils.printHexBinary(logger, "calculateSignatureECCSM2 server random", securityParameters.serverRandom);
        output.write(securityParameters.clientRandom);
        output.write(securityParameters.serverRandom);

        byte[] cb = encryptionCertificate.getEncoded();
//        GMSSLByteArrayUtils.printHexBinary(logger, "calculateSignatureECCSM2 encryptionCertificate", cb);
        TlsUtils.writeOpaque24(cb, output);

        return output.toByteArray();
    }


    static byte[] generateCertificateVerify(TlsContext context, TlsCredentialedSigner credentialedSigner,
                                            TlsStreamSigner streamSigner, TlsHandshakeHash handshakeHash) throws IOException {
        byte[] hash = handshakeHash.getFinalHash(HashAlgorithm.sm3);
        GMSSLByteArrayUtils.printHexBinary(logger, "verify certificate verify hash", hash);

        byte[] signature = credentialedSigner.generateRawSignature(hash);
        GMSSLByteArrayUtils.printHexBinary(logger, "verify certificate verify", signature);

        return signature;
    }

    static void verifyCertificateVerify(TlsContext context, CertificateRequest certificateRequest, Certificate certificate,
                                        ByteArrayInputStream buf, TlsHandshakeHash handshakeHash) throws IOException {
        byte[] certificateVerify = TlsUtils.readOpaque16(buf);
        GMSSLByteArrayUtils.printHexBinary(logger, "certificate verify", certificateVerify);

        byte[] hash = handshakeHash.getFinalHash(HashAlgorithm.sm3);
        GMSSLByteArrayUtils.printHexBinary(logger, "certificate verify hash", hash);

        // Verify the CertificateVerify message contains a correct signature.
        TlsVerifier verifier = certificate.getCertificateAt(0).createVerifier(SignatureAlgorithm.sm2);

        DigitallySigned signedParams = new DigitallySigned(new SignatureAndHashAlgorithm(HashAlgorithm.sm3, SignatureAlgorithm.sm2), certificateVerify);

        boolean verified = verifier.verifyRawSignature(signedParams, hash);
        logger.info("verify certificate verified: " + verified);
        if (!verified) {
            throw new TlsFatalAlert(AlertDescription.decrypt_error);
        }
    }
}
