481437 - Port ConnectHandler connect and context functionality from Jetty 8.

Restored connect and context functionalities.
diff --git a/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ConnectHandler.java b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ConnectHandler.java
index cc2a9f7..e7c36ef 100644
--- a/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ConnectHandler.java
+++ b/jetty-proxy/src/main/java/org/eclipse/jetty/proxy/ConnectHandler.java
@@ -18,6 +18,7 @@
 
 package org.eclipse.jetty.proxy;
 
+import java.io.Closeable;
 import java.io.IOException;
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
@@ -50,6 +51,7 @@
 import org.eclipse.jetty.server.handler.HandlerWrapper;
 import org.eclipse.jetty.util.BufferUtil;
 import org.eclipse.jetty.util.Callback;
+import org.eclipse.jetty.util.Promise;
 import org.eclipse.jetty.util.TypeUtil;
 import org.eclipse.jetty.util.log.Log;
 import org.eclipse.jetty.util.log.Logger;
@@ -159,21 +161,17 @@
     protected void doStart() throws Exception
     {
         if (executor == null)
-        {
-            setExecutor(getServer().getThreadPool());
-        }
+            executor = getServer().getThreadPool();
+
         if (scheduler == null)
-        {
-            setScheduler(new ScheduledExecutorScheduler());
-            addBean(getScheduler());
-        }
+            addBean(scheduler = new ScheduledExecutorScheduler());
+
         if (bufferPool == null)
-        {
-            setByteBufferPool(new MappedByteBufferPool());
-            addBean(getByteBufferPool());
-        }
+            addBean(bufferPool = new MappedByteBufferPool());
+
         addBean(selector = newSelectorManager());
         selector.setConnectTimeout(getConnectTimeout());
+
         super.doStart();
     }
 
@@ -190,16 +188,8 @@
             String serverAddress = request.getRequestURI();
             if (LOG.isDebugEnabled())
                 LOG.debug("CONNECT request for {}", serverAddress);
-            try
-            {
-                handleConnect(baseRequest, request, response, serverAddress);
-            }
-            catch (Exception x)
-            {
-                // TODO
-                LOG.warn("ConnectHandler " + baseRequest.getUri() + " " + x);
-                LOG.debug(x);
-            }
+
+            handleConnect(baseRequest, request, response, serverAddress);
         }
         else
         {
@@ -217,7 +207,7 @@
      * @param response      the http response
      * @param serverAddress the remote server address in the form {@code host:port}
      */
-    protected void handleConnect(Request baseRequest, HttpServletRequest request, HttpServletResponse response, String serverAddress)
+    protected void handleConnect(Request baseRequest, final HttpServletRequest request, final HttpServletResponse response, String serverAddress)
     {
         baseRequest.setHandled(true);
         try
@@ -248,32 +238,40 @@
                 return;
             }
 
-            SocketChannel channel = SocketChannel.open();
-            channel.socket().setTcpNoDelay(true);
-            channel.configureBlocking(false);
-
-            AsyncContext asyncContext = request.startAsync();
-            asyncContext.setTimeout(0);
-
-            HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport();
-            
+            final HttpTransport transport = baseRequest.getHttpChannel().getHttpTransport();
             // TODO Handle CONNECT over HTTP2!
             if (!(transport instanceof HttpConnection))
             {
                 if (LOG.isDebugEnabled())
-                    LOG.debug("CONNECT forbidden for {}", transport);
+                    LOG.debug("CONNECT not supported for {}", transport);
                 sendConnectResponse(request, response, HttpServletResponse.SC_FORBIDDEN);
                 return;
             }
 
-            InetSocketAddress address = newConnectAddress(host, port);
+            final AsyncContext asyncContext = request.startAsync();
+            asyncContext.setTimeout(0);
+
             if (LOG.isDebugEnabled())
-                LOG.debug("Connecting to {}", address);
-            ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport);
-            if (channel.connect(address))
-                selector.accept(channel, connectContext);
-            else
-                selector.connect(channel, connectContext);
+                LOG.debug("Connecting to {}:{}", host, port);
+
+            connectToServer(request, host, port, new Promise<SocketChannel>()
+            {
+                @Override
+                public void succeeded(SocketChannel channel)
+                {
+                    ConnectContext connectContext = new ConnectContext(request, response, asyncContext, (HttpConnection)transport);
+                    if (channel.isConnected())
+                        selector.accept(channel, connectContext);
+                    else
+                        selector.connect(channel, connectContext);
+                }
+
+                @Override
+                public void failed(Throwable x)
+                {
+                    onConnectFailure(request, response, asyncContext, x);
+                }
+            });
         }
         catch (Exception x)
         {
@@ -281,37 +279,59 @@
         }
     }
 
