package com.xdja.pki.gmssl.crypto.sdf;

import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.sdf.SdfSDK;
import com.xdja.pki.gmssl.sdf.SdfSDKException;
import com.xdja.pki.gmssl.sdf.bean.SdfAlgIdHash;
import com.xdja.pki.gmssl.sdf.bean.SdfRSAPublicKey;
import org.bouncycastle.asn1.ASN1Encoding;
import org.bouncycastle.asn1.ASN1ObjectIdentifier;
import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.asn1.x509.DigestInfo;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.DataLengthException;
import org.bouncycastle.crypto.ExtendedDigest;
import org.bouncycastle.crypto.digests.SHA1Digest;
import org.bouncycastle.crypto.digests.SHA256Digest;
import org.bouncycastle.crypto.digests.SHA384Digest;
import org.bouncycastle.crypto.digests.SHA512Digest;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.security.interfaces.RSAPublicKey;

public class SdfRSASigner extends SdfSigner {

    private Logger logger = LoggerFactory.getLogger(this.getClass());

    private ExtendedDigest sdfSHADigest;

    private SdfRSAKeyParameters keyParameters;

    private AlgorithmIdentifier algId;

    private SdfAlgIdHash sdfAlgIdHash;

    public SdfRSASigner(SdfCryptoType sdfCryptoType, ASN1ObjectIdentifier aid) throws SdfSDKException {
        this(sdfCryptoType, SdfAlgIdHash.converSdfAlgIdHash(aid));
    }

    public SdfRSASigner(SdfCryptoType sdfCryptoType, AlgorithmIdentifier algorithm) throws SdfSDKException {
        this(sdfCryptoType, algorithm.getAlgorithm());
    }

    public SdfRSASigner(SdfAlgIdHash sdfAlgIdHash) throws SdfSDKException {
        this(SdfCryptoType.YUNHSM, sdfAlgIdHash);
    }

    public SdfRSASigner(SdfCryptoType sdfCryptoType, SdfAlgIdHash sdfAlgIdHash) throws SdfSDKException {
        this(sdfCryptoType.getSdfSDK(), sdfCryptoType, sdfAlgIdHash);
    }

    public SdfRSASigner(SdfSDK sdfSDK, SdfCryptoType sdfCryptoType, SdfAlgIdHash sdfAlgIdHash) throws SdfSDKException {
        this.sdfSDK = sdfSDK;
        this.algId = SdfAlgIdHash.convertAlgorithmIdentifier(sdfAlgIdHash);
        this.sdfAlgIdHash = sdfAlgIdHash;
        if (sdfAlgIdHash == SdfAlgIdHash.SGD_SHA1) {
            if (sdfCryptoType == SdfCryptoType.YUNHSM) {
                sdfSHADigest = new SdfSHADigest(sdfSDK, SdfAlgIdHash.SGD_SHA1);
            } else {
                sdfSHADigest = new SHA1Digest();
            }
        } else if (sdfAlgIdHash == SdfAlgIdHash.SGD_SHA256) {
            if (sdfCryptoType == SdfCryptoType.YUNHSM) {
                sdfSHADigest = new SdfSHADigest(sdfSDK, SdfAlgIdHash.SGD_SHA256);
            } else {
                sdfSHADigest = new SHA256Digest();
            }
        } else if (sdfAlgIdHash == SdfAlgIdHash.SGD_SHA384) {
            if (sdfCryptoType == SdfCryptoType.YUNHSM) {
                sdfSHADigest = new SdfSHADigest(sdfSDK, SdfAlgIdHash.SGD_SHA384);
            } else {
                sdfSHADigest = new SHA384Digest();
            }
        } else if (sdfAlgIdHash == SdfAlgIdHash.SGD_SHA512) {
            if (sdfCryptoType == SdfCryptoType.YUNHSM) {
                sdfSHADigest = new SdfSHADigest(sdfSDK, SdfAlgIdHash.SGD_SHA512);
            } else {
                sdfSHADigest = new SHA512Digest();
            }
        } else {
            throw new SdfSDKException("unsupported " + sdfAlgIdHash.getName() + " hash type");
        }
        this.sdfSDK.init();
//        if ("SGD_SHA256".equalsIgnoreCase(sdfSHADigest.getAlgorithmName())) {
//            userId = Hex.decode("3031300d060960864801650304020105000420");
//        } else if ("SGD_SHA1".equalsIgnoreCase(sdfSHADigest.getAlgorithmName())) {
//            userId = Hex.decode("3021300906052b0e03021a05000414");
//        } else {
//            throw new SdfSDKException("暂不支持" + sdfSHADigest.getAlgorithmName() + "类型");
//        }
    }

