netty: replace TLS protocol negotiator with new style handlers

diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
index 8fad41e..7d2a7a7 100644
--- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
+++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
@@ -25,6 +25,8 @@
 import io.grpc.Attributes;
 import io.grpc.Grpc;
 import io.grpc.InternalChannelz;
+import io.grpc.InternalChannelz.Security;
+import io.grpc.InternalChannelz.Tls;
 import io.grpc.SecurityLevel;
 import io.grpc.Status;
 import io.grpc.internal.GrpcAttributes;
@@ -268,62 +270,13 @@
     }
   }
 
-  /**
-   * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
-   * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
-   * may happen immediately, even before the TLS Handshake is complete.
-   */
-  public static ProtocolNegotiator tls(SslContext sslContext) {
-    return new TlsNegotiator(sslContext);
-  }
+  static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
 
-  @VisibleForTesting
-  static final class TlsNegotiator implements ProtocolNegotiator {
-    private final SslContext sslContext;
-
-    TlsNegotiator(SslContext sslContext) {
+    public ClientTlsProtocolNegotiator(SslContext sslContext) {
       this.sslContext = checkNotNull(sslContext, "sslContext");
     }
 
-    @VisibleForTesting
-    HostPort parseAuthority(String authority) {
-      URI uri = GrpcUtil.authorityToUri(Preconditions.checkNotNull(authority, "authority"));
-      String host;
-      int port;
-      if (uri.getHost() != null) {
-        host = uri.getHost();
-        port = uri.getPort();
-      } else {
-        /*
-         * Implementation note: We pick -1 as the port here rather than deriving it from the
-         * original socket address.  The SSL engine doesn't use this port number when contacting the
-         * remote server, but rather it is used for other things like SSL Session caching.  When an
-         * invalid authority is provided (like "bad_cert"), picking the original port and passing it
-         * in would mean that the port might used under the assumption that it was correct.   By
-         * using -1 here, it forces the SSL implementation to treat it as invalid.
-         */
-        host = authority;
-        port = -1;
-      }
-      return new HostPort(host, port);
-    }
-
-    @Override
-    public ChannelHandler newHandler(GrpcHttp2ConnectionHandler handler) {
-      final HostPort hostPort = parseAuthority(handler.getAuthority());
-
-      ChannelHandler sslBootstrap = new ChannelHandlerAdapter() {
-        @Override
-        public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
-          SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), hostPort.host, hostPort.port);
-          SSLParameters sslParams = sslEngine.getSSLParameters();
-          sslParams.setEndpointIdentificationAlgorithm("HTTPS");
-          sslEngine.setSSLParameters(sslParams);
-          ctx.pipeline().replace(this, null, new SslHandler(sslEngine, false));
-        }
-      };
-      return new BufferUntilTlsNegotiatedHandler(sslBootstrap, handler);
-    }
+    private final SslContext sslContext;
 
     @Override
     public AsciiString scheme() {
@@ -331,9 +284,113 @@
     }
 
     @Override
+    public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+      ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
+      ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority());
+      WaitUntilActiveHandler wuah = new WaitUntilActiveHandler(cth);
+      return wuah;
+    }
+
+    @Override
     public void close() {}
   }
 
