package com.xdja.pki.gmssl;

import com.xdja.pki.gmssl.core.utils.GMSSLCertPathUtils;
import com.xdja.pki.gmssl.core.utils.GMSSLFileUtils;
import com.xdja.pki.gmssl.http.bean.GMSSLProtocol;
import com.xdja.pki.gmssl.keystore.utils.GMSSLKeyStoreUtils;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.bouncycastle.jsse.provider.XDJAJsseProvider;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.*;
import java.io.*;
import java.security.*;
import java.security.cert.Certificate;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Enumeration;
import java.util.List;

public class GMSSLContext {

    static {
        if (Security.getProvider(BouncyCastleProvider.PROVIDER_NAME) == null) {
            Security.addProvider(new BouncyCastleProvider());
        }
        if (Security.getProvider(XDJAJsseProvider.PROVIDER_NAME) == null) {
            Security.addProvider(new XDJAJsseProvider());
        }
    }

    public final static String DOUBLE_CERTIFICATE_SEPARATOR = XDJAJsseProvider.DOUBLE_CERTIFICATE_SEPARATOR;
    private final static Logger logger = LoggerFactory.getLogger(GMSSLContext.class.getName());

    public static class GMSSLException extends Exception {
        public GMSSLException(String message) {
            super(message);
        }

        public GMSSLException(String message, Throwable cause) {
            super(message, cause);
        }
    }

    private SSLContext sslContext;

