package com.xdja.pki.gmssl.crypto.sdf;

import com.xdja.pki.gmssl.asn1.crypto.ASN1SM2Cipher;
import com.xdja.pki.gmssl.core.utils.GMSSLByteArrayUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLX509Utils;
import com.xdja.pki.gmssl.sdf.SdfSDK;
import com.xdja.pki.gmssl.sdf.SdfSDKException;
import com.xdja.pki.gmssl.sdf.bean.SdfECCCipher;
import com.xdja.pki.gmssl.sdf.bean.SdfECCPublicKey;
import org.bouncycastle.crypto.CipherParameters;
import org.bouncycastle.crypto.InvalidCipherTextException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.PublicKey;

public class SdfECEngine {

    private Logger logger = LoggerFactory.getLogger(this.getClass());

    private SdfECKeyParameters ecKey;
    private boolean forEncryption;
    private SdfSDK sdfSDK;
    private String stdName;
    private int bits = 256;

    public SdfECEngine() throws SdfSDKException {
        this(SdfCryptoType.YUNHSM);
    }

    public SdfECEngine(SdfCryptoType sdfCryptoType) throws SdfSDKException {
        this(sdfCryptoType.getSdfSDK());
    }

    public SdfECEngine(SdfSDK sdfSDK) throws SdfSDKException {
        this.sdfSDK = sdfSDK;
        this.sdfSDK.init();
    }

    public SdfECEngine(SdfSDK sdfSDK, String stdName) throws SdfSDKException {
        this(sdfSDK, stdName, 256);
    }

    public SdfECEngine(SdfSDK sdfSDK, String stdName, int bits) throws SdfSDKException {
        this.sdfSDK = sdfSDK;
        this.stdName = stdName;
        this.sdfSDK.init();
        this.bits = bits;
    }

    public void init(boolean forEncryption, CipherParameters param) {
        this.forEncryption = forEncryption;
        this.ecKey = (SdfECKeyParameters) param;
    }

    public byte[] encryptASN1(byte[] data) {
        //SM2加密
        try {
            SdfECCPublicKey sdfECCPublicKey = this.ecKey.getSDFECCPublicKey();
            SdfECCCipher eccCipher;
            if (null == stdName) {
                eccCipher = this.sdfSDK.externalEncryptECC(sdfECCPublicKey, data);
            } else {
                eccCipher = this.sdfSDK.externalEncryptECC(sdfECCPublicKey, data, stdName);
            }
            byte[] c = new byte[eccCipher.getL()];
            System.arraycopy(eccCipher.getC(), 0, c, 0, c.length);
            ASN1SM2Cipher asn1SM2Cipher = new ASN1SM2Cipher(eccCipher.getX(), eccCipher.getY(), eccCipher.getM(), c);
            return asn1SM2Cipher.toASN1Primitive().getEncoded();
        } catch (SdfSDKException | IOException e) {
            logger.error("encrypt asn1", e);
            return new byte[0];
        }
    }

    public byte[] decryptASN1(byte[] cipher) {
        ASN1SM2Cipher sm2CipherASN1 = ASN1SM2Cipher.getInstance(cipher);
        try {
            assert sm2CipherASN1 != null;
            SdfECCCipher eccCipher = new SdfECCCipher(
                    GMSSLByteArrayUtils.changeByteArrayLength(sm2CipherASN1.getxCoordinate().toByteArray(), 32),
                    GMSSLByteArrayUtils.changeByteArrayLength(sm2CipherASN1.getyCoordinate().toByteArray(), 32),
                    GMSSLByteArrayUtils.changeByteArrayLength(sm2CipherASN1.getHash(), 32),
                    sm2CipherASN1.getCipherText().length,
                    sm2CipherASN1.getCipherText()
            );
            //EC解密
            if (null == stdName) {
                return sdfSDK.internalDecryptECC(ecKey.getEcIndex(), ecKey.getPassword(), sm2CipherASN1.getCipherText().length, eccCipher);
            } else {
                return sdfSDK.internalDecryptECC(ecKey.getEcIndex(), ecKey.getPassword(), sm2CipherASN1.getCipherText().length, eccCipher, stdName);
            }
        } catch (SdfSDKException e) {
            logger.error("decrypt asn1", e);
            return new byte[0];
        }

    }

