package com.xdja.pki.gmssl.crypto.sdf;

import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLX509Utils;
import com.xdja.pki.gmssl.sdf.SdfSDK;
import com.xdja.pki.gmssl.sdf.SdfSDKException;
import com.xdja.pki.gmssl.sdf.bean.SdfAlgIdAsymmetric;
import com.xdja.pki.gmssl.sdf.bean.SdfECCPublicKey;
import com.xdja.pki.gmssl.sdf.bean.SdfECCSignature;
import org.bouncycastle.asn1.*;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.CryptoException;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.ExtendedDigest;
import org.bouncycastle.crypto.params.ParametersWithID;
import org.bouncycastle.util.Arrays;
import org.bouncycastle.util.BigIntegers;
import org.bouncycastle.util.encoders.Hex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.math.BigInteger;

public abstract class SdfECBaseSigner extends SdfSigner {

    protected Logger logger = LoggerFactory.getLogger(this.getClass());
    /**
     * 摘要算法
     */
    protected ExtendedDigest digest;

    /**
     * EC密钥
     */
    protected SdfECKeyParameters ecKey;

    /**
     * 签名算法
     */
    protected String signAlgName = Constants.SM3_WITH_SM2;
    /**
     * 密钥长度
     */
    protected int bits = 256;

    /**
     * 默认走密码机
     *
     * @throws SdfSDKException
     */
    public SdfECBaseSigner() throws SdfSDKException {
        this(SdfCryptoType.YUNHSM);
    }

    public SdfECBaseSigner(SdfCryptoType sdfCryptoType) throws SdfSDKException {
        this(sdfCryptoType.getSdfSDK());
    }

    public SdfECBaseSigner(SdfSDK sdfSDK) throws SdfSDKException {
        this.sdfSDK = sdfSDK;
        this.sdfSDK.init();
    }

    public SdfECBaseSigner(SdfSDK sdfSDK, String signAlgName) throws SdfSDKException {
        this.sdfSDK = sdfSDK;
        this.signAlgName = signAlgName;
        this.sdfSDK.init();
    }

    @Override
    public void init(boolean forSigning, CipherParameters param) {
        byte[] userID;
        try {
            //兼容之前接口   SM2签名需要将公钥算入摘要  ECDSA不需要
            if (param instanceof ParametersWithID) {
                this.ecKey = (SdfECKeyParameters) ((ParametersWithID) param).getParameters();
                userID = ((ParametersWithID) param).getID();
            } else {
                this.ecKey = (SdfECKeyParameters) param;
                // the default value.
                userID = Hex.decode("31323334353637383132333435363738");
            }
            switch (signAlgName.toUpperCase()) {
                case Constants.SM3_WITH_SM2:
                    SdfECCPublicKey sdfECCPublicKey;
                    if (forSigning) {
                        sdfECCPublicKey = this.sdfSDK.exportSignPublicKeyEcc(this.ecKey.getEcIndex());
                    } else {
                        sdfECCPublicKey = this.ecKey.getSDFECCPublicKey();
                    }
                    initDigest(userID, sdfECCPublicKey);
                    break;
                case Constants.SHA1_WITH_ECDSA:
                case Constants.SHA256_WITH_ECDSA:
                case Constants.SHA384_WITH_ECDSA:
                case Constants.SHA512_WITH_ECDSA:
                    initDigest(signAlgName);
                    break;
                default:
                    logger.error("init error  un support sign alg name {}", signAlgName);
                    break;
            }
        } catch (SdfSDKException e) {
            logger.error("init", e);
        }
    }

    /**
     * 初始化并开始获取摘要算法
     *
     * @param userID          IDA发送方的id
     * @param sdfECCPublicKey 公钥
     * @throws SdfSDKException
     */
    public abstract void initDigest(byte[] userID, SdfECCPublicKey sdfECCPublicKey) throws SdfSDKException;

    /**
     * 初始化并开始获取摘要算法
     *
     * @param signAlgName 签名算法
     * @throws SdfSDKException
     */
    public abstract void initDigest(String signAlgName) throws SdfSDKException;

    @Override
    public void update(byte b) {
        byte[] bytes = new byte[]{b};
        update(bytes, 0, bytes.length);
    }

    @Override
    public void update(byte[] in, int off, int len) {
        digest.update(in, off, len);
    }

