package org.bouncycastle.tls.crypto.impl.bc;

import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.crypto.sdf.*;
import com.xdja.pki.gmssl.sdf.SdfSDKException;
import org.bouncycastle.asn1.x509.KeyUsage;
import org.bouncycastle.crypto.BlockCipher;
import org.bouncycastle.crypto.Digest;
import org.bouncycastle.crypto.ExtendedDigest;
import org.bouncycastle.crypto.digests.SM3Digest;
import org.bouncycastle.crypto.macs.HMac;
import org.bouncycastle.crypto.modes.CBCBlockCipher;
import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
import org.bouncycastle.crypto.params.KeyParameter;
import org.bouncycastle.crypto.params.ParametersWithIV;
import org.bouncycastle.tls.*;
import org.bouncycastle.tls.crypto.*;
import org.bouncycastle.tls.crypto.impl.AbstractTlsCrypto;
import org.bouncycastle.tls.crypto.impl.TlsBlockCipher;
import org.bouncycastle.tls.crypto.impl.TlsBlockCipherImpl;
import org.bouncycastle.tls.crypto.impl.TlsEncryptor;
import org.bouncycastle.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.math.BigInteger;
import java.security.SecureRandom;

/**
 * Class for providing cryptographic services for TLS based on implementations in the BC light-weight API.
 * 基于BC轻量级API的实现为TLS提供加密服务的类。
 * <p>
 * This class provides default implementations for everything. If you need to customise it, extend the class
 * and override the appropriate methods.
 * </p>
 */
public class BcTlsCryptoSdf extends AbstractTlsCrypto {

    private static Logger logger = LoggerFactory.getLogger(BcTlsCryptoSdf.class.getName());

    private final SdfCryptoType sdfCryptoType;

    public BcTlsCryptoSdf(SdfCryptoType sdfCryptoType) {
        this.sdfCryptoType = sdfCryptoType;
    }

    public SdfCryptoType getSdfCryptoType() {
        return sdfCryptoType;
    }

    BcTlsSecretSdf adoptLocalSecret(byte[] data) {
        return new BcTlsSecretSdf(this, data);
    }

    @Override
    public SecureRandom getSecureRandom() {
        try {
            return new SdfRandom(sdfCryptoType);
        } catch (SdfSDKException e) {
            logger.error("getSecureRandom new sdf random error", e);
            throw new IllegalStateException("unable to create SdfRandom: " + e.getMessage(), e);
        }
    }

    @Override
    public TlsCertificate createCertificate(byte[] encoding)
            throws IOException {
        return new BcTlsCertificateSdf(this, encoding);
    }

    @Override
    protected TlsCipher createCipher(TlsCryptoParameters cryptoParams, int encryptionAlgorithm, int macAlgorithm)
            throws IOException {
        switch (encryptionAlgorithm) {
            //GMSSL SUPPORT SM4_128_CBC
            case EncryptionAlgorithm.SM4_128_CBC:
                return createSM4Cipher(cryptoParams, 16, macAlgorithm);
            default:
                throw new TlsFatalAlert(AlertDescription.internal_error);
        }
    }

    @Override
    public TlsDHDomain createDHDomain(TlsDHConfig dhConfig) {
        return null;
    }


    @Override
    public TlsECDomain createECDomain(TlsECConfig ecConfig) {
        return null;
    }

    @Override
    protected TlsEncryptor createEncryptor(TlsCertificate certificate) throws IOException {
        final BcTlsCertificateSdf bcCert = BcTlsCertificateSdf.convert(this, certificate);
        bcCert.validateKeyUsage(KeyUsage.keyEncipherment);

        //GMSSL SUPPORT: 2018/8/2 for ecc sm2 must be encrypt pre master secret
        final AsymmetricKeyParameter keyParameter = bcCert.getPublicKey();
        if (keyParameter instanceof SdfECKeyParameters) {
            return new TlsEncryptor() {
                @Override
                public byte[] encrypt(byte[] input, int inOff, int length)
                        throws IOException {
                    try {
                        SdfSM2Engine psm2Engine = new SdfSM2Engine(sdfCryptoType);
                        psm2Engine.init(true, keyParameter);
                        byte[] cipher = psm2Engine.encryptASN1(input);
                        psm2Engine.release();
                        return cipher;
                    } catch (SdfSDKException e) {
                        throw new IOException(e);
                    }
                }
            };
        } else {
            throw new TlsFatalAlert(AlertDescription.internal_error);
        }
    }

