xds: Client and server proto negotiators and handlers added to SdsProtocolNegotiators (#6319)
diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
index e71a695..9f3e599 100644
--- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
+++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
@@ -68,6 +68,82 @@
}
/**
+ * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be
+ * negotiated, the server TLS {@code handler} is added and writes to the {@link
+ * io.netty.channel.Channel} may happen immediately, even before the TLS Handshake is complete.
+ */
+ public static InternalProtocolNegotiator.ProtocolNegotiator serverTls(SslContext sslContext) {
+ final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(sslContext);
+ final class ServerTlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
+
+ @Override
+ public AsciiString scheme() {
+ return negotiator.scheme();
+ }
+
+ @Override
+ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+ return negotiator.newHandler(grpcHandler);
+ }
+
+ @Override
+ public void close() {
+ negotiator.close();
+ }
+ }
+
+ return new ServerTlsNegotiator();
+ }
+
+ /** Returns a {@link ProtocolNegotiator} for plaintext client channel. */
+ public static InternalProtocolNegotiator.ProtocolNegotiator plaintext() {
+ final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.plaintext();
+ final class PlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
+
+ @Override
+ public AsciiString scheme() {
+ return negotiator.scheme();
+ }
+
+ @Override
+ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+ return negotiator.newHandler(grpcHandler);
+ }
+
+ @Override
+ public void close() {
+ negotiator.close();
+ }
+ }
+
+ return new PlaintextNegotiator();
+ }
+
+ /** Returns a {@link ProtocolNegotiator} for plaintext server channel. */
+ public static InternalProtocolNegotiator.ProtocolNegotiator serverPlaintext() {
+ final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverPlaintext();
+ final class ServerPlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
+
+ @Override
+ public AsciiString scheme() {
+ return negotiator.scheme();
+ }
+
+ @Override
+ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+ return negotiator.newHandler(grpcHandler);
+ }
+
+ @Override
+ public void close() {
+ negotiator.close();
+ }
+ }
+
+ return new ServerPlaintextNegotiator();
+ }
+
+ /**
* Internal version of {@link WaitUntilActiveHandler}.
*/
public static ChannelHandler waitUntilActiveHandler(ChannelHandler next) {
diff --git a/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java b/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java
index 682b401..2e4e5d2 100644
--- a/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java
+++ b/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java
@@ -16,14 +16,17 @@
package io.grpc.xds.sds;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.ExperimentalApi;
import io.grpc.ForwardingChannelBuilder;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
+import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.xds.sds.internal.SdsProtocolNegotiators;
import java.net.SocketAddress;
import javax.annotation.CheckReturnValue;
+import javax.annotation.Nullable;
/**
* A version of {@link ManagedChannelBuilder} to create xDS managed channels that will use SDS to
@@ -34,9 +37,11 @@
private final NettyChannelBuilder delegate;
+ // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+ @Nullable private UpstreamTlsContext upstreamTlsContext;
+
private XdsChannelBuilder(NettyChannelBuilder delegate) {
this.delegate = delegate;
- SdsProtocolNegotiators.setProtocolNegotiatorFactory(delegate);
}
/**
@@ -66,6 +71,15 @@
return new XdsChannelBuilder(NettyChannelBuilder.forTarget(target));
}
+ /**
+ * Set the UpstreamTlsContext for this channel. This is a temporary workaround until CDS is
+ * implemented in the XDS client. Passing {@code null} will fall back to plaintext.
+ */
+ public XdsChannelBuilder tlsContext(@Nullable UpstreamTlsContext upstreamTlsContext) {
+ this.upstreamTlsContext = upstreamTlsContext;
+ return this;
+ }
+
@Override
protected ManagedChannelBuilder<?> delegate() {
return delegate;
@@ -73,6 +87,8 @@
@Override
public ManagedChannel build() {
+ InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
+ delegate, SdsProtocolNegotiators.clientProtocolNegotiatorFactory(upstreamTlsContext));
return delegate.build();
}
}
diff --git a/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java
index f596d19..8f3b28b 100644
--- a/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java
+++ b/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java
@@ -16,6 +16,7 @@
package io.grpc.xds.sds;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
import io.grpc.BindableService;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
@@ -44,6 +45,9 @@
private final NettyServerBuilder delegate;
+ // TODO (sanjaypujare) integrate with xDS client to get downstreamTlsContext from LDS
+ @Nullable private DownstreamTlsContext downstreamTlsContext;
+
private XdsServerBuilder(NettyServerBuilder nettyDelegate) {
this.delegate = nettyDelegate;
}
@@ -119,6 +123,15 @@
return this;
}
+ /**
+ * Set the DownstreamTlsContext for the server. This is a temporary workaround until integration
+ * with xDS client is implemented to get LDS. Passing {@code null} will fall back to plaintext.
+ */
+ public XdsServerBuilder tlsContext(@Nullable DownstreamTlsContext downstreamTlsContext) {
+ this.downstreamTlsContext = downstreamTlsContext;
+ return this;
+ }
+
/** Creates a gRPC server builder for the given port. */
public static XdsServerBuilder forPort(int port) {
NettyServerBuilder nettyDelegate = NettyServerBuilder.forAddress(new InetSocketAddress(port));
@@ -128,7 +141,8 @@
@Override
public Server build() {
// note: doing it in build() will overwrite any previously set ProtocolNegotiator
- delegate.protocolNegotiator(SdsProtocolNegotiators.serverProtocolNegotiator());
+ delegate.protocolNegotiator(
+ SdsProtocolNegotiators.serverProtocolNegotiator(this.downstreamTlsContext));
return delegate.build();
}
}
diff --git a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java
index 9bc5129..4a72ea2 100644
--- a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java
+++ b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java
@@ -16,15 +16,30 @@
package io.grpc.xds.sds.internal;
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcHttp2ConnectionHandler;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
import io.grpc.netty.InternalProtocolNegotiator;
import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
+import io.grpc.netty.InternalProtocolNegotiators;
import io.grpc.netty.NettyChannelBuilder;
+import io.grpc.xds.sds.SecretProvider;
+import io.grpc.xds.sds.TlsContextManager;
import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerAdapter;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.ssl.SslContext;
import io.netty.util.AsciiString;
+import java.util.ArrayList;
+import java.util.List;
+import javax.annotation.Nullable;
/**
* Provides client and server side gRPC {@link ProtocolNegotiator}s that use SDS to provide the SSL
@@ -35,12 +50,40 @@
private static final AsciiString SCHEME = AsciiString.of("https");
+ /**
+ * Returns a {@link ProtocolNegotiatorFactory} to be used on {@link NettyChannelBuilder}. Passing
+ * {@code null} for upstreamTlsContext will fall back to plaintext.
+ */
+ // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+ public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory(
+ @Nullable UpstreamTlsContext upstreamTlsContext) {
+ return new ClientSdsProtocolNegotiatorFactory(upstreamTlsContext);
+ }
+
+ /**
+ * Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}.
+ * Passing {@code null} for downstreamTlsContext will fall back to plaintext.
+ */
+ // TODO (sanjaypujare) integrate with xDS client to get LDS
+ public static ProtocolNegotiator serverProtocolNegotiator(
+ @Nullable DownstreamTlsContext downstreamTlsContext) {
+ return new ServerSdsProtocolNegotiator(downstreamTlsContext);
+ }
+
private static final class ClientSdsProtocolNegotiatorFactory
implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
+ // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+ private final UpstreamTlsContext upstreamTlsContext;
+
+ ClientSdsProtocolNegotiatorFactory(UpstreamTlsContext upstreamTlsContext) {
+ this.upstreamTlsContext = upstreamTlsContext;
+ }
+
@Override
public InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator() {
- final ClientSdsProtocolNegotiator negotiator = new ClientSdsProtocolNegotiator();
+ final ClientSdsProtocolNegotiator negotiator =
+ new ClientSdsProtocolNegotiator(upstreamTlsContext);
final class LocalSdsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
@Override
@@ -63,7 +106,15 @@
}
}
- private static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator {
+ @VisibleForTesting
+ static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator {
+
+ // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+ UpstreamTlsContext upstreamTlsContext;
+
+ ClientSdsProtocolNegotiator(UpstreamTlsContext upstreamTlsContext) {
+ this.upstreamTlsContext = upstreamTlsContext;
+ }
@Override
public AsciiString scheme() {
@@ -72,16 +123,111 @@
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
- // TODO(sanjaypujare): once implemented return ClientSdsHandler
- throw new UnsupportedOperationException("Not implemented yet");
+ // once CDS is implemented we will retrieve upstreamTlsContext as follows:
+ // grpcHandler.getEagAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT);
+ if (isTlsContextEmpty(upstreamTlsContext)) {
+ return InternalProtocolNegotiators.plaintext().newHandler(grpcHandler);
+ }
+ return new ClientSdsHandler(grpcHandler, upstreamTlsContext);
+ }
+
+ private static boolean isTlsContextEmpty(UpstreamTlsContext upstreamTlsContext) {
+ return upstreamTlsContext == null || !upstreamTlsContext.hasCommonTlsContext();
}
@Override
public void close() {}
}
+ private static class BufferReadsHandler extends ChannelInboundHandlerAdapter {
+ private final List<Object> reads = new ArrayList<>();
+ private boolean readComplete;
+
+ @Override
+ public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ reads.add(msg);
+ }
+
+ @Override
+ public void channelReadComplete(ChannelHandlerContext ctx) {
+ readComplete = true;
+ }
+
+ @Override
+ public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+ for (Object msg : reads) {
+ super.channelRead(ctx, msg);
+ }
+ if (readComplete) {
+ super.channelReadComplete(ctx);
+ }
+ }
+ }
+
+ @VisibleForTesting
+ static final class ClientSdsHandler
+ extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
+ private final GrpcHttp2ConnectionHandler grpcHandler;
+ private final UpstreamTlsContext upstreamTlsContext;
+
+ ClientSdsHandler(
+ GrpcHttp2ConnectionHandler grpcHandler, UpstreamTlsContext upstreamTlsContext) {
+ super(
+ // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
+ // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
+ // here and then manually add 'next' when we call fireProtocolNegotiationEvent()
+ new ChannelHandlerAdapter() {
+ @Override
+ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+ ctx.pipeline().remove(this);
+ }
+ });
+ checkNotNull(grpcHandler, "grpcHandler");
+ this.grpcHandler = grpcHandler;
+ this.upstreamTlsContext = upstreamTlsContext;
+ }
+
+ @Override
+ protected void handlerAdded0(final ChannelHandlerContext ctx) {
+ final BufferReadsHandler bufferReads = new BufferReadsHandler();
+ ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
+
+ SecretProvider<SslContext> sslContextProvider =
+ TlsContextManager.getInstance().findOrCreateClientSslContextProvider(upstreamTlsContext);
+
+ sslContextProvider.addCallback(
+ new SecretProvider.Callback<SslContext>() {
+
+ @Override
+ public void updateSecret(SslContext sslContext) {
+ ChannelHandler handler =
+ InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
+
+ // Delegate rest of handshake to TLS handler
+ ctx.pipeline().addAfter(ctx.name(), null, handler);
+ fireProtocolNegotiationEvent(ctx);
+ ctx.pipeline().remove(bufferReads);
+ }
+
+ @Override
+ public void onException(Throwable throwable) {
+ ctx.fireExceptionCaught(throwable);
+ }
+ },
+ ctx.executor());
+ }
+ }
+
private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator {
+ // TODO (sanjaypujare) integrate with xDS client to get LDS. LDS watcher will
+ // inject/update the downstreamTlsContext from LDS
+ private DownstreamTlsContext downstreamTlsContext;
+
+ ServerSdsProtocolNegotiator(DownstreamTlsContext downstreamTlsContext) {
+ this.downstreamTlsContext = downstreamTlsContext;
+ }
+
@Override
public AsciiString scheme() {
return SCHEME;
@@ -89,22 +235,72 @@
@Override
public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
- // TODO(sanjaypujare): once implemented return ServerSdsHandler
- throw new UnsupportedOperationException("Not implemented yet");
+ if (isTlsContextEmpty(downstreamTlsContext)) {
+ return InternalProtocolNegotiators.serverPlaintext().newHandler(grpcHandler);
+ }
+ return new ServerSdsHandler(grpcHandler, downstreamTlsContext);
+ }
+
+ private static boolean isTlsContextEmpty(DownstreamTlsContext downstreamTlsContext) {
+ return downstreamTlsContext == null || !downstreamTlsContext.hasCommonTlsContext();
}
@Override
public void close() {}
}
- /** Sets the {@link ProtocolNegotiatorFactory} on a NettyChannelBuilder. */
- public static void setProtocolNegotiatorFactory(NettyChannelBuilder builder) {
- InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
- builder, new ClientSdsProtocolNegotiatorFactory());
- }
+ @VisibleForTesting
+ static final class ServerSdsHandler
+ extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
+ private final GrpcHttp2ConnectionHandler grpcHandler;
+ private final DownstreamTlsContext downstreamTlsContext;
- /** Creates an SDS based {@link ProtocolNegotiator} for a server. */
- public static ProtocolNegotiator serverProtocolNegotiator() {
- return new ServerSdsProtocolNegotiator();
+ ServerSdsHandler(
+ GrpcHttp2ConnectionHandler grpcHandler, DownstreamTlsContext downstreamTlsContext) {
+ super(
+ // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
+ // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
+ // here and then manually add 'next' when we call fireProtocolNegotiationEvent()
+ new ChannelHandlerAdapter() {
+ @Override
+ public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+ ctx.pipeline().remove(this);
+ }
+ });
+ checkNotNull(grpcHandler, "grpcHandler");
+ this.grpcHandler = grpcHandler;
+ this.downstreamTlsContext = downstreamTlsContext;
+ }
+
+ @Override
+ protected void handlerAdded0(final ChannelHandlerContext ctx) {
+ final BufferReadsHandler bufferReads = new BufferReadsHandler();
+ ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
+
+ SecretProvider<SslContext> sslContextProvider =
+ TlsContextManager.getInstance()
+ .findOrCreateServerSslContextProvider(downstreamTlsContext);
+
+ sslContextProvider.addCallback(
+ new SecretProvider.Callback<SslContext>() {
+
+ @Override
+ public void updateSecret(SslContext sslContext) {
+ ChannelHandler handler =
+ InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler);
+
+ // Delegate rest of handshake to TLS handler
+ ctx.pipeline().addAfter(ctx.name(), null, handler);
+ fireProtocolNegotiationEvent(ctx);
+ ctx.pipeline().remove(bufferReads);
+ }
+
+ @Override
+ public void onException(Throwable throwable) {
+ ctx.fireExceptionCaught(throwable);
+ }
+ },
+ ctx.executor());
+ }
}
}
diff --git a/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java
new file mode 100644
index 0000000..5a2687b
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java
@@ -0,0 +1,193 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds.sds;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.protobuf.BoolValue;
+import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
+import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.core.DataSource;
+import io.grpc.Server;
+import io.grpc.internal.testing.TestUtils;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.testing.protobuf.SimpleRequest;
+import io.grpc.testing.protobuf.SimpleResponse;
+import io.grpc.testing.protobuf.SimpleServiceGrpc;
+import java.io.IOException;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS
+ * modes.
+ */
+@RunWith(JUnit4.class)
+public class XdsSdsClientServerTest {
+
+ @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
+
+ @Test
+ public void plaintextClientServer() throws IOException {
+ Server server = getXdsServer(/* downstreamTlsContext= */ null);
+ buildClientAndTest(
+ /* upstreamTlsContext= */ null, /* overrideAuthority= */ null, "buddy", server.getPort());
+ }
+
+ /** TLS channel - no mTLS. */
+ @Test
+ public void tlsClientServer_noClientAuthentication() throws IOException {
+ String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
+ String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
+
+ TlsCertificate tlsCert =
+ TlsCertificate.newBuilder()
+ .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
+ .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
+ .build();
+
+ CommonTlsContext commonTlsContext =
+ CommonTlsContext.newBuilder().addTlsCertificates(tlsCert).build();
+
+ DownstreamTlsContext downstreamTlsContext =
+ DownstreamTlsContext.newBuilder()
+ .setCommonTlsContext(commonTlsContext)
+ .setRequireClientCertificate(BoolValue.of(false))
+ .build();
+
+ Server server = getXdsServer(downstreamTlsContext);
+
+ // for TLS client doesn't need cert but needs trustCa
+ String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
+ CertificateValidationContext certContext =
+ CertificateValidationContext.newBuilder()
+ .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
+ .build();
+
+ CommonTlsContext commonTlsContext1 =
+ CommonTlsContext.newBuilder().setValidationContext(certContext).build();
+
+ UpstreamTlsContext upstreamTlsContext =
+ UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
+ buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
+ }
+
+ /** mTLS - client auth enabled. */
+ @Test
+ public void mtlsClientServer_withClientAuthentication() throws IOException, InterruptedException {
+ String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
+ String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
+ String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
+
+ TlsCertificate tlsCert =
+ TlsCertificate.newBuilder()
+ .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
+ .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
+ .build();
+
+ CertificateValidationContext certContext =
+ CertificateValidationContext.newBuilder()
+ .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
+ .build();
+
+ CommonTlsContext commonTlsContext =
+ CommonTlsContext.newBuilder()
+ .addTlsCertificates(tlsCert)
+ .setValidationContext(certContext)
+ .build();
+
+ DownstreamTlsContext downstreamTlsContext =
+ DownstreamTlsContext.newBuilder()
+ .setCommonTlsContext(commonTlsContext)
+ .setRequireClientCertificate(BoolValue.of(false))
+ .build();
+
+ Server server = getXdsServer(downstreamTlsContext);
+
+ String clientPem = TestUtils.loadCert("client.pem").getAbsolutePath();
+ String clientKey = TestUtils.loadCert("client.key").getAbsolutePath();
+
+ TlsCertificate tlsCert1 =
+ TlsCertificate.newBuilder()
+ .setPrivateKey(DataSource.newBuilder().setFilename(clientKey).build())
+ .setCertificateChain(DataSource.newBuilder().setFilename(clientPem).build())
+ .build();
+
+ CommonTlsContext commonTlsContext1 =
+ CommonTlsContext.newBuilder()
+ .addTlsCertificates(tlsCert1)
+ .setValidationContext(certContext)
+ .build();
+
+ UpstreamTlsContext upstreamTlsContext =
+ UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
+
+ buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
+ }
+
+ private Server getXdsServer(DownstreamTlsContext downstreamTlsContext) throws IOException {
+ XdsServerBuilder serverBuilder =
+ XdsServerBuilder.forPort(0) // get unused port
+ .addService(new SimpleServiceImpl())
+ .tlsContext(downstreamTlsContext);
+ return cleanupRule.register(serverBuilder.build()).start();
+ }
+
+ private void buildClientAndTest(
+ UpstreamTlsContext upstreamTlsContext,
+ String overrideAuthority,
+ String requestMessage,
+ int serverPort) {
+
+ XdsChannelBuilder builder =
+ XdsChannelBuilder.forTarget("localhost:" + serverPort).tlsContext(upstreamTlsContext);
+ if (overrideAuthority != null) {
+ builder = builder.overrideAuthority(overrideAuthority);
+ }
+ SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
+ SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build()));
+ String resp = unaryRpc(requestMessage, blockingStub);
+ assertThat(resp).isEqualTo("Hello " + requestMessage);
+ }
+
+ /** Say hello to server. */
+ private static String unaryRpc(
+ String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
+ SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build();
+ SimpleResponse response = blockingStub.unaryRpc(request);
+ return response.getResponseMessage();
+ }
+
+ private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
+
+ @Override
+ public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> responseObserver) {
+ SimpleResponse response =
+ SimpleResponse.newBuilder()
+ .setResponseMessage("Hello " + req.getRequestMessage())
+ .build();
+ responseObserver.onNext(response);
+ responseObserver.onCompleted();
+ }
+ }
+}
diff --git a/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java
new file mode 100644
index 0000000..ea3c653
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java
@@ -0,0 +1,263 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds.sds.internal;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.base.Strings;
+import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
+import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.core.DataSource;
+import io.grpc.internal.testing.TestUtils;
+import io.grpc.netty.GrpcHttp2ConnectionHandler;
+import io.grpc.netty.InternalProtocolNegotiationEvent;
+import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsHandler;
+import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsProtocolNegotiator;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http2.DefaultHttp2Connection;
+import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
+import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
+import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
+import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
+import io.netty.handler.codec.http2.Http2ConnectionDecoder;
+import io.netty.handler.codec.http2.Http2ConnectionEncoder;
+import io.netty.handler.codec.http2.Http2Settings;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.ssl.SslHandshakeCompletionEvent;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link SdsProtocolNegotiators}. */
+@RunWith(JUnit4.class)
+public class SdsProtocolNegotiatorsTest {
+
+ private static final String SERVER_1_PEM_FILE = "server1.pem";
+ private static final String SERVER_1_KEY_FILE = "server1.key";
+ private static final String CLIENT_PEM_FILE = "client.pem";
+ private static final String CLIENT_KEY_FILE = "client.key";
+ private static final String CA_PEM_FILE = "ca.pem";
+
+ private final GrpcHttp2ConnectionHandler grpcHandler =
+ FakeGrpcHttp2ConnectionHandler.newHandler();
+
+ private EmbeddedChannel channel = new EmbeddedChannel();
+ private ChannelPipeline pipeline = channel.pipeline();
+ private ChannelHandlerContext channelHandlerCtx;
+
+ private static String getTempFileNameForResourcesFile(String resFile) throws IOException {
+ return Strings.isNullOrEmpty(resFile) ? null : TestUtils.loadCert(resFile).getAbsolutePath();
+ }
+
+ /** Builds DownstreamTlsContext from file-names. */
+ private static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
+ String privateKey, String certChain, String trustCa) throws IOException {
+ return buildDownstreamTlsContext(
+ buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
+ }
+
+ /** Builds UpstreamTlsContext from file-names. */
+ private static UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
+ String privateKey, String certChain, String trustCa) throws IOException {
+ return buildUpstreamTlsContext(
+ buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
+ }
+
+ /** Builds UpstreamTlsContext from commonTlsContext. */
+ private static UpstreamTlsContext buildUpstreamTlsContext(CommonTlsContext commonTlsContext) {
+ UpstreamTlsContext upstreamTlsContext =
+ UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
+ return upstreamTlsContext;
+ }
+
+ /** Builds DownstreamTlsContext from commonTlsContext. */
+ private static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) {
+ DownstreamTlsContext downstreamTlsContext =
+ DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
+ return downstreamTlsContext;
+ }
+
+ private static CommonTlsContext buildCommonTlsContextFromFilenames(
+ String privateKey, String certChain, String trustCa) throws IOException {
+ TlsCertificate tlsCert = null;
+ privateKey = getTempFileNameForResourcesFile(privateKey);
+ certChain = getTempFileNameForResourcesFile(certChain);
+ trustCa = getTempFileNameForResourcesFile(trustCa);
+ if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
+ tlsCert =
+ TlsCertificate.newBuilder()
+ .setCertificateChain(DataSource.newBuilder().setFilename(certChain))
+ .setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
+ .build();
+ }
+ CertificateValidationContext certContext = null;
+ if (!Strings.isNullOrEmpty(trustCa)) {
+ certContext =
+ CertificateValidationContext.newBuilder()
+ .setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
+ .build();
+ }
+ return getCommonTlsContext(tlsCert, certContext);
+ }
+
+ private static CommonTlsContext getCommonTlsContext(
+ TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
+ CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
+ if (tlsCertificate != null) {
+ builder = builder.addTlsCertificates(tlsCertificate);
+ }
+ if (certContext != null) {
+ builder = builder.setValidationContext(certContext);
+ }
+ return builder.build();
+ }
+
+ @Test
+ public void clientSdsProtocolNegotiatorNewHandler_nullTlsContext() {
+ ClientSdsProtocolNegotiator pn =
+ new ClientSdsProtocolNegotiator(/* upstreamTlsContext= */ null);
+ ChannelHandler newHandler = pn.newHandler(grpcHandler);
+ assertThat(newHandler).isNotNull();
+ // ProtocolNegotiators.WaitUntilActiveHandler not accessible, get canonical name
+ assertThat(newHandler.getClass().getCanonicalName())
+ .contains("io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler");
+ }
+
+ @Test
+ public void clientSdsProtocolNegotiatorNewHandler_nonNullTlsContext() {
+ UpstreamTlsContext upstreamTlsContext =
+ buildUpstreamTlsContext(getCommonTlsContext(null, null));
+ ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(upstreamTlsContext);
+ ChannelHandler newHandler = pn.newHandler(grpcHandler);
+ assertThat(newHandler).isNotNull();
+ assertThat(newHandler).isInstanceOf(ClientSdsHandler.class);
+ }
+
+ @Test
+ public void clientSdsHandler_addLast() throws IOException {
+ UpstreamTlsContext upstreamTlsContext =
+ buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
+
+ SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
+ new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext);
+ pipeline.addLast(clientSdsHandler);
+ channelHandlerCtx = pipeline.context(clientSdsHandler);
+ assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it
+
+ // kick off protocol negotiation.
+ pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
+ channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+ channelHandlerCtx = pipeline.context(clientSdsHandler);
+ assertThat(channelHandlerCtx).isNull();
+
+ // pipeline should have SslHandler and ClientTlsHandler
+ Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
+ assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
+ // ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name
+ assertThat(iterator.next().getValue().getClass().getCanonicalName())
+ .contains("ProtocolNegotiators.ClientTlsHandler");
+ }
+
+ @Test
+ public void serverSdsHandler_addLast() throws IOException {
+ DownstreamTlsContext downstreamTlsContext =
+ buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
+
+ SdsProtocolNegotiators.ServerSdsHandler serverSdsHandler =
+ new SdsProtocolNegotiators.ServerSdsHandler(grpcHandler, downstreamTlsContext);
+ pipeline.addLast(serverSdsHandler);
+ channelHandlerCtx = pipeline.context(serverSdsHandler);
+ assertNotNull(channelHandlerCtx); // serverSdsHandler ctx is non-null since we just added it
+
+ // kick off protocol negotiation
+ pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
+ channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+ channelHandlerCtx = pipeline.context(serverSdsHandler);
+ assertThat(channelHandlerCtx).isNull();
+
+ // pipeline should have SslHandler and ServerTlsHandler
+ Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
+ assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
+ // ProtocolNegotiators.ServerTlsHandler.class is not accessible, get canonical name
+ assertThat(iterator.next().getValue().getClass().getCanonicalName())
+ .contains("ProtocolNegotiators.ServerTlsHandler");
+ }
+
+ @Test
+ public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent()
+ throws IOException, InterruptedException {
+ UpstreamTlsContext upstreamTlsContext =
+ buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
+
+ SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
+ new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext);
+
+ pipeline.addLast(clientSdsHandler);
+ channelHandlerCtx = pipeline.context(clientSdsHandler);
+ assertNotNull(channelHandlerCtx); // non-null since we just added it
+
+ // kick off protocol negotiation.
+ pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
+ channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+ channelHandlerCtx = pipeline.context(clientSdsHandler);
+ assertThat(channelHandlerCtx).isNull();
+ Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+ pipeline.fireUserEventTriggered(sslEvent);
+ channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+ assertTrue(channel.isOpen());
+ }
+
+ private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
+
+ FakeGrpcHttp2ConnectionHandler(
+ ChannelPromise channelUnused,
+ Http2ConnectionDecoder decoder,
+ Http2ConnectionEncoder encoder,
+ Http2Settings initialSettings) {
+ super(channelUnused, decoder, encoder, initialSettings);
+ }
+
+ static FakeGrpcHttp2ConnectionHandler newHandler() {
+ DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
+ DefaultHttp2ConnectionEncoder encoder =
+ new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
+ DefaultHttp2ConnectionDecoder decoder =
+ new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader());
+ Http2Settings settings = new Http2Settings();
+ return new FakeGrpcHttp2ConnectionHandler(
+ /*channelUnused=*/ null, decoder, encoder, settings);
+ }
+
+ @Override
+ public String getAuthority() {
+ return "authority";
+ }
+ }
+}