    public byte[] processBlockASN1(byte[] in, int inOff, int inLen) throws InvalidCipherTextException {
        byte[] data = new byte[inLen];
        System.arraycopy(in, inOff, data, 0, inLen);
        if (forEncryption) {
            return encryptASN1(data);
        } else {
            return decryptASN1(data);
        }
    }

    public byte[] processBlock(byte[] in, int inOff, int inLen) throws InvalidCipherTextException {
        byte[] data = new byte[inLen];
        System.arraycopy(in, inOff, data, 0, inLen);
        if (forEncryption) {
            return encrypt(data);
        } else {
            return decrypt(data);
        }
    }

    public byte[] encrypt(byte[] data) {
        byte[] cipher = null;
        SdfECCPublicKey sdfECCPublicKey = ecKey.getSDFECCPublicKey();
        //SM2加密
        try {
            SdfECCCipher eccCipher;
            if (null == stdName) {
                eccCipher = sdfSDK.externalEncryptECC(sdfECCPublicKey, data);
            } else {
                eccCipher = sdfSDK.externalEncryptECC(sdfECCPublicKey, data, stdName);
            }

            byte[] x = eccCipher.getX();
            byte[] y = eccCipher.getY();
            byte[] c = new byte[eccCipher.getL()];
            System.arraycopy(eccCipher.getC(), 0, c, 0, c.length);
            byte[] m = eccCipher.getM();
            byte[] PO = new byte[x.length + y.length + 1];
            PO[0] = 0x04;
            System.arraycopy(x, 0, PO, 1, x.length);
            System.arraycopy(y, 0, PO, x.length + 1, y.length);
            cipher = new byte[PO.length + c.length + m.length];
            System.arraycopy(PO, 0, cipher, 0, PO.length);
            System.arraycopy(c, 0, cipher, PO.length, c.length);
            System.arraycopy(m, 0, cipher, PO.length + c.length, m.length);

            return cipher;
        } catch (SdfSDKException e) {
            GMSSLByteArrayUtils.printHexBinary(logger, "encrypt x", sdfECCPublicKey.getX(), true);
            GMSSLByteArrayUtils.printHexBinary(logger, "encrypt y", sdfECCPublicKey.getY(), true);
            GMSSLByteArrayUtils.printHexBinary(logger, "encrypt data", data, true);
            logger.error("encrypt error", e);
            return new byte[0];
        }
    }

    public byte[] decrypt(byte[] cipher) {
        GMSSLByteArrayUtils.printHexBinary(logger, "cipher", cipher);
        int curveLength = 32;
        int hashLength = 32;

        byte[] c1 = new byte[curveLength * 2 + 1];
        System.arraycopy(cipher, 0, c1, 0, c1.length);
        GMSSLByteArrayUtils.printHexBinary(logger, "c1", c1);

        byte[] x = new byte[curveLength];
        System.arraycopy(c1, 1, x, 0, x.length);
        byte[] y = new byte[curveLength];
        System.arraycopy(c1, x.length + 1, y, 0, y.length);

        byte[] c = new byte[cipher.length - hashLength - c1.length];
        System.arraycopy(cipher, c1.length, c, 0, c.length);

        byte[] m = new byte[hashLength];
        System.arraycopy(cipher, cipher.length - hashLength, m, 0, hashLength);

        SdfECCCipher eccCipher = new SdfECCCipher(x, y, m, c.length, c);

        try {
            //ECC解密
            if (null == stdName) {
                return sdfSDK.internalDecryptECC(ecKey.getEcIndex(), ecKey.getPassword(), c.length, eccCipher);
            } else {
                return sdfSDK.internalDecryptECC(ecKey.getEcIndex(), ecKey.getPassword(), c.length, eccCipher, stdName);
            }
        } catch (SdfSDKException e) {
            logger.error("decrypt error index={} password={}", ecKey.getEcIndex(), ecKey.getPassword(), e);
            return new byte[0];
        }

    }

    public void release() throws SdfSDKException {
        if (sdfSDK != null) {
            sdfSDK.release();
        }
    }
}