    @Override
    public TlsNonceGenerator createNonceGenerator(byte[] additionalSeedMaterial) {
        return new TlsNonceGenerator() {
            @Override
            public byte[] generateNonce(int size) {
                byte[] nonce = new byte[size];
                SdfRandom sdfRandom = (SdfRandom) getSecureRandom();
                sdfRandom.nextBytes(nonce);
//                GMSSLByteArrayUtils.printHexBinary(logger, "SdfRandom generateNonce", nonce);
                return nonce;
            }
        };
    }

    @Override
    public boolean hasAllRawSignatureAlgorithms() {
        return true;
    }

    @Override
    public boolean hasDHAgreement() {
        return true;
    }

    @Override
    public boolean hasECDHAgreement() {
        return true;
    }

    @Override
    public boolean hasEncryptionAlgorithm(int encryptionAlgorithm) {
        switch (encryptionAlgorithm) {
            case EncryptionAlgorithm.SM4_128_CBC:
                return true;
            default:
                return false;
        }
    }

    @Override
    public boolean hasHashAlgorithm(short hashAlgorithm) {
        return true;
    }

    @Override
    public boolean hasMacAlgorithm(int macAlgorithm) {
        return true;
    }

    @Override
    public boolean hasNamedGroup(int namedGroup) {
        return NamedGroup.refersToASpecificGroup(namedGroup);
    }

    @Override
    public boolean hasRSAEncryption() {
        return true;
    }

    @Override
    public boolean hasSignatureAlgorithm(int signatureAlgorithm) {
        return true;
    }

    @Override
    public boolean hasSignatureAndHashAlgorithm(SignatureAndHashAlgorithm sigAndHashAlgorithm) {
        return true;
    }

    @Override
    public boolean hasSRPAuthentication() {
        return true;
    }

    @Override
    public TlsSecret createSecret(byte[] data) {
        return adoptLocalSecret(Arrays.clone(data));
    }

    //生成预主密钥
    @Override
    public TlsSecret generatePreMasterSecret(ProtocolVersion version) {
        byte[] data = new byte[48];
        SdfRandom sdfRandom = (SdfRandom) getSecureRandom();
        sdfRandom.nextBytes(data);//生成随机数，并写入data
//        GMSSLByteArrayUtils.printHexBinary(logger, "generate pre master secret random", data);
        TlsUtils.writeVersion(version, data, 0);//向data中写入version 从0位写入 替换
        GMSSLByteArrayUtils.printHexBinary(logger, "generate pre master secret", data);
        return adoptLocalSecret(data);
    }

    public Digest createDigest(short hashAlgorithm) {
        switch (hashAlgorithm) {
            //GMSSL SUPPORT add sm3
            case HashAlgorithm.sm3:
                try {
                    return new SdfSM3Digest(sdfCryptoType);
                } catch (SdfSDKException e) {
                    throw new IllegalArgumentException(e);
                }
            default:
                throw new IllegalArgumentException("unknown HashAlgorithm: " + HashAlgorithm.getText(hashAlgorithm));
        }
    }

    @Override
    public TlsHash createHash(short algorithm) {
        return new BcTlsHash(algorithm, createDigest(algorithm));
    }

    public static class BcTlsHash implements TlsHash {
        private final short hashAlgorithm;
        private final Digest digest;

        BcTlsHash(short hashAlgorithm, Digest digest) {
            this.hashAlgorithm = hashAlgorithm;
            this.digest = digest;
        }

        @Override
        public void update(byte[] data, int offSet, int length) {
            digest.update(data, offSet, length);
        }

        @Override
        public byte[] calculateHash() {
            byte[] rv = new byte[digest.getDigestSize()];
            digest.doFinal(rv, 0);
            return rv;
        }

        @Override
        public Object clone() {
            return new BcTlsHash(hashAlgorithm, cloneDigest(hashAlgorithm, digest));
        }

        @Override
        public void reset() {
            digest.reset();
        }

        public void releaseConnection(){
            if (digest instanceof SdfSM3Digest){
                ((SdfSM3Digest) digest).releaseConnection();
            }
        }
    }

