package com.xdja.pki.gmssl.hsm.server.runner;

import com.xdja.pki.gmssl.core.utils.GMSSLCertPathUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLFileUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLRandomUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLX509Utils;
import com.xdja.pki.gmssl.crypto.utils.GMSSLRSAKeyUtils;
import com.xdja.pki.gmssl.hsm.server.constant.Constants;
import com.xdja.pki.gmssl.keystore.utils.GMSSLKeyStoreUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

import javax.crypto.spec.SecretKeySpec;
import java.io.File;
import java.io.FileNotFoundException;
import java.security.Key;
import java.security.KeyPair;
import java.security.KeyStore;
import java.security.cert.Certificate;
import java.text.MessageFormat;
import java.util.HashMap;
import java.util.Map;

/**
 * @author: houzhe
 * @date: 2021-11-01
 * @description: 生成KeyStore
 **/
@Component
public class KeyStoreRunner implements ApplicationRunner {

    private final Logger logger = LoggerFactory.getLogger(this.getClass());

    @Override
    public void run(ApplicationArguments args) {
        String[] sourceArgs = args.getSourceArgs();
        if (null == sourceArgs || sourceArgs.length == 0) {
            return;
        }

        int sm2Count = 0;
        int symCount = 0;
        Map<Integer, Integer> rsaCount = new HashMap<>();

        for (String arg : sourceArgs) {
            String[] split = arg.split("-");
            String alg = split[0];
            if (split.length == 2) {
                if (Constants.SM2_PREFIX.equalsIgnoreCase(alg)) {
                    sm2Count = Integer.parseInt(split[1]);
                } else if (Constants.SYM_PREFIX.equalsIgnoreCase(alg)) {
                    symCount = Integer.parseInt(split[1]);
                }
            } else if (split.length == 3) {
                if (Constants.RSA_PREFIX.equalsIgnoreCase(alg)) {
                    rsaCount.put(Integer.valueOf(split[1]), Integer.valueOf(split[2]));
                }
            }
        }
        Map<Integer, Integer> sortMap = new HashMap<>(rsaCount.size());
        rsaCount.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEachOrdered(e -> sortMap.put(e.getKey(), e.getValue()));

        if (this.checkFile(sm2Count, symCount, sortMap)) {
            return;
        }
        this.generateKeyStore(sm2Count, symCount, sortMap);
    }

    /**
     * 验证所有索引的密钥是否都存在
     *
     * @return true-均存在  false-不存在或缺失
     */
    private boolean checkFile(int sm2Count, int symCount, Map<Integer, Integer> rsaCount) {
        boolean checkFlag = true;
        String keyStorePath = Constants.PATH + File.separator + Constants.KEY_STORE_NAME + ".keystore";
        try {
            KeyStore keyStore = GMSSLKeyStoreUtils.readKeyStoreFromPath(keyStorePath, Constants.PWD.toCharArray());

            for (int i = 1; i <= sm2Count; i++) {
                if (!keyStore.containsAlias(MessageFormat.format(Constants.ASYMMETRIC_SIGN_ALIAS, i))
                        || !keyStore.containsAlias(MessageFormat.format(Constants.ASYMMETRIC_ENC_ALIAS, i))) {
                    checkFlag = false;
                    break;
                }
            }

            int rsaIndexBegin = sm2Count;
            for (Map.Entry<Integer, Integer> entry : rsaCount.entrySet()) {
                Integer count = entry.getValue();
                for (int i = 1; i <= count; i++) {
                    if (!keyStore.containsAlias(MessageFormat.format(Constants.ASYMMETRIC_SIGN_ALIAS, rsaIndexBegin + i))
                            || !keyStore.containsAlias(MessageFormat.format(Constants.ASYMMETRIC_ENC_ALIAS, rsaIndexBegin + i))) {
                        checkFlag = false;
                        break;
                    }
                }
                if (!checkFlag) {
                    break;
                }
                rsaIndexBegin += count;
            }

            for (int i = 1; i <= symCount; i++) {
                if (!keyStore.containsAlias(MessageFormat.format(Constants.SYMMETRIC_ALIAS, i))) {
                    checkFlag = false;
                    break;
                }
            }
        } catch (FileNotFoundException e) {
            return false;
        } catch (Exception e) {
            logger.warn("key.keystore解析失败", e);
            checkFlag = false;
        }

        if (!checkFlag) {
            logger.warn("key.keystore解析失败或缺失，重新成功key.keystore");
            GMSSLFileUtils.deleteFile(keyStorePath);
        }
        return checkFlag;
    }