+  static final class ClientTlsHandler extends ChannelDuplexHandler {
+
+    private final ChannelHandler next;
+    private final SslContext sslContext;
+    private final String host;
+    private final int port;
+
+    private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT;
+
+    ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority) {
+      this.next = checkNotNull(next, "next");
+      this.sslContext = checkNotNull(sslContext, "sslContext");
+      HostPort hostPort = parseAuthority(authority);
+      this.host = hostPort.host;
+      this.port = hostPort.port;
+    }
+
+    @Override
+    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+      SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
+      SSLParameters sslParams = sslEngine.getSSLParameters();
+      sslParams.setEndpointIdentificationAlgorithm("HTTPS");
+      sslEngine.setSSLParameters(sslParams);
+      ctx.pipeline().addBefore(ctx.name(), /* name= */ null, new SslHandler(sslEngine, false));
+      super.handlerAdded(ctx);
+    }
+
+    @Override
+    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+      if (evt instanceof ProtocolNegotiationEvent) {
+        pne = (ProtocolNegotiationEvent) evt;
+      } else if (evt instanceof SslHandshakeCompletionEvent) {
+        SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
+        if (handshakeEvent.isSuccess()) {
+          SslHandler handler = ctx.pipeline().get(SslHandler.class);
+          if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
+            // Successfully negotiated the protocol.
+            logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
+            ctx.pipeline().replace(ctx.name(), null, next);
+            fireProtocolNegotiationEvent(ctx, handler.engine().getSession());
+          } else {
+            Exception ex = new Exception(
+                "Failed ALPN negotiation: Unable to find compatible protocol.");
+            logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
+            ctx.fireExceptionCaught(ex);
+          }
+        } else {
+          ctx.fireExceptionCaught(handshakeEvent.cause());
+        }
+      } else {
+        super.userEventTriggered(ctx, evt);
+      }
+    }
+
+    private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) {
+      Security security = new Security(new Tls(session));
+      Attributes attrs = pne.getAttributes().toBuilder()
+          .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
+          .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
+          .build();
+      ctx.fireUserEventTriggered(pne.withAttributes(attrs).withSecurity(security));
+    }
+  }
+
+  @VisibleForTesting
+  static HostPort parseAuthority(String authority) {
+    URI uri = GrpcUtil.authorityToUri(Preconditions.checkNotNull(authority, "authority"));
+    String host;
+    int port;
+    if (uri.getHost() != null) {
+      host = uri.getHost();
+      port = uri.getPort();
+    } else {
+      /*
+       * Implementation note: We pick -1 as the port here rather than deriving it from the
+       * original socket address.  The SSL engine doens't use this port number when contacting the
+       * remote server, but rather it is used for other things like SSL Session caching.  When an
+       * invalid authority is provided (like "bad_cert"), picking the original port and passing it
+       * in would mean that the port might used under the assumption that it was correct.   By
+       * using -1 here, it forces the SSL implementation to treat it as invalid.
+       */
+      host = authority;
+      port = -1;
+    }
+    return new HostPort(host, port);
+  }
+
+  /**
+   * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
+   * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
+   * may happen immediately, even before the TLS Handshake is complete.
+   */
+  public static ProtocolNegotiator tls(SslContext sslContext) {
+    return new ClientTlsProtocolNegotiator(sslContext);
+  }
+
   /** A tuple of (host, port). */
   @VisibleForTesting
   static final class HostPort {
@@ -636,59 +693,6 @@
   }
 
   /**
-   * Buffers all writes until the TLS Handshake is complete.
-   */
-  private static class BufferUntilTlsNegotiatedHandler extends AbstractBufferingHandler {
-
-    private final GrpcHttp2ConnectionHandler grpcHandler;
-
-    BufferUntilTlsNegotiatedHandler(
-        ChannelHandler bootstrapHandler, GrpcHttp2ConnectionHandler grpcHandler) {
-      super(bootstrapHandler);
-      this.grpcHandler = grpcHandler;
-    }
-
-    @Override
-    public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
-      if (evt instanceof SslHandshakeCompletionEvent) {
-        SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
-        if (handshakeEvent.isSuccess()) {
-          SslHandler handler = ctx.pipeline().get(SslHandler.class);
-          if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
-            // Successfully negotiated the protocol.
-            logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
-
-            // Wait until negotiation is complete to add gRPC.   If added too early, HTTP/2 writes
-            // will fail before we see the userEvent, and the channel is closed down prematurely.
-            ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
-
-            SSLSession session = handler.engine().getSession();
-            // Successfully negotiated the protocol.
-            // Notify about completion and pass down SSLSession in attributes.
-            grpcHandler.handleProtocolNegotiationCompleted(
-                Attributes.newBuilder()
-                    .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
-                    .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
-                    .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
-                    .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
-                    .build(),
-                new InternalChannelz.Security(new InternalChannelz.Tls(session)));
-            writeBufferedAndRemove(ctx);
-          } else {
-            Exception ex = new Exception(
-                "Failed ALPN negotiation: Unable to find compatible protocol.");
-            logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
-            fail(ctx, ex);
-          }
-        } else {
-          fail(ctx, handshakeEvent.cause());
-        }
-      }
-      super.userEventTriggered(ctx, evt);
-    }
-  }
-
-  /**
    * Buffers all writes until the HTTP to HTTP/2 upgrade is complete.
    */
   private static class BufferingHttp2UpgradeHandler extends AbstractBufferingHandler {
diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java
index c4f44ab..c8ad48c 100644
--- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java
@@ -22,7 +22,6 @@
 
 import io.grpc.ManagedChannel;
 import io.grpc.netty.InternalNettyChannelBuilder.OverrideAuthorityChecker;
-import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
 import io.netty.handler.ssl.SslContext;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
@@ -168,9 +167,7 @@
         NegotiationType.TLS,
         GrpcSslContexts.forClient().build());
 