    @Override
    public void init(boolean forSigning, CipherParameters param) {
        this.keyParameters = (SdfRSAKeyParameters) param;
    }

    @Override
    public void update(byte b) {
        byte[] bytes = new byte[]{b};
        update(bytes, 0, bytes.length);
    }

    @Override
    public void update(byte[] in, int off, int len) {
        sdfSHADigest.update(in, off, len);
    }

    @Override
    public byte[] generateSignature() throws DataLengthException {
        try {
            byte[] out = new byte[sdfSHADigest.getDigestSize()];
            sdfSHADigest.doFinal(out, 0);
            SdfRSAPublicKey sdfRSAPublicKey = this.sdfSDK.exportSignPublicKeyRsa(this.keyParameters.getPrivateKeyIndex());
            byte[] data = derEncode(out, sdfRSAPublicKey.getBits() / 8);
            if (data.length == 0) {
                return data;
            }
//            GMSSLByteArrayUtils.printHexBinary(logger, "sha hash", data);
//            GMSSLByteArrayUtils.printHexBinary(logger, "signature " + sdfAlgIdHash.getName() + " digest", data);
//            //RSA签名
            byte[] sign = sdfSDK.internalPrivateKeyOperationRsa(this.keyParameters.getPrivateKeyIndex(), this.keyParameters.getPassword(), data);
            //  GMSSLByteArrayUtils.printHexBinary(logger, "signature " + sdfAlgIdHash.getName() + " sign  ", sign);
            return sign;
        } catch (
                Exception e) {
            logger.error("generateSignature", e);
            return new byte[0];
        }

    }


    @Override
    public boolean verifySignature(byte[] signature) {
        try {
            byte[] out = new byte[sdfSHADigest.getDigestSize()];
            sdfSHADigest.doFinal(out, 0);
            RSAPublicKey rsaPublicKey = this.keyParameters.getPublicKey();
            SdfRSAPublicKey sdfRSAPublicKey = SdfRSAPublicKey.getInstance(rsaPublicKey);
            byte[] normal = derEncode(out, sdfRSAPublicKey.getBits() / 8);
            byte[] verify = sdfSDK.externalPublicKeyOperationRsa(sdfRSAPublicKey, signature);
//            GMSSLByteArrayUtils.printHexBinary(logger, "signature", signature, true);
//            GMSSLByteArrayUtils.printHexBinary(logger, "verify", verify, true);
//            GMSSLByteArrayUtils.printHexBinary(logger, "normal", normal, true);
            if (GMSSLByteArrayUtils.isEqual(normal, verify)) {
                return true;
            }
            return false;
        } catch (Exception e) {
            if (signature != null) {
                GMSSLByteArrayUtils.printHexBinary(logger, "verifySignature signature", signature);
            }
            logger.error("verifySignature", e);
            return false;
        }
    }

    @Override
    public void reset() {

    }

    private byte[] derEncode(byte[] hash, int length) throws IOException {
        DigestInfo dInfo = new DigestInfo(algId, hash);
        byte[] encoded = dInfo.getEncoded(ASN1Encoding.DER);
        return pkcs1Padding(encoded, length);
    }
}