-    /* ------------------------------------------------------------ */
-    /** Create the address the connect channel will connect to.
-     * @param host The host from the connect request
-     * @param port The port from the connect request
+    protected void connectToServer(HttpServletRequest request, String host, int port, Promise<SocketChannel> promise)
+    {
+        SocketChannel channel = null;
+        try
+        {
+            channel = SocketChannel.open();
+            channel.socket().setTcpNoDelay(true);
+            channel.configureBlocking(false);
+            InetSocketAddress address = newConnectAddress(host, port);
+            channel.connect(address);
+            promise.succeeded(channel);
+        }
+        catch (Throwable x)
+        {
+            close(channel);
+            promise.failed(x);
+        }
+    }
+
+    private void close(Closeable closeable)
+    {
+        try
+        {
+            if (closeable != null)
+                closeable.close();
+        }
+        catch (Throwable x)
+        {
+            LOG.ignore(x);
+        }
+    }
+
+    /**
+     * Creates the server address to connect to.
+     *
+     * @param host The host from the CONNECT request
+     * @param port The port from the CONNECT request
      * @return The InetSocketAddress to connect to.
      */
     protected InetSocketAddress newConnectAddress(String host, int port)
     {
         return new InetSocketAddress(host, port);
     }
-    
+
     protected void onConnectSuccess(ConnectContext connectContext, UpstreamConnection upstreamConnection)
     {
-        HttpConnection httpConnection = connectContext.getHttpConnection();
-        ByteBuffer requestBuffer = httpConnection.getRequestBuffer();
-        ByteBuffer buffer = BufferUtil.EMPTY_BUFFER;
-        int remaining = requestBuffer.remaining();
-        if (remaining > 0)
-        {
-            buffer = bufferPool.acquire(remaining, requestBuffer.isDirect());
-            BufferUtil.flipToFill(buffer);
-            buffer.put(requestBuffer);
-            buffer.flip();
-        }
-
         ConcurrentMap<String, Object> context = connectContext.getContext();
         HttpServletRequest request = connectContext.getRequest();
         prepareContext(request, context);
 
+        HttpConnection httpConnection = connectContext.getHttpConnection();
         EndPoint downstreamEndPoint = httpConnection.getEndPoint();
-        DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, buffer);
+        DownstreamConnection downstreamConnection = newDownstreamConnection(downstreamEndPoint, context, BufferUtil.EMPTY_BUFFER);
         downstreamConnection.setInputBufferSize(getBufferSize());
 
         upstreamConnection.setConnection(downstreamConnection);
@@ -323,6 +343,7 @@
         sendConnectResponse(request, response, HttpServletResponse.SC_OK);
 
         upgradeConnection(request, response, downstreamConnection);
+
         connectContext.getAsyncContext().complete();
     }
 
@@ -348,7 +369,8 @@
         }
         catch (IOException x)
         {
-            // TODO: nothing we can do, close the connection
+            if (LOG.isDebugEnabled())
+                LOG.debug("Could not send CONNECT response", x);
         }
     }
 
@@ -366,9 +388,18 @@
         return true;
     }
 
+    /**
+     * @deprecated use {@link #newDownstreamConnection(EndPoint, ConcurrentMap)} instead
+     */
+    @Deprecated
     protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context, ByteBuffer buffer)
     {
-        return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context, buffer);
+        return newDownstreamConnection(endPoint, context);
+    }
+
+    protected DownstreamConnection newDownstreamConnection(EndPoint endPoint, ConcurrentMap<String, Object> context)
+    {
+        return new DownstreamConnection(endPoint, getExecutor(), getByteBufferPool(), context);
     }
 
     protected UpstreamConnection newUpstreamConnection(EndPoint endPoint, ConnectContext connectContext)
@@ -395,10 +426,23 @@
      *
      * @param endPoint the endPoint to read from
      * @param buffer   the buffer to read data into
+     * @param context  the context information related to the connection
      * @return the number of bytes read (possibly 0 since the read is non-blocking)
      *         or -1 if the channel has been closed remotely
      * @throws IOException if the endPoint cannot be read
      */
+    protected int read(EndPoint endPoint, ByteBuffer buffer, ConcurrentMap<String, Object> context) throws IOException
+    {
+        int read = read(endPoint, buffer);
+        if (LOG.isDebugEnabled())
+            LOG.debug("{} read {} bytes", this, read);
+        return read;
+    }
+
+    /**
+     * @deprecated override {@link #read(EndPoint, ByteBuffer, ConcurrentMap)} instead
+     */
+    @Deprecated
     protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
     {
         return endPoint.fill(buffer);
@@ -411,10 +455,19 @@
      * @param buffer   the buffer to write
      * @param callback the completion callback to invoke
      */
-    protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
+    protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback, ConcurrentMap<String, Object> context)
     {
         if (LOG.isDebugEnabled())
             LOG.debug("{} writing {} bytes", this, buffer.remaining());
+        write(endPoint, buffer, callback);
+    }
+
+    /**
+     * @deprecated override {@link #write(EndPoint, ByteBuffer, Callback, ConcurrentMap)} instead
+     */
+    @Deprecated
+    protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
+    {
         endPoint.write(callback, buffer);
     }
 
