/*
 *  Licensed to the Apache Software Foundation (ASF) under one or more
 *  contributor license agreements.  See the NOTICE file distributed with
 *  this work for additional information regarding copyright ownership.
 *  The ASF licenses this file to You under the Apache License, Version 2.0
 *  (the "License"); you may not use this file except in compliance with
 *  the License.  You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package com.xdja.pki.gmssl.tomcat.plugin;

import org.apache.tomcat.util.net.AbstractEndpoint;
import org.apache.tomcat.util.net.Constants;
import org.apache.tomcat.util.net.SSLUtil;
import org.apache.tomcat.util.net.ServerSocketFactory;
import org.apache.tomcat.util.res.StringManager;
import org.bouncycastle.jsse.provider.XDJAJsseProvider;

import javax.net.ssl.*;
import java.io.*;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.security.KeyStore;

/**
 * SSL server socket factory. It <b>requires</b> a valid RSA key and
 * JSSE.<br>
 * keytool -genkey -alias tomcat -keyalg RSA</br>
 * Use "changeit" as password (this is the default we use).
 *
 * @author Harish Prabandham
 * @author Costin Manolache
 * @author Stefan Freyr Stefansson
 * @author EKR -- renamed to JSSESocketFactory
 * @author Jan Luehe
 * @author Bill Barker
 */
public class XDJAJSSESocketFactory implements ServerSocketFactory, SSLUtil {

    private static final org.apache.juli.logging.Log log = org.apache.juli.logging.LogFactory.getLog(XDJAJSSESocketFactory.class);
    private static final StringManager sm = StringManager.getManager("org.apache.tomcat.util.net.jsse.res");

    // Defaults - made public where re-used
    private static final String defaultProtocol = "TLS";
    private static final String defaultKeystoreType = "JKS";
    private static final String defaultKeystoreFile = System.getProperty("user.home") + "/.keystore";
    private static final int defaultSessionCacheSize = 0;
    private static final int defaultSessionTimeout = 86400;
    private static final String ALLOW_ALL_SUPPORTED_CIPHERS = "ALL";
    public static final String DEFAULT_KEY_PASS = "changeit";

    private AbstractEndpoint<?> endpoint;

    protected SSLServerSocketFactory sslProxy = null;
    protected boolean allowUnsafeLegacyRenegotiation = false;

    /**
     * Flag to state that we require client authentication.
     */
    protected boolean requireClientAuth = false;

    /**
     * Flag to state that we would like client authentication.
     */
    protected boolean wantClientAuth = false;


    public XDJAJSSESocketFactory(AbstractEndpoint<?> endpoint) {
        this.endpoint = endpoint;
    }


    @Override
    public ServerSocket createSocket(int port) throws IOException {
        init();
        ServerSocket socket = sslProxy.createServerSocket(port);
        initServerSocket(socket);
        return socket;
    }

    @Override
    public ServerSocket createSocket(int port, int backlog) throws IOException {
        init();
        ServerSocket socket = sslProxy.createServerSocket(port, backlog);
        initServerSocket(socket);
        return socket;
    }

    @Override
    public ServerSocket createSocket(int port, int backlog, InetAddress ifAddress) throws IOException {
        init();
        ServerSocket socket = sslProxy.createServerSocket(port, backlog, ifAddress);
        initServerSocket(socket);
        return socket;
    }

    @Override
    public Socket acceptSocket(ServerSocket socket) throws IOException {
        SSLSocket asock = null;
        try {
            asock = (SSLSocket) socket.accept();
        } catch (SSLException e) {
            throw new SocketException("SSL handshake error" + e.toString());
        }
        return asock;
    }

    @Override
    public void handshake(Socket sock) throws IOException {
        // We do getSession instead of startHandshake() so we can call this multiple times
        SSLSession session = ((SSLSocket) sock).getSession();
        if (session.getCipherSuite().equals("SSL_NULL_WITH_NULL_NULL")) {
            throw new IOException("SSL handshake failed. Ciper suite in SSL Session is SSL_NULL_WITH_NULL_NULL");
        }
    }

    @Override
    public String[] getEnableableCiphers(SSLContext context) {
        return context.getSupportedSSLParameters().getCipherSuites();
    }

    /*
     * Gets the SSL server's keystore password.
     */
    protected String getKeystorePassword() {
        String keystorePass = endpoint.getKeystorePass();
        if (keystorePass == null) {
            keystorePass = endpoint.getKeyPass();
        }
        if (keystorePass == null) {
            keystorePass = DEFAULT_KEY_PASS;
        }
        return keystorePass;
    }

    /*
     * Gets the SSL server's keystore.
     */
    protected KeyStore getKeystore(String type, String provider, String pass)
            throws IOException {

        String keystoreFile = endpoint.getKeystoreFile();
        if (keystoreFile == null) {
            keystoreFile = defaultKeystoreFile;
        }

        return getStore(type, provider, keystoreFile, pass);
    }

    /*
     * Gets the SSL server's truststore.
     */
    protected KeyStore getTrustStore() throws IOException {
        KeyStore trustStore = null;

        String truststoreFile = endpoint.getTruststoreFile();
        if (truststoreFile == null) {
            truststoreFile = System.getProperty("javax.net.ssl.trustStore");
        }

        String truststorePassword = endpoint.getTruststorePass();
        if (truststorePassword == null) {
            truststorePassword = System.getProperty("javax.net.ssl.trustStorePassword");
        }

        String truststoreType = endpoint.getTruststoreType();
        if (truststoreType == null) {
            truststoreType = System.getProperty("javax.net.ssl.trustStoreType");
        }

        String truststoreProvider = endpoint.getTruststoreProvider();
        if (truststoreProvider == null) {
            truststoreProvider = System.getProperty("javax.net.ssl.trustStoreProvider");
        }

        trustStore = getStore(truststoreType, truststoreProvider, truststoreFile, truststorePassword);

        return trustStore;
    }