    /**
     * 生成KeyStore
     *
     * @param sm2Count sm2数量
     * @param symCount 对称密钥数量
     * @param rsaCount rsa数量
     */
    private void generateKeyStore(int sm2Count, int symCount, Map<Integer, Integer> rsaCount) {
        //获取server.keystore
        String serverKeyStorePath = Constants.PATH + File.separator + Constants.SERVER_KEY_STORE_NAME + ".keystore";
        KeyStore serverKeyStore;
        Certificate serverSignCert;
        try {
            serverKeyStore = GMSSLKeyStoreUtils.readKeyStoreFromPath(serverKeyStorePath, Constants.PWD.toCharArray());
            serverSignCert = GMSSLKeyStoreUtils.readCertificateFromKeyStore(serverKeyStore, "sign");
        } catch (Exception e) {
            serverKeyStore = null;
            serverSignCert = null;
        }
        if (null == serverKeyStore || null == serverSignCert) {
            logger.error("生成key.keystore失败，需先初始化");
            return;
        }

        try {
            KeyStore keyStore = KeyStore.getInstance(GMSSLCertPathUtils.KEYSTORE_TYPE_BKS, BouncyCastleProvider.PROVIDER_NAME);
            keyStore.load(null, null);

            //生成SM2
            for (int i = 1; i <= sm2Count; i++) {
                KeyPair keyPair = GMSSLX509Utils.generateSM2KeyPair();
                keyStore.setKeyEntry(MessageFormat.format(Constants.ASYMMETRIC_SIGN_ALIAS, i), keyPair.getPrivate(), Constants.PWD.toCharArray(), new Certificate[]{serverSignCert});

                keyPair = GMSSLX509Utils.generateSM2KeyPair();
                keyStore.setKeyEntry(MessageFormat.format(Constants.ASYMMETRIC_ENC_ALIAS, i), keyPair.getPrivate(), Constants.PWD.toCharArray(), new Certificate[]{serverSignCert});
                logger.info("SM2密钥索引[{}]生成成功", i);
            }
            //生成RSA
            int rsaIndexBegin = sm2Count;
            for (Map.Entry<Integer, Integer> entry : rsaCount.entrySet()) {
                Integer len = entry.getKey();
                Integer count = entry.getValue();
                for (int i = 1; i <= count; i++) {
                    KeyPair keyPair = GMSSLRSAKeyUtils.generateKeyPairByBC(len);
                    keyStore.setKeyEntry(MessageFormat.format(Constants.ASYMMETRIC_SIGN_ALIAS, rsaIndexBegin + i), keyPair.getPrivate(), Constants.PWD.toCharArray(), new Certificate[]{serverSignCert});

                    keyPair = GMSSLRSAKeyUtils.generateKeyPairByBC(len);
                    keyStore.setKeyEntry(MessageFormat.format(Constants.ASYMMETRIC_ENC_ALIAS, rsaIndexBegin + i), keyPair.getPrivate(), Constants.PWD.toCharArray(), new Certificate[]{serverSignCert});
                    logger.info("RSA-{}密钥索引[{}]生成成功", len, i);
                }
                rsaIndexBegin += count;
            }
            //生成对称密钥
            for (int i = 1; i <= symCount; i++) {
                Key keySpec = new SecretKeySpec(GMSSLRandomUtils.generateRandom(16), "SM4");
                keyStore.setKeyEntry(MessageFormat.format(Constants.SYMMETRIC_ALIAS, i), keySpec, Constants.PWD.toCharArray(), null);
                logger.info("对称密钥索引[{}]生成成功", i);
            }

            GMSSLKeyStoreUtils.saveGMSSLKeyStore(keyStore, Constants.PWD, Constants.PATH, Constants.KEY_STORE_NAME);
            logger.info("key.keystore保存成功");
        } catch (Exception e) {
            logger.error("生成key.keystore失败", e);
        }
    }
}