    @Override
    public byte[] generateSignature() throws CryptoException, DataLengthException {
        byte[] out = new byte[digest.getDigestSize()];
        try {
            digest.doFinal(out, 0);
            SdfECCSignature sdfECCSignature;
            switch (bits) {
                case Constants.KEY_BITS_256:
                    // TODO: 2021/6/30 填充0 确认位数SHA384  SHA512
                    out = GMSSLByteArrayUtils.fillByteArrayWithZeroInHead(out, 32);
                    // TODO: 2021/6/30 导出公钥确认格式 是SM2 还是Nist
                    //SM2签名
                    sdfECCSignature = sdfSDK.internalSignECC(ecKey.getEcIndex(), ecKey.getPassword(), out);
                    break;
                case Constants.KEY_BITS_384:
                case Constants.KEY_BITS_521:
                    sdfECCSignature = sdfSDK.internalSignECCEx(ecKey.getEcIndex(), ecKey.getPassword(), out, bits);
                    break;
                default:
                    return null;
            }
            return derEncode(sdfECCSignature.getR(), sdfECCSignature.getS());
        } catch (SdfSDKException | IOException e) {
            GMSSLByteArrayUtils.printHexBinary(logger, "signature  digest", out);
            logger.error("generateSignature error index={} password={}", ecKey.getEcIndex(), new String(ecKey.getPassword()), e);
            throw new CryptoException("sdf ec sign error", e);
        }
    }

    @Override
    public boolean verifySignature(byte[] signature) {
        byte[] out = new byte[digest.getDigestSize()];
        try {
            digest.doFinal(out, 0);
            BigInteger[] bigIntegers = GMSSLX509Utils.derSignatureDecode(signature);
            byte[] r = bigIntegers[0].toByteArray();
            byte[] s = bigIntegers[1].toByteArray();
            SdfECCSignature sdfECCSignature = new SdfECCSignature(r, s);
            switch (bits) {
                case Constants.KEY_BITS_256:
                    out = GMSSLByteArrayUtils.fillByteArrayWithZeroInHead(out, 32);
                    //SM2验签
                    if (null == signAlgName) {
                        sdfSDK.externalVerifyECC(ecKey.getSDFECCPublicKey(), out, sdfECCSignature);
                    } else {
                        sdfSDK.externalVerifyECC(ecKey.getSDFECCPublicKey(), out, sdfECCSignature, signAlgName);
                    }
                    break;
                case Constants.KEY_BITS_384:
                    sdfSDK.externalVerifyECCEx(ecKey.getSDFECCPublicKey(), out, sdfECCSignature, SdfAlgIdAsymmetric.SGD_ECC_NISTP384, bits);
                    break;
                case Constants.KEY_BITS_521:
                    sdfSDK.externalVerifyECCEx(ecKey.getSDFECCPublicKey(), out, sdfECCSignature, SdfAlgIdAsymmetric.SGD_ECC_NISTP521, bits);
                    break;
                default:
                    logger.error("un support key bits  {}", bits);
                    return false;
            }
            return true;
        } catch (IOException | SdfSDKException e) {
            if (signature != null) {
                GMSSLByteArrayUtils.printHexBinary(logger, "verifySignature digest", out, true);
                GMSSLByteArrayUtils.printHexBinary(logger, "verifySignature x", ecKey.getSDFECCPublicKey().getX(), true);
                GMSSLByteArrayUtils.printHexBinary(logger, "verifySignature y", ecKey.getSDFECCPublicKey().getY(), true);
                GMSSLByteArrayUtils.printHexBinary(logger, "verifySignature signature", signature, true);
            }
            logger.error("verifySignature", e);
            return false;
        }
    }

    @Override
    public void reset() {

    }

    protected BigInteger[] derDecode(byte[] encoding) throws IOException {
        ASN1Sequence seq = ASN1Sequence.getInstance(ASN1Primitive.fromByteArray(encoding));
        if (seq.size() != 2) {
            return null;
        }
        BigInteger r = ASN1Integer.getInstance(seq.getObjectAt(0)).getValue();
        BigInteger s = ASN1Integer.getInstance(seq.getObjectAt(1)).getValue();

        byte[] expectedEncoding = derEncode(r, s);
        if (!Arrays.constantTimeAreEqual(expectedEncoding, encoding)) {
            return null;
        }

        return new BigInteger[]{r, s};
    }

    protected byte[] derEncode(byte[] r, byte[] s) throws IOException {
        return derEncode(
                BigIntegers.fromUnsignedByteArray(r),
                BigIntegers.fromUnsignedByteArray(s)
        );
    }

    protected byte[] derEncode(BigInteger r, BigInteger s) throws IOException {
        ASN1EncodableVector v = new ASN1EncodableVector();
        v.add(new ASN1Integer(r));
        v.add(new ASN1Integer(s));
        return new DERSequence(v).getEncoded(ASN1Encoding.DER);
    }
}