    /*
     * Gets the key- or truststore with the specified type, path, and password.
     */
    private KeyStore getStore(String type, String provider, String path, String pass) throws IOException {
        KeyStore ks = null;
        InputStream in = null;
        try {
            if (provider == null) {
                ks = KeyStore.getInstance(type);
            } else {
                ks = KeyStore.getInstance(type, provider);
            }
            File keyStoreFile = new File(path);
            if (!keyStoreFile.isAbsolute()) {
                keyStoreFile = new File(System.getProperty(
                        Constants.CATALINA_BASE_PROP), path);
            }
            in = new FileInputStream(keyStoreFile);
            char[] storePass = null;
            if (pass != null && !"".equals(pass)) {
                storePass = pass.toCharArray();
            }
            ks.load(in, storePass);

        } catch (FileNotFoundException fnfe) {
            log.error(sm.getString("jsse.keystore_load_failed", type, path, fnfe.getMessage()), fnfe);
            throw fnfe;
        } catch (IOException ioe) {
            // May be expected when working with a trust store
            // Re-throw. Caller will catch and log as required
            throw ioe;
        } catch (Exception ex) {
            String msg = sm.getString("jsse.keystore_load_failed", type, path, ex.getMessage());
            log.error(msg, ex);
            throw new IOException(msg);
        } finally {
            if (in != null) {
                try {
                    in.close();
                } catch (IOException ioe) {
                    // Do nothing
                }
            }
        }

        return ks;
    }

    /**
     * Reads the keystore and initializes the SSL socket factory.
     */
    void init() throws IOException {
        try {
            String clientAuthStr = endpoint.getClientAuth();
            if ("true".equalsIgnoreCase(clientAuthStr) || "yes".equalsIgnoreCase(clientAuthStr)) {
                requireClientAuth = true;
            } else if ("want".equalsIgnoreCase(clientAuthStr)) {
                wantClientAuth = true;
            }

            SSLContext context = createSSLContext();
            context.init(getKeyManagers(), getTrustManagers(), null);

            // Configure SSL session cache
            SSLSessionContext sessionContext = context.getServerSessionContext();
            if (sessionContext != null) {
                configureSessionContext(sessionContext);
            }

            // create proxy
            sslProxy = context.getServerSocketFactory();

            allowUnsafeLegacyRenegotiation = "true".equals(endpoint.getAllowUnsafeLegacyRenegotiation());

        } catch (Exception e) {
            if (e instanceof IOException) {
                throw (IOException) e;
            }
            throw new IOException(e.getMessage(), e);
        }
    }

    @Override
    public SSLContext createSSLContext() throws Exception {
        // SSL protocol variant (e.g., TLS, SSL v3, etc.)
        String protocol = endpoint.getSslProtocol();
        if (protocol == null) {
            protocol = defaultProtocol;
        }

        return SSLContext.getInstance(protocol, XDJAJsseProvider.PROVIDER_NAME);
    }

    @Override
    public KeyManager[] getKeyManagers() throws Exception {
        String password = getKeystorePassword();

        KeyStore ks = getKeystore(endpoint.getKeystoreType(), endpoint.getKeystoreProvider(), password);

        KeyManagerFactory keyMgrFact = KeyManagerFactory.getInstance("PKIX", XDJAJsseProvider.PROVIDER_NAME);
        keyMgrFact.init(ks, password.toCharArray());

        return keyMgrFact.getKeyManagers();
    }

    @Override
    public TrustManager[] getTrustManagers() throws Exception {
        KeyStore ts = getTrustStore();

        TrustManagerFactory trustMgrFact = TrustManagerFactory.getInstance("PKIX", XDJAJsseProvider.PROVIDER_NAME);
        trustMgrFact.init(ts);

        return trustMgrFact.getTrustManagers();
    }

    @Override
    public void configureSessionContext(SSLSessionContext sslSessionContext) {
        int sessionCacheSize;
        if (endpoint.getSessionCacheSize() != null) {
            sessionCacheSize = Integer.parseInt(endpoint.getSessionCacheSize());
        } else {
            sessionCacheSize = defaultSessionCacheSize;
        }

        int sessionTimeout;
        if (endpoint.getSessionTimeout() != null) {
            sessionTimeout = Integer.parseInt(endpoint.getSessionTimeout());
        } else {
            sessionTimeout = defaultSessionTimeout;
        }

        sslSessionContext.setSessionCacheSize(sessionCacheSize);
        sslSessionContext.setSessionTimeout(sessionTimeout);
    }

    @Override
    public String[] getEnableableProtocols(SSLContext context) {
        return new String[]{"GMSSLv1.1"};
    }

    /**
     * Configure Client authentication for this version of JSSE.  The
     * JSSE included in Java 1.4 supports the 'want' value.  Prior
     * versions of JSSE will treat 'want' as 'false'.
     *
     * @param socket the SSLServerSocket
     */
    protected void configureClientAuth(SSLServerSocket socket) {
        if (wantClientAuth) {
            socket.setWantClientAuth(wantClientAuth);
        } else {
            socket.setNeedClientAuth(requireClientAuth);
        }
    }

    /**
     * Configures the given SSL server socket with the requested cipher suites,
     * protocol versions, and need for client authentication
     */
    private void initServerSocket(ServerSocket ssocket) {

        SSLServerSocket socket = (SSLServerSocket) ssocket;

//        socket.setEnabledCipherSuites(enabledCiphers);
//        socket.setEnabledProtocols(enabledProtocols);

        // we don't know if client auth is needed -
        // after parsing the request we may re-handshake
        configureClientAuth(socket);
    }

}