-    assertTrue(negotiator instanceof ProtocolNegotiators.TlsNegotiator);
-    ProtocolNegotiators.TlsNegotiator n = (TlsNegotiator) negotiator;
-    ProtocolNegotiators.HostPort hostPort = n.parseAuthority("authority:1234");
+    ProtocolNegotiators.HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234");
 
     assertEquals("authority", hostPort.host);
     assertEquals(1234, hostPort.port);
@@ -178,13 +175,7 @@
 
   @Test
   public void createProtocolNegotiatorByType_tlsWithAuthorityFallback() throws SSLException {
-    ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType(
-        NegotiationType.TLS,
-        GrpcSslContexts.forClient().build());
-
-    assertTrue(negotiator instanceof ProtocolNegotiators.TlsNegotiator);
-    ProtocolNegotiators.TlsNegotiator n = (TlsNegotiator) negotiator;
-    ProtocolNegotiators.HostPort hostPort = n.parseAuthority("bad_authority");
+    ProtocolNegotiators.HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_authority");
 
     assertEquals("bad_authority", hostPort.host);
     assertEquals(-1, hostPort.port);
diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
index df1c052..fd21e61 100644
--- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
+++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
@@ -29,13 +29,14 @@
 
 import io.grpc.Attributes;
 import io.grpc.Grpc;
+import io.grpc.InternalChannelz.Security;
 import io.grpc.SecurityLevel;
 import io.grpc.internal.GrpcAttributes;
 import io.grpc.internal.testing.TestUtils;
 import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler;
+import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator;
 import io.grpc.netty.ProtocolNegotiators.HostPort;
 import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
-import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
 import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
 import io.netty.bootstrap.Bootstrap;
 import io.netty.bootstrap.ServerBootstrap;
@@ -52,7 +53,6 @@
 import io.netty.channel.ChannelPromise;
 import io.netty.channel.DefaultEventLoop;
 import io.netty.channel.DefaultEventLoopGroup;
-import io.netty.channel.EventLoop;
 import io.netty.channel.EventLoopGroup;
 import io.netty.channel.embedded.EmbeddedChannel;
 import io.netty.channel.local.LocalAddress;
@@ -68,9 +68,11 @@
 import io.netty.handler.codec.http2.Http2Settings;
 import io.netty.handler.proxy.ProxyConnectException;
 import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
 import io.netty.handler.ssl.SslHandler;
 import io.netty.handler.ssl.SslHandshakeCompletionEvent;
 import io.netty.handler.ssl.SupportedCipherSuiteFilter;
+import io.netty.handler.ssl.util.SelfSignedCertificate;
 import java.io.File;
 import java.net.InetSocketAddress;
 import java.net.SocketAddress;
@@ -84,6 +86,7 @@
 import javax.net.ssl.SSLContext;
 import javax.net.ssl.SSLEngine;
 import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLSession;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
@@ -107,7 +110,7 @@
   @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS));
   @Rule public final ExpectedException thrown = ExpectedException.none();
 
-  private final EventLoop group = new DefaultEventLoop();
+  private final EventLoopGroup group = new DefaultEventLoop();
   private Channel chan;
   private Channel server;
 
@@ -380,20 +383,16 @@
   }
 
   @Test
