netty: replace TLS protocol negotiator with new style handlers
diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
index 8fad41e..7d2a7a7 100644
--- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
+++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java
@@ -25,6 +25,8 @@
import io.grpc.Attributes;
import io.grpc.Grpc;
import io.grpc.InternalChannelz;
+import io.grpc.InternalChannelz.Security;
+import io.grpc.InternalChannelz.Tls;
import io.grpc.SecurityLevel;
import io.grpc.Status;
import io.grpc.internal.GrpcAttributes;
@@ -268,62 +270,13 @@
}
}
- /**
- * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
- * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
- * may happen immediately, even before the TLS Handshake is complete.
- */
- public static ProtocolNegotiator tls(SslContext sslContext) {
- return new TlsNegotiator(sslContext);
- }
+ static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
- @VisibleForTesting
- static final class TlsNegotiator implements ProtocolNegotiator {
- private final SslContext sslContext;
-
- TlsNegotiator(SslContext sslContext) {
+ public ClientTlsProtocolNegotiator(SslContext sslContext) {
this.sslContext = checkNotNull(sslContext, "sslContext");
}
- @VisibleForTesting
- HostPort parseAuthority(String authority) {
- URI uri = GrpcUtil.authorityToUri(Preconditions.checkNotNull(authority, "authority"));
- String host;
- int port;
- if (uri.getHost() != null) {
- host = uri.getHost();
- port = uri.getPort();
- } else {
- /*
- * Implementation note: We pick -1 as the port here rather than deriving it from the
- * original socket address. The SSL engine doesn't use this port number when contacting the
- * remote server, but rather it is used for other things like SSL Session caching. When an
- * invalid authority is provided (like "bad_cert"), picking the original port and passing it
- * in would mean that the port might used under the assumption that it was correct. By
- * using -1 here, it forces the SSL implementation to treat it as invalid.
- */
- host = authority;
- port = -1;
- }
- return new HostPort(host, port);
- }
-
- @Override
- public ChannelHandler newHandler(GrpcHttp2ConnectionHandler handler) {
- final HostPort hostPort = parseAuthority(handler.getAuthority());
-
- ChannelHandler sslBootstrap = new ChannelHandlerAdapter() {
- @Override
- public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
- SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), hostPort.host, hostPort.port);
- SSLParameters sslParams = sslEngine.getSSLParameters();
- sslParams.setEndpointIdentificationAlgorithm("HTTPS");
- sslEngine.setSSLParameters(sslParams);
- ctx.pipeline().replace(this, null, new SslHandler(sslEngine, false));
- }
- };
- return new BufferUntilTlsNegotiatedHandler(sslBootstrap, handler);
- }
+ private final SslContext sslContext;
@Override
public AsciiString scheme() {
@@ -331,9 +284,113 @@
}
@Override
+ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+ ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler);
+ ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority());
+ WaitUntilActiveHandler wuah = new WaitUntilActiveHandler(cth);
+ return wuah;
+ }
+
+ @Override
public void close() {}
}
+ static final class ClientTlsHandler extends ChannelDuplexHandler {
+
+ private final ChannelHandler next;
+ private final SslContext sslContext;
+ private final String host;
+ private final int port;
+
+ private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT;
+
+ ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority) {
+ this.next = checkNotNull(next, "next");
+ this.sslContext = checkNotNull(sslContext, "sslContext");
+ HostPort hostPort = parseAuthority(authority);
+ this.host = hostPort.host;
+ this.port = hostPort.port;
+ }
+
+ @Override
+ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+ SSLEngine sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
+ SSLParameters sslParams = sslEngine.getSSLParameters();
+ sslParams.setEndpointIdentificationAlgorithm("HTTPS");
+ sslEngine.setSSLParameters(sslParams);
+ ctx.pipeline().addBefore(ctx.name(), /* name= */ null, new SslHandler(sslEngine, false));
+ super.handlerAdded(ctx);
+ }
+
+ @Override
+ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
+ if (evt instanceof ProtocolNegotiationEvent) {
+ pne = (ProtocolNegotiationEvent) evt;
+ } else if (evt instanceof SslHandshakeCompletionEvent) {
+ SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
+ if (handshakeEvent.isSuccess()) {
+ SslHandler handler = ctx.pipeline().get(SslHandler.class);
+ if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
+ // Successfully negotiated the protocol.
+ logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
+ ctx.pipeline().replace(ctx.name(), null, next);
+ fireProtocolNegotiationEvent(ctx, handler.engine().getSession());
+ } else {
+ Exception ex = new Exception(
+ "Failed ALPN negotiation: Unable to find compatible protocol.");
+ logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
+ ctx.fireExceptionCaught(ex);
+ }
+ } else {
+ ctx.fireExceptionCaught(handshakeEvent.cause());
+ }
+ } else {
+ super.userEventTriggered(ctx, evt);
+ }
+ }
+
+ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession session) {
+ Security security = new Security(new Tls(session));
+ Attributes attrs = pne.getAttributes().toBuilder()
+ .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
+ .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
+ .build();
+ ctx.fireUserEventTriggered(pne.withAttributes(attrs).withSecurity(security));
+ }
+ }
+
+ @VisibleForTesting
+ static HostPort parseAuthority(String authority) {
+ URI uri = GrpcUtil.authorityToUri(Preconditions.checkNotNull(authority, "authority"));
+ String host;
+ int port;
+ if (uri.getHost() != null) {
+ host = uri.getHost();
+ port = uri.getPort();
+ } else {
+ /*
+ * Implementation note: We pick -1 as the port here rather than deriving it from the
+ * original socket address. The SSL engine doens't use this port number when contacting the
+ * remote server, but rather it is used for other things like SSL Session caching. When an
+ * invalid authority is provided (like "bad_cert"), picking the original port and passing it
+ * in would mean that the port might used under the assumption that it was correct. By
+ * using -1 here, it forces the SSL implementation to treat it as invalid.
+ */
+ host = authority;
+ port = -1;
+ }
+ return new HostPort(host, port);
+ }
+
+ /**
+ * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
+ * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
+ * may happen immediately, even before the TLS Handshake is complete.
+ */
+ public static ProtocolNegotiator tls(SslContext sslContext) {
+ return new ClientTlsProtocolNegotiator(sslContext);
+ }
+
/** A tuple of (host, port). */
@VisibleForTesting
static final class HostPort {
@@ -636,59 +693,6 @@
}
/**
- * Buffers all writes until the TLS Handshake is complete.
- */
- private static class BufferUntilTlsNegotiatedHandler extends AbstractBufferingHandler {
-
- private final GrpcHttp2ConnectionHandler grpcHandler;
-
- BufferUntilTlsNegotiatedHandler(
- ChannelHandler bootstrapHandler, GrpcHttp2ConnectionHandler grpcHandler) {
- super(bootstrapHandler);
- this.grpcHandler = grpcHandler;
- }
-
- @Override
- public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
- if (evt instanceof SslHandshakeCompletionEvent) {
- SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt;
- if (handshakeEvent.isSuccess()) {
- SslHandler handler = ctx.pipeline().get(SslHandler.class);
- if (NEXT_PROTOCOL_VERSIONS.contains(handler.applicationProtocol())) {
- // Successfully negotiated the protocol.
- logSslEngineDetails(Level.FINER, ctx, "TLS negotiation succeeded.", null);
-
- // Wait until negotiation is complete to add gRPC. If added too early, HTTP/2 writes
- // will fail before we see the userEvent, and the channel is closed down prematurely.
- ctx.pipeline().addBefore(ctx.name(), null, grpcHandler);
-
- SSLSession session = handler.engine().getSession();
- // Successfully negotiated the protocol.
- // Notify about completion and pass down SSLSession in attributes.
- grpcHandler.handleProtocolNegotiationCompleted(
- Attributes.newBuilder()
- .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, session)
- .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, ctx.channel().remoteAddress())
- .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, ctx.channel().localAddress())
- .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY)
- .build(),
- new InternalChannelz.Security(new InternalChannelz.Tls(session)));
- writeBufferedAndRemove(ctx);
- } else {
- Exception ex = new Exception(
- "Failed ALPN negotiation: Unable to find compatible protocol.");
- logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed.", ex);
- fail(ctx, ex);
- }
- } else {
- fail(ctx, handshakeEvent.cause());
- }
- }
- super.userEventTriggered(ctx, evt);
- }
- }
-
- /**
* Buffers all writes until the HTTP to HTTP/2 upgrade is complete.
*/
private static class BufferingHttp2UpgradeHandler extends AbstractBufferingHandler {
diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java
index c4f44ab..c8ad48c 100644
--- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java
+++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java
@@ -22,7 +22,6 @@
import io.grpc.ManagedChannel;
import io.grpc.netty.InternalNettyChannelBuilder.OverrideAuthorityChecker;
-import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
import io.netty.handler.ssl.SslContext;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
@@ -168,9 +167,7 @@
NegotiationType.TLS,
GrpcSslContexts.forClient().build());
- assertTrue(negotiator instanceof ProtocolNegotiators.TlsNegotiator);
- ProtocolNegotiators.TlsNegotiator n = (TlsNegotiator) negotiator;
- ProtocolNegotiators.HostPort hostPort = n.parseAuthority("authority:1234");
+ ProtocolNegotiators.HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234");
assertEquals("authority", hostPort.host);
assertEquals(1234, hostPort.port);
@@ -178,13 +175,7 @@
@Test
public void createProtocolNegotiatorByType_tlsWithAuthorityFallback() throws SSLException {
- ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType(
- NegotiationType.TLS,
- GrpcSslContexts.forClient().build());
-
- assertTrue(negotiator instanceof ProtocolNegotiators.TlsNegotiator);
- ProtocolNegotiators.TlsNegotiator n = (TlsNegotiator) negotiator;
- ProtocolNegotiators.HostPort hostPort = n.parseAuthority("bad_authority");
+ ProtocolNegotiators.HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_authority");
assertEquals("bad_authority", hostPort.host);
assertEquals(-1, hostPort.port);
diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
index df1c052..fd21e61 100644
--- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
+++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java
@@ -29,13 +29,14 @@
import io.grpc.Attributes;
import io.grpc.Grpc;
+import io.grpc.InternalChannelz.Security;
import io.grpc.SecurityLevel;
import io.grpc.internal.GrpcAttributes;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.ProtocolNegotiators.AbstractBufferingHandler;
+import io.grpc.netty.ProtocolNegotiators.ClientTlsProtocolNegotiator;
import io.grpc.netty.ProtocolNegotiators.HostPort;
import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
-import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
import io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler;
import io.netty.bootstrap.Bootstrap;
import io.netty.bootstrap.ServerBootstrap;
@@ -52,7 +53,6 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultEventLoop;
import io.netty.channel.DefaultEventLoopGroup;
-import io.netty.channel.EventLoop;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.local.LocalAddress;
@@ -68,9 +68,11 @@
import io.netty.handler.codec.http2.Http2Settings;
import io.netty.handler.proxy.ProxyConnectException;
import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
+import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.File;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
@@ -84,6 +86,7 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLSession;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
@@ -107,7 +110,7 @@
@Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(TIMEOUT_SECONDS));
@Rule public final ExpectedException thrown = ExpectedException.none();
- private final EventLoop group = new DefaultEventLoop();
+ private final EventLoopGroup group = new DefaultEventLoop();
private Channel chan;
private Channel server;
@@ -380,20 +383,16 @@
}
@Test
- public void tls_hostAndPort() throws SSLException {
- SslContext ctx = GrpcSslContexts.forClient().build();
- TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
- HostPort hostPort = negotiator.parseAuthority("authority:1234");
+ public void tls_hostAndPort() {
+ HostPort hostPort = ProtocolNegotiators.parseAuthority("authority:1234");
assertEquals("authority", hostPort.host);
assertEquals(1234, hostPort.port);
}
@Test
- public void tls_host() throws SSLException {
- SslContext ctx = GrpcSslContexts.forClient().build();
- TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
- HostPort hostPort = negotiator.parseAuthority("[::1]");
+ public void tls_host() {
+ HostPort hostPort = ProtocolNegotiators.parseAuthority("[::1]");
assertEquals("[::1]", hostPort.host);
assertEquals(-1, hostPort.port);
@@ -401,9 +400,7 @@
@Test
public void tls_invalidHost() throws SSLException {
- SslContext ctx = GrpcSslContexts.forClient().build();
- TlsNegotiator negotiator = (TlsNegotiator) ProtocolNegotiators.tls(ctx);
- HostPort hostPort = negotiator.parseAuthority("bad_host:1234");
+ HostPort hostPort = ProtocolNegotiators.parseAuthority("bad_host:1234");
// Even though it looks like a port, we treat it as part of the authority, since the host is
// invalid.
@@ -603,17 +600,72 @@
sf.sync();
}
+ @Test
+ public void clientTlsHandler_firesNegotiation() throws Exception {
+ SelfSignedCertificate cert = new SelfSignedCertificate("authority");
+ SslContext clientSslContext =
+ GrpcSslContexts.configure(SslContextBuilder.forClient().trustManager(cert.cert())).build();
+ SslContext serverSslContext =
+ GrpcSslContexts.configure(SslContextBuilder.forServer(cert.key(), cert.cert())).build();
+ FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
+
+ ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext);
+ WriteBufferingAndExceptionHandler wbaeh =
+ new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
+
+ SocketAddress addr = new LocalAddress("addr");
+
+ ChannelHandler sh =
+ ProtocolNegotiators.serverTls(serverSslContext)
+ .newHandler(FakeGrpcHttp2ConnectionHandler.noopHandler());
+ Channel s = new ServerBootstrap()
+ .childHandler(sh)
+ .group(group)
+ .channel(LocalServerChannel.class)
+ .bind(addr)
+ .sync()
+ .channel();
+ Channel c = new Bootstrap()
+ .handler(wbaeh)
+ .channel(LocalChannel.class)
+ .group(group)
+ .register()
+ .sync()
+ .channel();
+ ChannelFuture write = c.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
+ c.connect(addr);
+
+ boolean completed = gh.negotiated.await(5, TimeUnit.SECONDS);
+ if (!completed) {
+ assertTrue("failed to negotiated", write.await(5, TimeUnit.SECONDS));
+ // sync should fail if we are in this block.
+ write.sync();
+ throw new AssertionError("neither wrote nor negotiated");
+ }
+ c.close();
+ s.close();
+
+ assertThat(gh.securityInfo).isNotNull();
+ assertThat(gh.securityInfo.tls).isNotNull();
+ assertThat(gh.attrs.get(GrpcAttributes.ATTR_SECURITY_LEVEL))
+ .isEqualTo(SecurityLevel.PRIVACY_AND_INTEGRITY);
+ assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_SSL_SESSION)).isInstanceOf(SSLSession.class);
+ // This is not part of the ClientTls negotiation, but shows that the negotiation event happens
+ // in the right order.
+ assertThat(gh.attrs.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)).isEqualTo(addr);
+ }
+
private static class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
- static GrpcHttp2ConnectionHandler noopHandler() {
+ static FakeGrpcHttp2ConnectionHandler noopHandler() {
return newHandler(true);
}
- static GrpcHttp2ConnectionHandler newHandler() {
+ static FakeGrpcHttp2ConnectionHandler newHandler() {
return newHandler(false);
}
- private static GrpcHttp2ConnectionHandler newHandler(boolean noop) {
+ private static FakeGrpcHttp2ConnectionHandler newHandler(boolean noop) {
DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
DefaultHttp2ConnectionEncoder encoder =
new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
@@ -625,6 +677,9 @@
}
private final boolean noop;
+ private Attributes attrs;
+ private Security securityInfo;
+ private final CountDownLatch negotiated = new CountDownLatch(1);
FakeGrpcHttp2ConnectionHandler(ChannelPromise channelUnused,
Http2ConnectionDecoder decoder,
@@ -636,6 +691,14 @@
}
@Override
+ public void handleProtocolNegotiationCompleted(Attributes attrs, Security securityInfo) {
+ super.handleProtocolNegotiationCompleted(attrs, securityInfo);
+ this.attrs = attrs;
+ this.securityInfo = securityInfo;
+ negotiated.countDown();
+ }
+
+ @Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
if (noop) {
ctx.pipeline().remove(ctx.name());
@@ -643,6 +706,11 @@
super.handlerAdded(ctx);
}
}
+
+ @Override
+ public String getAuthority() {
+ return "authority";
+ }
}
private static ByteBuf bb(String s, Channel c) {