package com.koushikdutta.async;

import android.os.Build;

import com.koushikdutta.async.callback.CompletedCallback;
import com.koushikdutta.async.callback.DataCallback;
import com.koushikdutta.async.callback.WritableCallback;
import com.koushikdutta.async.util.Allocator;
import com.koushikdutta.async.wrapper.AsyncSocketWrapper;

import org.apache.http.conn.ssl.StrictHostnameVerifier;

import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.cert.X509Certificate;

import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;

public class AsyncSSLSocketWrapper implements AsyncSocketWrapper, AsyncSSLSocket {
    public interface HandshakeCallback {
        public void onHandshakeCompleted(Exception e, AsyncSSLSocket socket);
    }

    static SSLContext defaultSSLContext;

    AsyncSocket mSocket;
    BufferedDataSink mSink;
    boolean mUnwrapping;
    SSLEngine engine;
    boolean finishedHandshake;
    private int mPort;
    private String mHost;
    private boolean mWrapping;
    HostnameVerifier hostnameVerifier;
    HandshakeCallback handshakeCallback;
    X509Certificate[] peerCertificates;
    WritableCallback mWriteableCallback;
    DataCallback mDataCallback;
    TrustManager[] trustManagers;
    boolean clientMode;

    static {
        // following is the "trust the system" certs setup
        try {
            // critical extension 2.5.29.15 is implemented improperly prior to 4.0.3.
            // https://code.google.com/p/android/issues/detail?id=9307
            // https://groups.google.com/forum/?fromgroups=#!topic/netty/UCfqPPk5O4s
            // certs that use this extension will throw in Cipher.java.
            // fallback is to use a custom SSLContext, and hack around the x509 extension.
            if (Build.VERSION.SDK_INT <= Build.VERSION_CODES.ICE_CREAM_SANDWICH_MR1)
                throw new Exception();
            defaultSSLContext = SSLContext.getInstance("Default");
        }
        catch (Exception ex) {
            try {
                defaultSSLContext = SSLContext.getInstance("TLS");
                TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() {
                    public java.security.cert.X509Certificate[] getAcceptedIssuers() {
                        return new X509Certificate[0];
                    }

                    public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {
                    }

                    public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {
                        for (X509Certificate cert : certs) {
                            if (cert != null && cert.getCriticalExtensionOIDs() != null)
                                cert.getCriticalExtensionOIDs().remove("2.5.29.15");
                        }
                    }
                } };
                defaultSSLContext.init(null, trustAllCerts, null);
            }
            catch (Exception ex2) {
                ex.printStackTrace();
                ex2.printStackTrace();
            }
        }
    }

    public static SSLContext getDefaultSSLContext() {
        return defaultSSLContext;
    }

    public static void handshake(AsyncSocket socket,
                                 String host, int port,
                                 SSLEngine sslEngine,
                                 TrustManager[] trustManagers, HostnameVerifier verifier, boolean clientMode,
                                 final HandshakeCallback callback) {
        AsyncSSLSocketWrapper wrapper = new AsyncSSLSocketWrapper(socket, host, port, sslEngine, trustManagers, verifier, clientMode);
        wrapper.handshakeCallback = callback;
        socket.setClosedCallback(new CompletedCallback() {
            @Override
            public void onCompleted(Exception ex) {
                if (ex != null)
                    callback.onHandshakeCompleted(ex, null);
                else
                    callback.onHandshakeCompleted(new SSLException("socket closed during handshake"), null);
            }
        });
        try {
            wrapper.engine.beginHandshake();
            wrapper.handleHandshakeStatus(wrapper.engine.getHandshakeStatus());
        } catch (SSLException e) {
            wrapper.report(e);
        }
    }

    boolean mEnded;
    Exception mEndException;
    final ByteBufferList pending = new ByteBufferList();

    private AsyncSSLSocketWrapper(AsyncSocket socket,
                                  String host, int port,
                                  SSLEngine sslEngine,
                                  TrustManager[] trustManagers, HostnameVerifier verifier, boolean clientMode) {
        mSocket = socket;
        hostnameVerifier = verifier;
        this.clientMode = clientMode;
        this.trustManagers = trustManagers;
        this.engine = sslEngine;

        mHost = host;
        mPort = port;
        engine.setUseClientMode(clientMode);
        mSink = new BufferedDataSink(socket);
        mSink.setWriteableCallback(new WritableCallback() {
            @Override
            public void onWriteable() {
                if (mWriteableCallback != null)
                    mWriteableCallback.onWriteable();
            }
        });

        // on pause, the emitter is paused to prevent the buffered
        // socket and itself from firing.
        // on resume, emitter is resumed, ssl buffer is flushed as well
        mSocket.setEndCallback(new CompletedCallback() {
            @Override
            public void onCompleted(Exception ex) {
                if (mEnded)
                    return;
                mEnded = true;
                mEndException = ex;
                if (!pending.hasRemaining() && mEndCallback != null)
                    mEndCallback.onCompleted(ex);
            }
        });

        mSocket.setDataCallback(dataCallback);
    }

    final DataCallback dataCallback = new DataCallback() {
        final Allocator allocator = new Allocator().setMinAlloc(8192);
        final ByteBufferList buffered = new ByteBufferList();

        @Override
        public void onDataAvailable(DataEmitter emitter, ByteBufferList bb) {
            if (mUnwrapping)
                return;
            try {
                mUnwrapping = true;

                bb.get(buffered);

                if (buffered.hasRemaining()) {
                    ByteBuffer all = buffered.getAll();
                    buffered.add(all);
                }

                ByteBuffer b = ByteBufferList.EMPTY_BYTEBUFFER;
                while (true) {
                    if (b.remaining() == 0 && buffered.size() > 0) {
                        b = buffered.remove();
                    }
                    int remaining = b.remaining();
                    int before = pending.remaining();

                    SSLEngineResult res;
                    {
                        // wrap to prevent access to the readBuf
                        ByteBuffer readBuf = allocator.allocate();
                        res = engine.unwrap(b, readBuf);
                        addToPending(pending, readBuf);
                        allocator.track(pending.remaining() - before);
                    }
                    if (res.getStatus() == Status.BUFFER_OVERFLOW) {
                        allocator.setMinAlloc(allocator.getMinAlloc() * 2);
                        remaining = -1;
                    }
                    else if (res.getStatus() == Status.BUFFER_UNDERFLOW) {
                        buffered.addFirst(b);
                        if (buffered.size() <= 1) {
                            break;
                        }
                        // pack it
                        remaining = -1;
                        b = buffered.getAll();
                        buffered.addFirst(b);
                        b = ByteBufferList.EMPTY_BYTEBUFFER;
                    }
                    handleHandshakeStatus(res.getHandshakeStatus());
                    if (b.remaining() == remaining && before == pending.remaining()) {
                        buffered.addFirst(b);
                        break;
                    }
                }

                AsyncSSLSocketWrapper.this.onDataAvailable();
            }
            catch (SSLException ex) {
                ex.printStackTrace();
                report(ex);
            }
            finally {
                mUnwrapping = false;
            }
        }
    };

    public void onDataAvailable() {
        Util.emitAllData(this, pending);

        if (mEnded && !pending.hasRemaining() && mEndCallback != null)
            mEndCallback.onCompleted(mEndException);
    }


    @Override
    public SSLEngine getSSLEngine() {
        return engine;
    }

    void addToPending(ByteBufferList out, ByteBuffer mReadTmp) {
        mReadTmp.flip();
        if (mReadTmp.hasRemaining()) {
            out.add(mReadTmp);
        }
        else {
            ByteBufferList.reclaim(mReadTmp);
        }
    }


    @Override
    public void end() {
        mSocket.end();
    }

    public String getHost() {
        return mHost;
    }

    public int getPort() {
        return mPort;
    }

    private void handleHandshakeStatus(HandshakeStatus status) {
        if (status == HandshakeStatus.NEED_TASK) {
            final Runnable task = engine.getDelegatedTask();
            task.run();
        }

        if (status == HandshakeStatus.NEED_WRAP) {
            write(writeList);
        }

        if (status == HandshakeStatus.NEED_UNWRAP) {
            dataCallback.onDataAvailable(this, new ByteBufferList());
        }

        try {
            if (!finishedHandshake && (engine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING || engine.getHandshakeStatus() == HandshakeStatus.FINISHED)) {
                if (clientMode) {
                    TrustManager[] trustManagers = this.trustManagers;
                    if (trustManagers == null) {
                        TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
                        tmf.init((KeyStore) null);
                        trustManagers = tmf.getTrustManagers();
                    }
                    boolean trusted = false;
                    Exception peerUnverifiedCause = null;
                    for (TrustManager tm : trustManagers) {
                        try {
                            X509TrustManager xtm = (X509TrustManager) tm;
                            peerCertificates = (X509Certificate[]) engine.getSession().getPeerCertificates();
                            xtm.checkServerTrusted(peerCertificates, "SSL");
                            if (mHost != null) {
                                if (hostnameVerifier == null) {
                                    StrictHostnameVerifier verifier = new StrictHostnameVerifier();
                                    verifier.verify(mHost, StrictHostnameVerifier.getCNs(peerCertificates[0]), StrictHostnameVerifier.getDNSSubjectAlts(peerCertificates[0]));
                                }
                                else {
                                    if (!hostnameVerifier.verify(mHost, engine.getSession())) {
                                        throw new SSLException("hostname <" + mHost + "> has been denied");
                                    }
                                }
                            }
                            trusted = true;
                            break;
                        }
                        catch (GeneralSecurityException ex) {
                            peerUnverifiedCause = ex;
                        }
                        catch (SSLException ex) {
                            peerUnverifiedCause = ex;
                        }
                    }
                    finishedHandshake = true;
                    if (!trusted) {
                        AsyncSSLException e = new AsyncSSLException(peerUnverifiedCause);
                        report(e);
                        if (!e.getIgnore())
                            throw e;
                    }
                }
                else {
                    finishedHandshake = true;
                }
                handshakeCallback.onHandshakeCompleted(null, this);
                handshakeCallback = null;

                mSocket.setClosedCallback(null);
                // handshake can complete during a wrap, so make sure that the call
                // stack and wrap flag is cleared before invoking writable
                getServer().post(new Runnable() {
                    @Override
                    public void run() {
                        if (mWriteableCallback != null)
                            mWriteableCallback.onWriteable();
                    }
                });
                onDataAvailable();
            }
        }
        catch (NoSuchAlgorithmException ex) {
            throw new RuntimeException(ex);
        }
        catch (GeneralSecurityException ex) {
            report(ex);
        }
        catch (AsyncSSLException ex) {
            report(ex);
        }
    }

    int calculateAlloc(int remaining) {
        // alloc 50% more than we need for writing
        int alloc = remaining * 3 / 2;
        if (alloc == 0)
            alloc = 8192;
        return alloc;
    }

    ByteBufferList writeList = new ByteBufferList();
    @Override
    public void write(ByteBufferList bb) {
        if (mWrapping)
            return;
        if (mSink.remaining() > 0)
            return;
        mWrapping = true;
        int remaining;
        SSLEngineResult res = null;
        ByteBuffer writeBuf = ByteBufferList.obtain(calculateAlloc(bb.remaining()));
        do {
            // if the handshake is finished, don't send
            // 0 bytes of data, since that makes the ssl connection die.
            // it wraps a 0 byte package, and craps out.
            if (finishedHandshake && bb.remaining() == 0)
                break;
            remaining = bb.remaining();
            try {
                ByteBuffer[] arr = bb.getAllArray();
                res = engine.wrap(arr, writeBuf);
                bb.addAll(arr);
                writeBuf.flip();
                writeList.add(writeBuf);
                assert !writeList.hasRemaining();
                if (writeList.remaining() > 0)
                    mSink.write(writeList);
                int previousCapacity = writeBuf.capacity();
                writeBuf = null;
                if (res.getStatus() == Status.BUFFER_OVERFLOW) {
                    writeBuf = ByteBufferList.obtain(previousCapacity * 2);
                    remaining = -1;
                }
                else {
                    writeBuf = ByteBufferList.obtain(calculateAlloc(bb.remaining()));
                    handleHandshakeStatus(res.getHandshakeStatus());
                }
            }
            catch (SSLException e) {
                report(e);
            }
        }
        while ((remaining != bb.remaining() || (res != null && res.getHandshakeStatus() == HandshakeStatus.NEED_WRAP)) && mSink.remaining() == 0);
        mWrapping = false;
        ByteBufferList.reclaim(writeBuf);
    }

    @Override
    public void setWriteableCallback(WritableCallback handler) {
        mWriteableCallback = handler;
    }

    @Override
    public WritableCallback getWriteableCallback() {
        return mWriteableCallback;
    }

    private void report(Exception e) {
        final HandshakeCallback hs = handshakeCallback;
        if (hs != null) {
            handshakeCallback = null;
            mSocket.setDataCallback(new DataCallback.NullDataCallback());
            mSocket.end();
            // handshake sets this callback. unset it.
            mSocket.setClosedCallback(null);
            mSocket.close();
            hs.onHandshakeCompleted(e, null);
            return;
        }

        CompletedCallback cb = getEndCallback();
        if (cb != null)
            cb.onCompleted(e);
    }

    @Override
    public void setDataCallback(DataCallback callback) {
        mDataCallback = callback;
    }

    @Override
    public DataCallback getDataCallback() {
        return mDataCallback;
    }

    @Override
    public boolean isChunked() {
        return mSocket.isChunked();
    }

    @Override
    public boolean isOpen() {
        return mSocket.isOpen();
    }

    @Override
    public void close() {
        mSocket.close();
    }

    @Override
    public void setClosedCallback(CompletedCallback handler) {
        mSocket.setClosedCallback(handler);
    }

    @Override
    public CompletedCallback getClosedCallback() {
        return mSocket.getClosedCallback();
    }

    CompletedCallback mEndCallback;
    @Override
    public void setEndCallback(CompletedCallback callback) {
        mEndCallback = callback;
    }

    @Override
    public CompletedCallback getEndCallback() {
        return mEndCallback;
    }

    @Override
    public void pause() {
        mSocket.pause();
    }

    @Override
    public void resume() {
        mSocket.resume();
        onDataAvailable();
    }

    @Override
    public boolean isPaused() {
        return mSocket.isPaused();
    }

    @Override
    public AsyncServer getServer() {
        return mSocket.getServer();
    }

    @Override
    public AsyncSocket getSocket() {
        return mSocket;
    }

    @Override
    public DataEmitter getDataEmitter() {
        return mSocket;
    }

    @Override
    public X509Certificate[] getPeerCertificates() {
        return peerCertificates;
    }

    @Override
    public String charset() {
        return null;
    }
}