-  public void tls_hostAndPort() throws SSLException {
-    SslContext ctx = GrpcSslContexts.forClient().build();
-    TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
-    HostPort hostPort = negotiator.parseAuthority("authority:1234");
+  public void tls_hostAndPort() {
+    HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234");
 
     assertEquals("authority", hostPort.host);
     assertEquals(1234, hostPort.port);
   }
 
   @Test
-  public void tls_host() throws SSLException {
-    SslContext ctx = GrpcSslContexts.forClient().build();
-    TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
-    HostPort hostPort = negotiator.parseAuthority("[::1]");
+  public void tls_host() {
+    HostPort hostPort = ProtocolNegotiators.parseAuthority("[::1]");
 
     assertEquals("[::1]", hostPort.host);
     assertEquals(-1, hostPort.port);
@@ -401,9 +400,7 @@
 
   @Test
   public void tls_invalidHost() throws SSLException {
-    SslContext ctx = GrpcSslContexts.forClient().build();
-    TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
-    HostPort hostPort = negotiator.parseAuthority("bad_host:1234");
+    HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_host:1234");
 
     // Even though it looks like a port, we treat it as part of the authority, since the host is
     // invalid.
@@ -603,17 +600,72 @@
     sf.sync();
   }
 
+  @Test
+  public void clientTlsHandler_firesNegotiation() throws Exception {
+    SelfSignedCertificate cert = new SelfSignedCertificate("authority");
+    SslContext clientSslContext =
+        GrpcSslContexts.configure(SslContextBuilder.forClient().trustManager(cert.cert())).build();
+    SslContext serverSslContext =
+        GrpcSslContexts.configure(SslContextBuilder.forServer(cert.key(), cert.cert())).build();
+    FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
+
+    ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext);
+    WriteBufferingAndExceptionHandler wbaeh =
+        new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
+
+    SocketAddress addr = new LocalAddress("addr");
+
+    ChannelHandler sh =
+        ProtocolNegotiators.serverTls(serverSslContext)
+            .newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler());
+    Channel s = new ServerBootstrap()
+        .childHandler(sh)
+        .group(group)
+        .channel(LocalServerChannel.class)
+        .bind(addr)
+        .sync()
+        .channel();
+    Channel c = new Bootstrap()
+        .handler(wbaeh)
+        .channel(LocalChannel.class)
+        .group(group)
+        .register()
+        .sync()
+        .channel();
+    ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
+    c.connect(addr);
+
+    boolean completed = gh.negotiated.await(5, TimeUnit.SECONDS);
+    if (!completed) {
+      assertTrue("failed to negotiated", write.await(5, TimeUnit.SECONDS));
+      // sync should fail if we are in this block.
+      write.sync();
+      throw new AssertionError("neither wrote nor negotiated");
+    }
+    c.close();
+    s.close();
+
+    assertThat(gh.securityInfo).isNotNull();
+    assertThat(gh.securityInfo.tls).isNotNull();
+    assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL))
+        .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
+    assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_SSL_SESSION)).isInstanceOf(SSLSession.class);
+    // This is not part of the ClientTls negotiation, but shows that the negotiation event happens
+    // in the right order.
+    assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr);
+  }
+
   private static class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
 
-    static GrpcHttp2ConnectionHandler noopHandler() {
+    static FakeGrpcHttp2ConnectionHandler noopHandler() {
       return newHandler(true);
     }
 
-    static GrpcHttp2ConnectionHandler newHandler() {
+    static FakeGrpcHttp2ConnectionHandler newHandler() {
       return newHandler(false);
     }
 
-    private static GrpcHttp2ConnectionHandler newHandler(boolean noop) {
+    private static FakeGrpcHttp2ConnectionHandler newHandler(boolean noop) {
       DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
       DefaultHttp2ConnectionEncoder encoder =
           new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
@@ -625,6 +677,9 @@
     }
 
     private final boolean noop;
+    private Attributes attrs;
+    private Security securityInfo;
+    private final CountDownLatch negotiated = new CountDownLatch(1);
 
     FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused,
         Http2ConnectionDecoder decoder,
@@ -636,6 +691,14 @@
     }
 
     @Override
+    public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) {
+      super.handleProtocolNegotiationCompleted(attrs, securityInfo);
+      this.attrs = attrs;
+      this.securityInfo = securityInfo;
+      negotiated.countDown();
+    }
+
+    @Override
     public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
       if (noop) {
         ctx.pipeline().remove(ctx.name());
@@ -643,6 +706,11 @@
         super.handlerAdded(ctx);
       }
     }
+
+    @Override
+    public String getAuthority() {
+      return "authority";
+    }
   }
 
   private static ByteBuf bb(String s, Channel c) {