@@ -493,6 +546,7 @@
         @Override
         protected void connectionFailed(SocketChannel channel, final Throwable ex, final Object attachment)
         {
+            close(channel);
             getExecutor().execute(new Runnable()
             {
                 public void run()
@@ -573,24 +627,38 @@
         @Override
         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
         {
-            return ConnectHandler.this.read(endPoint, buffer);
+            return ConnectHandler.this.read(endPoint, buffer, getContext());
         }
 
         @Override
         protected void write(EndPoint endPoint, ByteBuffer buffer,Callback callback)
         {
-            ConnectHandler.this.write(endPoint, buffer, callback);
+            ConnectHandler.this.write(endPoint, buffer, callback, getContext());
         }
     }
 
-    public class DownstreamConnection extends ProxyConnection
+    public class DownstreamConnection extends ProxyConnection implements Connection.UpgradeTo
     {
-        private final ByteBuffer buffer;
+        private ByteBuffer buffer;
 
-        public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
+        public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context)
         {
             super(endPoint, executor, bufferPool, context);
-            this.buffer = buffer;
+        }
+
+        /**
+         * @deprecated use {@link #DownstreamConnection(EndPoint, Executor, ByteBufferPool, ConcurrentMap)} instead
+         */
+        @Deprecated
+        public DownstreamConnection(EndPoint endPoint, Executor executor, ByteBufferPool bufferPool, ConcurrentMap<String, Object> context, ByteBuffer buffer)
+        {
+            this(endPoint, executor, bufferPool, context);
+        }
+
+        @Override
+        public void onUpgradeTo(ByteBuffer buffer)
+        {
+            this.buffer = buffer == null ? BufferUtil.EMPTY_BUFFER : buffer;
         }
 
         @Override
@@ -622,13 +690,13 @@
         @Override
         protected int read(EndPoint endPoint, ByteBuffer buffer) throws IOException
         {
-            return ConnectHandler.this.read(endPoint, buffer);
+            return ConnectHandler.this.read(endPoint, buffer, getContext());
         }
 
         @Override
         protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback)
         {
-            ConnectHandler.this.write(endPoint, buffer, callback);
+            ConnectHandler.this.write(endPoint, buffer, callback, getContext());
         }
     }
 }
diff --git a/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ConnectHandlerTest.java b/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ConnectHandlerTest.java
index f601975..780e706 100644
--- a/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ConnectHandlerTest.java
+++ b/jetty-proxy/src/test/java/org/eclipse/jetty/proxy/ConnectHandlerTest.java
@@ -27,6 +27,8 @@
 import java.net.InetAddress;
 import java.net.Socket;
 import java.net.UnknownHostException;
+import java.nio.ByteBuffer;
+import java.nio.channels.SocketChannel;
 import java.nio.charset.StandardCharsets;
 import java.util.Locale;
 import java.util.concurrent.ConcurrentMap;
@@ -36,12 +38,15 @@
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.eclipse.jetty.io.EndPoint;
 import org.eclipse.jetty.server.Request;
 import org.eclipse.jetty.server.Server;
 import org.eclipse.jetty.server.ServerConnector;
 import org.eclipse.jetty.server.handler.AbstractHandler;
 import org.eclipse.jetty.toolchain.test.http.SimpleHttpResponse;
 import org.eclipse.jetty.util.B64Code;
+import org.eclipse.jetty.util.Callback;
+import org.eclipse.jetty.util.Promise;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Test;
@@ -631,12 +636,33 @@
             }
 
             @Override
+            protected void connectToServer(HttpServletRequest request, String host, int port, Promise<SocketChannel> promise)
+            {
+                Assert.assertEquals(contextValue, request.getAttribute(contextKey));
+                super.connectToServer(request, host, port, promise);
+            }
+
+            @Override
             protected void prepareContext(HttpServletRequest request, ConcurrentMap<String, Object> context)
             {
                 // Transfer data from the HTTP request to the connection context
                 Assert.assertEquals(contextValue, request.getAttribute(contextKey));
                 context.put(contextKey, request.getAttribute(contextKey));
             }
+
+            @Override
+            protected int read(EndPoint endPoint, ByteBuffer buffer, ConcurrentMap<String, Object> context) throws IOException
+            {
+                Assert.assertEquals(contextValue, context.get(contextKey));
+                return super.read(endPoint, buffer, context);
+            }
+
+            @Override
+            protected void write(EndPoint endPoint, ByteBuffer buffer, Callback callback, ConcurrentMap<String, Object> context)
+            {
+                Assert.assertEquals(contextValue, context.get(contextKey));
+                super.write(endPoint, buffer, callback, context);
+            }
         });
         proxy.start();