    public static Digest cloneDigest(short hashAlgorithm, Digest hash) {
        switch (hashAlgorithm) {
            //GMSSL SUPPORT add sm3
            case HashAlgorithm.sm3:
                try {
                    return new SdfSM3Digest((SdfSM3Digest) hash);
                } catch (SdfSDKException e) {
                    throw new IllegalArgumentException(e);
                }
            default:
                throw new IllegalArgumentException("unknown HashAlgorithm: " + HashAlgorithm.getText(hashAlgorithm));
        }
    }

    //GMSSL SUPPORT add create sm4 cipher
    protected TlsCipher createSM4Cipher(TlsCryptoParameters cryptoParams, int cipherKeySize, int macAlgorithm)
            throws IOException {
        return new TlsBlockCipher(this, cryptoParams, new BlockOperator(createSM4BlockCipher(), true), new BlockOperator(createSM4BlockCipher(), false),
                createHMAC(macAlgorithm), createHMAC(macAlgorithm), cipherKeySize);
    }

    //GMSSL SUPPORT add create sm4 engine
    protected BlockCipher createSM4Engine() {
        try {
            return new SdfSM4Engine(sdfCryptoType);
        } catch (SdfSDKException e) {
            throw new IllegalArgumentException(e);
        }
    }

    //GMSSL SUPPORT add create sm4 block cipher
    protected BlockCipher createSM4BlockCipher() {
        return new CBCBlockCipherSdf(createSM4Engine());
    }

    @Override
    public TlsHMAC createHMAC(int macAlgorithm) {
        return new HMacOperator(createDigest(TlsUtils.getHashAlgorithmForHMACAlgorithm(macAlgorithm)));
    }

    @Override
    public TlsSRP6Client createSRP6Client(TlsSRPConfig srpConfig) {
        return null;
    }

    @Override
    public TlsSRP6Server createSRP6Server(TlsSRPConfig srpConfig, BigInteger srpVerifier) {
        return null;
    }

    @Override
    public TlsSRP6VerifierGenerator createSRP6VerifierGenerator(TlsSRPConfig srpConfig) {
        return null;
    }

    public class BlockOperator implements TlsBlockCipherImpl {
        private final boolean isEncrypting;
        private final BlockCipher cipher;

        private KeyParameter key;

        BlockOperator(BlockCipher cipher, boolean isEncrypting) {
            this.cipher = cipher;
            this.isEncrypting = isEncrypting;
        }

        @Override
        public void setKey(byte[] key, int keyOff, int keyLen) {
            this.key = new KeyParameter(key, keyOff, keyLen);
            cipher.init(isEncrypting, this.key);
        }

        @Override
        public void init(byte[] iv, int ivOff, int ivLen) {
            cipher.init(isEncrypting, new ParametersWithIV(null, iv, ivOff, ivLen));
        }

        @Override
        public int doFinal(byte[] input, int inputOffset, int inputLength, byte[] output, int outputOffset) {
            int blockSize = cipher.getBlockSize();

            for (int i = 0; i < inputLength; i += blockSize) {
                cipher.processBlock(input, inputOffset + i, output, outputOffset + i);
            }

            return inputLength;
        }

        @Override
        public int getBlockSize() {
            return cipher.getBlockSize();
        }

        public void releaseConnection(){
            if (cipher instanceof CBCBlockCipherSdf){
                ((CBCBlockCipherSdf) cipher).releaseConnection();
            }
        }
    }

    private class HMacOperator implements TlsHMAC {
        private final HMac hmac;

        HMacOperator(Digest digest) {
            // TODO: 2019/4/3 use pcie sm3
            this.hmac = new HMac(new SM3Digest());
//            this.hmac = new HMac(digest);
        }

        @Override
        public void setKey(byte[] key, int keyOff, int keyLen) {
            hmac.init(new KeyParameter(key, keyOff, keyLen));
        }

        @Override
        public void update(byte[] input, int inOff, int length) {
            hmac.update(input, inOff, length);
        }

        @Override
        public byte[] calculateMAC() {
            byte[] rv = new byte[hmac.getMacSize()];

            hmac.doFinal(rv, 0);

            return rv;
        }

        @Override
        public int getInternalBlockSize() {
            return ((ExtendedDigest) hmac.getUnderlyingDigest()).getByteLength();
        }

        @Override
        public int getMacLength() {
            return hmac.getMacSize();
        }

        @Override
        public void reset() {
            hmac.reset();
        }
    }
}