    public static TrustManager getTrustAllManager(){
        return new X509TrustManager() {
            @Override
            public X509Certificate[] getAcceptedIssuers() {
                return new X509Certificate[0];
            }

            @Override
            public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
                if (chain == null || chain.length < 1 || authType == null || authType.length() < 1) {
                    throw new IllegalArgumentException();
                }

                String subject = chain[0].getSubjectX500Principal().getName();
                logger.info("Auto-trusted server certificate chain for: " + subject);
            }

            @Override
            public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException {
                if (chain == null || chain.length < 1 || authType == null || authType.length() < 1) {
                    throw new IllegalArgumentException();
                }

                String subject = chain[0].getSubjectX500Principal().getName();
                logger.info("Auto-trusted server certificate chain for: " + subject);
            }
        };
    }

    public static GMSSLContext getClientInstance(String protocol) throws GMSSLException {
        TrustManager trustManager = getTrustAllManager();
        return GMSSLContext.getClientInstance(new TrustManager[]{trustManager}, protocol);
    }

    public static GMSSLContext getClientInstance(char[] password, KeyStore trustStore, String protocol) throws GMSSLException {
        return new GMSSLContext(null, password, trustStore, protocol);
    }

    public static GMSSLContext getClientInstance(KeyStore keyStore, char[] password, KeyStore trustStore, String protocol) throws GMSSLException {
        return new GMSSLContext(keyStore, password, trustStore, protocol);
    }

    public static GMSSLContext getClientInstance(TrustManager[] trustManagers, String protocol) throws GMSSLException {
        return new GMSSLContext(null, trustManagers, protocol);
    }

    public static GMSSLContext getServerInstance(KeyStore serverStore, char[] password, KeyStore trustStore, String protocol) throws GMSSLException {
        return new GMSSLContext(serverStore, password, trustStore, protocol);
    }

    public static GMSSLContext getServerInstance(KeyManager[] keyManagers, TrustManager[] trustManagers, String protocol) throws GMSSLException {
        return new GMSSLContext(keyManagers, trustManagers, protocol);
    }

    private static KeyManager[] getKeyManagers(KeyStore keyStore, char[] password) throws GMSSLException {
        try {
            KeyManager[] keyManagers = null;
            if (keyStore != null && password.length > 0) {
                KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(XDJAJsseProvider.KEY_MANAGER_FACTORY_NAME,
                        XDJAJsseProvider.PROVIDER_NAME);
                keyManagerFactory.init(keyStore, password);
                keyManagers = keyManagerFactory.getKeyManagers();
            }

            return keyManagers;
        } catch (NoSuchProviderException | NoSuchAlgorithmException | UnrecoverableKeyException | KeyStoreException e) {
            throw new GMSSLException("GMSSLContext get key managers exception", e);
        }
    }

    private static TrustManager[] getTrustManagers(KeyStore trustStore) throws GMSSLException {
        try {
            TrustManager[] trustManagers = null;
            if (trustStore != null) {
                TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(XDJAJsseProvider.KEY_MANAGER_FACTORY_NAME,
                        XDJAJsseProvider.PROVIDER_NAME);
                trustManagerFactory.init(trustStore);
                trustManagers = trustManagerFactory.getTrustManagers();
            }

            return trustManagers;
        } catch (NoSuchProviderException | NoSuchAlgorithmException | KeyStoreException e) {
            throw new GMSSLException("GMSSLContext get trust managers exception", e);
        }
    }

    private GMSSLContext(KeyStore keyStore, char[] password, KeyStore trustStore, String protocol) throws GMSSLException {
        this(getKeyManagers(keyStore, password), getTrustManagers(trustStore), protocol);
    }

    private GMSSLContext(KeyManager[] keyManagers, TrustManager[] trustManagers, String protocol) throws GMSSLException {
        try {
            if (GMSSLProtocol.TLSV12.getValue().equalsIgnoreCase(protocol)) {
                this.sslContext = SSLContext.getInstance(protocol);
            } else {
                this.sslContext = SSLContext.getInstance(protocol, XDJAJsseProvider.PROVIDER_NAME);
            }
            this.sslContext.init(keyManagers, trustManagers, null);
        } catch (KeyManagementException | NoSuchProviderException | NoSuchAlgorithmException e) {
            throw new GMSSLException("GMSSLContext new instance exception", e);
        }
    }

    public SSLSocketFactory getSocketFactory() {
        return this.sslContext.getSocketFactory();
    }

    public SSLServerSocketFactory getServerSocketFactory() {
        return this.sslContext.getServerSocketFactory();
    }

    public SSLContext getSslContext() {
        return sslContext;
    }

    /*
     * Gets the SSL server's keystore.
     */
    public static KeyStore getKeystore(String path, String type, String provider, String pass) throws IOException {
        logger.debug("get key store path {}, type {}, provider {}, pass {}", path, type, provider, pass);
        if (path.contains(DOUBLE_CERTIFICATE_SEPARATOR)
                && path.split(DOUBLE_CERTIFICATE_SEPARATOR).length == 2
                && pass.contains(DOUBLE_CERTIFICATE_SEPARATOR)
                && pass.split(DOUBLE_CERTIFICATE_SEPARATOR).length == 2) {

            String[] keys = path.split(DOUBLE_CERTIFICATE_SEPARATOR);
            String[] passes = pass.split(DOUBLE_CERTIFICATE_SEPARATOR);
            String sign = keys[0];
            String enc = keys[1];
            String signPass = passes[0];
            String encPass = passes[1];
            logger.debug("use double certificate keystore, sign is: " + sign + " pass is: " + signPass);
            logger.debug("use double certificate keystore, enc is: " + enc + " pass is: " + encPass);
            KeyStore signKeyStore = getStore(sign, type, provider, signPass);
            KeyStore encKeyStore = getStore(enc, type, provider, encPass);
            Enumeration<String> aliases = null;
            try {
                aliases = encKeyStore.aliases();
                while (aliases.hasMoreElements()) {
                    String alias = aliases.nextElement();
                    if (encKeyStore.isKeyEntry(alias)) {
                        Key key = encKeyStore.getKey(alias, encPass.toCharArray());
                        Certificate[] certificateChain = encKeyStore.getCertificateChain(alias);
                        signKeyStore.setKeyEntry(alias + "-enc", key, encPass.toCharArray(), certificateChain);
                    }
                }
            } catch (Exception e) {
                throw new IOException("merge keystore aliases error", e);
            }
            return signKeyStore;
        }

        return getStore(path, type, provider, pass);
    }

    /*
     * Gets the SSL server's truststore.
     */
    public static KeyStore getTrustStore(String path, String type, String provider, String pass) throws IOException {
        logger.debug("get trust store path {}, type {}, provider {}, pass {}", path, type, provider, pass);
        KeyStore trustStore = null;

        if (path == null) {
            path = System.getProperty("javax.net.ssl.trustStore");
        }
        if (pass == null) {
            pass = System.getProperty("javax.net.ssl.trustStorePassword");
        }
        if (type == null) {
            type = System.getProperty("javax.net.ssl.trustStoreType");
        }
        if (provider == null) {
            provider = System.getProperty("javax.net.ssl.trustStoreProvider");
        }

        if (GMSSLCertPathUtils.checkSupportType(type)) {
            try {
                InputStream in = GMSSLFileUtils.getResourceAsStream(path);
                List<X509Certificate> list = GMSSLCertPathUtils.readCertificatesFromCertPath(in);
                return GMSSLKeyStoreUtils.generateGMSSLTrustStoreWithBKS(list);
            } catch (Exception e) {
                logger.error("type: {}, provider: {}, path: {}, pass: {}", type, provider, path, pass);
                throw new IOException("read certificate from cert path error", e);
            }
        }

        trustStore = getStore(path, type, provider, pass);

        return trustStore;
    }

    /*
     * Gets the key- or truststore with the specified type, path, and password.
     */
    public static KeyStore getStore(String path, String type, String provider, String pass) throws IOException {
        KeyStore ks = null;
        InputStream in = null;
        try {
            if (provider == null) {
                ks = KeyStore.getInstance(type);
            } else {
                ks = KeyStore.getInstance(type, provider);
            }
            in = GMSSLFileUtils.getResourceAsStream(path);
            char[] storePass = null;
            if (pass != null && !"".equals(pass)) {
                storePass = pass.toCharArray();
            }
            ks.load(in, storePass);
        } catch (FileNotFoundException fnfe) {
            logger.error("jsse.keystore_load_failed {} {} {}", type, path, fnfe.getMessage(), fnfe);
            throw fnfe;
        } catch (IOException ioe) {
            // May be expected when working with a trust store
            // Re-throw. Caller will catch and log as required
            throw ioe;
        } catch (Exception ex) {
            logger.error("jsse.keystore_load_failed {} {} {}", type, path, ex.getMessage(), ex);
            throw new IOException(ex);
        } finally {
            if (in != null) {
                try {
                    in.close();
                } catch (IOException ioe) {
                    // Do nothing
                }
            }
        }

        return ks;
    }

}
