xds: Client and server proto negotiators and handlers added to SdsProtocolNegotiators (#6319)

diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
index e71a695..9f3e599 100644
--- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
+++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java
@@ -68,6 +68,82 @@
   }
 
   /**
+   * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will be
+   * negotiated, the server TLS {@code handler} is added and writes to the {@link
+   * io.netty.channel.Channel} may happen immediately, even before the TLS Handshake is complete.
+   */
+  public static InternalProtocolNegotiator.ProtocolNegotiator serverTls(SslContext sslContext) {
+    final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverTls(sslContext);
+    final class ServerTlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
+
+      @Override
+      public AsciiString scheme() {
+        return negotiator.scheme();
+      }
+
+      @Override
+      public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+        return negotiator.newHandler(grpcHandler);
+      }
+
+      @Override
+      public void close() {
+        negotiator.close();
+      }
+    }
+
+    return new ServerTlsNegotiator();
+  }
+
+  /** Returns a {@link ProtocolNegotiator} for plaintext client channel. */
+  public static InternalProtocolNegotiator.ProtocolNegotiator plaintext() {
+    final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.plaintext();
+    final class PlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
+
+      @Override
+      public AsciiString scheme() {
+        return negotiator.scheme();
+      }
+
+      @Override
+      public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+        return negotiator.newHandler(grpcHandler);
+      }
+
+      @Override
+      public void close() {
+        negotiator.close();
+      }
+    }
+
+    return new PlaintextNegotiator();
+  }
+
+  /** Returns a {@link ProtocolNegotiator} for plaintext server channel. */
+  public static InternalProtocolNegotiator.ProtocolNegotiator serverPlaintext() {
+    final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.serverPlaintext();
+    final class ServerPlaintextNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
+
+      @Override
+      public AsciiString scheme() {
+        return negotiator.scheme();
+      }
+
+      @Override
+      public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
+        return negotiator.newHandler(grpcHandler);
+      }
+
+      @Override
+      public void close() {
+        negotiator.close();
+      }
+    }
+
+    return new ServerPlaintextNegotiator();
+  }
+
+  /**
    * Internal version of {@link WaitUntilActiveHandler}.
    */
   public static ChannelHandler waitUntilActiveHandler(ChannelHandler next) {
diff --git a/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java b/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java
index 682b401..2e4e5d2 100644
--- a/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java
+++ b/xds/src/main/java/io/grpc/xds/sds/XdsChannelBuilder.java
@@ -16,14 +16,17 @@
 
 package io.grpc.xds.sds;
 
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
 import io.grpc.ExperimentalApi;
 import io.grpc.ForwardingChannelBuilder;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.netty.InternalNettyChannelBuilder;
 import io.grpc.netty.NettyChannelBuilder;
 import io.grpc.xds.sds.internal.SdsProtocolNegotiators;
 import java.net.SocketAddress;
 import javax.annotation.CheckReturnValue;
+import javax.annotation.Nullable;
 
 /**
  * A version of {@link ManagedChannelBuilder} to create xDS managed channels that will use SDS to
@@ -34,9 +37,11 @@
 
   private final NettyChannelBuilder delegate;
 
+  // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+  @Nullable private UpstreamTlsContext upstreamTlsContext;
+
   private XdsChannelBuilder(NettyChannelBuilder delegate) {
     this.delegate = delegate;
-    SdsProtocolNegotiators.setProtocolNegotiatorFactory(delegate);
   }
 
   /**
@@ -66,6 +71,15 @@
     return new XdsChannelBuilder(NettyChannelBuilder.forTarget(target));
   }
 
+  /**
+   * Set the UpstreamTlsContext for this channel. This is a temporary workaround until CDS is
+   * implemented in the XDS client. Passing {@code null} will fall back to plaintext.
+   */
+  public XdsChannelBuilder tlsContext(@Nullable UpstreamTlsContext upstreamTlsContext) {
+    this.upstreamTlsContext = upstreamTlsContext;
+    return this;
+  }
+
   @Override
   protected ManagedChannelBuilder<?> delegate() {
     return delegate;
@@ -73,6 +87,8 @@
 
   @Override
   public ManagedChannel build() {
+    InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
+        delegate, SdsProtocolNegotiators.clientProtocolNegotiatorFactory(upstreamTlsContext));
     return delegate.build();
   }
 }
diff --git a/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java
index f596d19..8f3b28b 100644
--- a/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java
+++ b/xds/src/main/java/io/grpc/xds/sds/XdsServerBuilder.java
@@ -16,6 +16,7 @@
 
 package io.grpc.xds.sds;
 
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
 import io.grpc.BindableService;
 import io.grpc.CompressorRegistry;
 import io.grpc.DecompressorRegistry;
@@ -44,6 +45,9 @@
 
   private final NettyServerBuilder delegate;
 
+  // TODO (sanjaypujare) integrate with xDS client to get downstreamTlsContext from LDS
+  @Nullable private DownstreamTlsContext downstreamTlsContext;
+
   private XdsServerBuilder(NettyServerBuilder nettyDelegate) {
     this.delegate = nettyDelegate;
   }
@@ -119,6 +123,15 @@
     return this;
   }
 
+  /**
+   * Set the DownstreamTlsContext for the server. This is a temporary workaround until integration
+   * with xDS client is implemented to get LDS. Passing {@code null} will fall back to plaintext.
+   */
+  public XdsServerBuilder tlsContext(@Nullable DownstreamTlsContext downstreamTlsContext) {
+    this.downstreamTlsContext = downstreamTlsContext;
+    return this;
+  }
+
   /** Creates a gRPC server builder for the given port. */
   public static XdsServerBuilder forPort(int port) {
     NettyServerBuilder nettyDelegate = NettyServerBuilder.forAddress(new InetSocketAddress(port));
@@ -128,7 +141,8 @@
   @Override
   public Server build() {
     // note: doing it in build() will overwrite any previously set ProtocolNegotiator
-    delegate.protocolNegotiator(SdsProtocolNegotiators.serverProtocolNegotiator());
+    delegate.protocolNegotiator(
+        SdsProtocolNegotiators.serverProtocolNegotiator(this.downstreamTlsContext));
     return delegate.build();
   }
 }
diff --git a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java
index 9bc5129..4a72ea2 100644
--- a/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java
+++ b/xds/src/main/java/io/grpc/xds/sds/internal/SdsProtocolNegotiators.java
@@ -16,15 +16,30 @@
 
 package io.grpc.xds.sds.internal;
 
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
 import io.grpc.Internal;
 import io.grpc.netty.GrpcHttp2ConnectionHandler;
 import io.grpc.netty.InternalNettyChannelBuilder;
 import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory;
 import io.grpc.netty.InternalProtocolNegotiator;
 import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator;
+import io.grpc.netty.InternalProtocolNegotiators;
 import io.grpc.netty.NettyChannelBuilder;
+import io.grpc.xds.sds.SecretProvider;
+import io.grpc.xds.sds.TlsContextManager;
 import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerAdapter;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.handler.ssl.SslContext;
 import io.netty.util.AsciiString;
+import java.util.ArrayList;
+import java.util.List;
+import javax.annotation.Nullable;
 
 /**
  * Provides client and server side gRPC {@link ProtocolNegotiator}s that use SDS to provide the SSL
@@ -35,12 +50,40 @@
 
   private static final AsciiString SCHEME = AsciiString.of("https");
 
+  /**
+   * Returns a {@link ProtocolNegotiatorFactory} to be used on {@link NettyChannelBuilder}. Passing
+   * {@code null} for upstreamTlsContext will fall back to plaintext.
+   */
+  // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+  public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory(
+      @Nullable UpstreamTlsContext upstreamTlsContext) {
+    return new ClientSdsProtocolNegotiatorFactory(upstreamTlsContext);
+  }
+
+  /**
+   * Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}.
+   * Passing {@code null} for downstreamTlsContext will fall back to plaintext.
+   */
+  // TODO (sanjaypujare) integrate with xDS client to get LDS
+  public static ProtocolNegotiator serverProtocolNegotiator(
+      @Nullable DownstreamTlsContext downstreamTlsContext) {
+    return new ServerSdsProtocolNegotiator(downstreamTlsContext);
+  }
+
   private static final class ClientSdsProtocolNegotiatorFactory
       implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory {
 
+    // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+    private final UpstreamTlsContext upstreamTlsContext;
+
+    ClientSdsProtocolNegotiatorFactory(UpstreamTlsContext upstreamTlsContext) {
+      this.upstreamTlsContext = upstreamTlsContext;
+    }
+
     @Override
     public InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator() {
-      final ClientSdsProtocolNegotiator negotiator = new ClientSdsProtocolNegotiator();
+      final ClientSdsProtocolNegotiator negotiator =
+          new ClientSdsProtocolNegotiator(upstreamTlsContext);
       final class LocalSdsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
 
         @Override
@@ -63,7 +106,15 @@
     }
   }
 
-  private static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator {
+  @VisibleForTesting
+  static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator {
+
+    // TODO (sanjaypujare) integrate with xDS client to get upstreamTlsContext from CDS
+    UpstreamTlsContext upstreamTlsContext;
+
+    ClientSdsProtocolNegotiator(UpstreamTlsContext upstreamTlsContext) {
+      this.upstreamTlsContext = upstreamTlsContext;
+    }
 
     @Override
     public AsciiString scheme() {
@@ -72,16 +123,111 @@
 
     @Override
     public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
-      // TODO(sanjaypujare): once implemented return ClientSdsHandler
-      throw new UnsupportedOperationException("Not implemented yet");
+      // once CDS is implemented we will retrieve upstreamTlsContext as follows:
+      // grpcHandler.getEagAttributes().get(XdsAttributes.ATTR_UPSTREAM_TLS_CONTEXT);
+      if (isTlsContextEmpty(upstreamTlsContext)) {
+        return InternalProtocolNegotiators.plaintext().newHandler(grpcHandler);
+      }
+      return new ClientSdsHandler(grpcHandler, upstreamTlsContext);
+    }
+
+    private static boolean isTlsContextEmpty(UpstreamTlsContext upstreamTlsContext) {
+      return upstreamTlsContext == null || !upstreamTlsContext.hasCommonTlsContext();
     }
 
     @Override
     public void close() {}
   }
 
+  private static class BufferReadsHandler extends ChannelInboundHandlerAdapter {
+    private final List<Object> reads = new ArrayList<>();
+    private boolean readComplete;
+
+    @Override
+    public void channelRead(ChannelHandlerContext ctx, Object msg) {
+      reads.add(msg);
+    }
+
+    @Override
+    public void channelReadComplete(ChannelHandlerContext ctx) {
+      readComplete = true;
+    }
+
+    @Override
+    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
+      for (Object msg : reads) {
+        super.channelRead(ctx, msg);
+      }
+      if (readComplete) {
+        super.channelReadComplete(ctx);
+      }
+    }
+  }
+
+  @VisibleForTesting
+  static final class ClientSdsHandler
+      extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
+    private final GrpcHttp2ConnectionHandler grpcHandler;
+    private final UpstreamTlsContext upstreamTlsContext;
+
+    ClientSdsHandler(
+        GrpcHttp2ConnectionHandler grpcHandler, UpstreamTlsContext upstreamTlsContext) {
+      super(
+          // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
+          // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
+          // here and then manually add 'next' when we call fireProtocolNegotiationEvent()
+          new ChannelHandlerAdapter() {
+            @Override
+            public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+              ctx.pipeline().remove(this);
+            }
+          });
+      checkNotNull(grpcHandler, "grpcHandler");
+      this.grpcHandler = grpcHandler;
+      this.upstreamTlsContext = upstreamTlsContext;
+    }
+
+    @Override
+    protected void handlerAdded0(final ChannelHandlerContext ctx) {
+      final BufferReadsHandler bufferReads = new BufferReadsHandler();
+      ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
+
+      SecretProvider<SslContext> sslContextProvider =
+          TlsContextManager.getInstance().findOrCreateClientSslContextProvider(upstreamTlsContext);
+
+      sslContextProvider.addCallback(
+          new SecretProvider.Callback<SslContext>() {
+
+            @Override
+            public void updateSecret(SslContext sslContext) {
+              ChannelHandler handler =
+                  InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
+
+              // Delegate rest of handshake to TLS handler
+              ctx.pipeline().addAfter(ctx.name(), null, handler);
+              fireProtocolNegotiationEvent(ctx);
+              ctx.pipeline().remove(bufferReads);
+            }
+
+            @Override
+            public void onException(Throwable throwable) {
+              ctx.fireExceptionCaught(throwable);
+            }
+          },
+          ctx.executor());
+    }
+  }
+
   private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator {
 
+    // TODO (sanjaypujare) integrate with xDS client to get LDS. LDS watcher will
+    // inject/update the downstreamTlsContext from LDS
+    private DownstreamTlsContext downstreamTlsContext;
+
+    ServerSdsProtocolNegotiator(DownstreamTlsContext downstreamTlsContext) {
+      this.downstreamTlsContext = downstreamTlsContext;
+    }
+
     @Override
     public AsciiString scheme() {
       return SCHEME;
@@ -89,22 +235,72 @@
 
     @Override
     public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
-      // TODO(sanjaypujare): once implemented return ServerSdsHandler
-      throw new UnsupportedOperationException("Not implemented yet");
+      if (isTlsContextEmpty(downstreamTlsContext)) {
+        return InternalProtocolNegotiators.serverPlaintext().newHandler(grpcHandler);
+      }
+      return new ServerSdsHandler(grpcHandler, downstreamTlsContext);
+    }
+
+    private static boolean isTlsContextEmpty(DownstreamTlsContext downstreamTlsContext) {
+      return downstreamTlsContext == null || !downstreamTlsContext.hasCommonTlsContext();
     }
 
     @Override
     public void close() {}
   }
 
-  /** Sets the {@link ProtocolNegotiatorFactory} on a NettyChannelBuilder. */
-  public static void setProtocolNegotiatorFactory(NettyChannelBuilder builder) {
-    InternalNettyChannelBuilder.setProtocolNegotiatorFactory(
-        builder, new ClientSdsProtocolNegotiatorFactory());
-  }
+  @VisibleForTesting
+  static final class ServerSdsHandler
+      extends InternalProtocolNegotiators.ProtocolNegotiationHandler {
+    private final GrpcHttp2ConnectionHandler grpcHandler;
+    private final DownstreamTlsContext downstreamTlsContext;
 
-  /** Creates an SDS based {@link ProtocolNegotiator} for a server. */
-  public static ProtocolNegotiator serverProtocolNegotiator() {
-    return new ServerSdsProtocolNegotiator();
+    ServerSdsHandler(
+        GrpcHttp2ConnectionHandler grpcHandler, DownstreamTlsContext downstreamTlsContext) {
+      super(
+          // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next'
+          // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior
+          // here and then manually add 'next' when we call fireProtocolNegotiationEvent()
+          new ChannelHandlerAdapter() {
+            @Override
+            public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
+              ctx.pipeline().remove(this);
+            }
+          });
+      checkNotNull(grpcHandler, "grpcHandler");
+      this.grpcHandler = grpcHandler;
+      this.downstreamTlsContext = downstreamTlsContext;
+    }
+
+    @Override
+    protected void handlerAdded0(final ChannelHandlerContext ctx) {
+      final BufferReadsHandler bufferReads = new BufferReadsHandler();
+      ctx.pipeline().addBefore(ctx.name(), null, bufferReads);
+
+      SecretProvider<SslContext> sslContextProvider =
+          TlsContextManager.getInstance()
+              .findOrCreateServerSslContextProvider(downstreamTlsContext);
+
+      sslContextProvider.addCallback(
+          new SecretProvider.Callback<SslContext>() {
+
+            @Override
+            public void updateSecret(SslContext sslContext) {
+              ChannelHandler handler =
+                  InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler);
+
+              // Delegate rest of handshake to TLS handler
+              ctx.pipeline().addAfter(ctx.name(), null, handler);
+              fireProtocolNegotiationEvent(ctx);
+              ctx.pipeline().remove(bufferReads);
+            }
+
+            @Override
+            public void onException(Throwable throwable) {
+              ctx.fireExceptionCaught(throwable);
+            }
+          },
+          ctx.executor());
+    }
   }
 }
diff --git a/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java
new file mode 100644
index 0000000..5a2687b
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/sds/XdsSdsClientServerTest.java
@@ -0,0 +1,193 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds.sds;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import com.google.protobuf.BoolValue;
+import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
+import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.core.DataSource;
+import io.grpc.Server;
+import io.grpc.internal.testing.TestUtils;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.testing.protobuf.SimpleRequest;
+import io.grpc.testing.protobuf.SimpleResponse;
+import io.grpc.testing.protobuf.SimpleServiceGrpc;
+import java.io.IOException;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS
+ * modes.
+ */
+@RunWith(JUnit4.class)
+public class XdsSdsClientServerTest {
+
+  @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
+
+  @Test
+  public void plaintextClientServer() throws IOException {
+    Server server = getXdsServer(/* downstreamTlsContext= */ null);
+    buildClientAndTest(
+        /* upstreamTlsContext= */ null, /* overrideAuthority= */ null, "buddy", server.getPort());
+  }
+
+  /** TLS channel - no mTLS. */
+  @Test
+  public void tlsClientServer_noClientAuthentication() throws IOException {
+    String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
+    String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
+
+    TlsCertificate tlsCert =
+        TlsCertificate.newBuilder()
+            .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
+            .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
+            .build();
+
+    CommonTlsContext commonTlsContext =
+        CommonTlsContext.newBuilder().addTlsCertificates(tlsCert).build();
+
+    DownstreamTlsContext downstreamTlsContext =
+        DownstreamTlsContext.newBuilder()
+            .setCommonTlsContext(commonTlsContext)
+            .setRequireClientCertificate(BoolValue.of(false))
+            .build();
+
+    Server server = getXdsServer(downstreamTlsContext);
+
+    // for TLS client doesn't need cert but needs trustCa
+    String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
+    CertificateValidationContext certContext =
+        CertificateValidationContext.newBuilder()
+            .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
+            .build();
+
+    CommonTlsContext commonTlsContext1 =
+        CommonTlsContext.newBuilder().setValidationContext(certContext).build();
+
+    UpstreamTlsContext upstreamTlsContext =
+        UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
+    buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
+  }
+
+  /** mTLS - client auth enabled. */
+  @Test
+  public void mtlsClientServer_withClientAuthentication() throws IOException, InterruptedException {
+    String server1Pem = TestUtils.loadCert("server1.pem").getAbsolutePath();
+    String server1Key = TestUtils.loadCert("server1.key").getAbsolutePath();
+    String trustCa = TestUtils.loadCert("ca.pem").getAbsolutePath();
+
+    TlsCertificate tlsCert =
+        TlsCertificate.newBuilder()
+            .setPrivateKey(DataSource.newBuilder().setFilename(server1Key).build())
+            .setCertificateChain(DataSource.newBuilder().setFilename(server1Pem).build())
+            .build();
+
+    CertificateValidationContext certContext =
+        CertificateValidationContext.newBuilder()
+            .setTrustedCa(DataSource.newBuilder().setFilename(trustCa).build())
+            .build();
+
+    CommonTlsContext commonTlsContext =
+        CommonTlsContext.newBuilder()
+            .addTlsCertificates(tlsCert)
+            .setValidationContext(certContext)
+            .build();
+
+    DownstreamTlsContext downstreamTlsContext =
+        DownstreamTlsContext.newBuilder()
+            .setCommonTlsContext(commonTlsContext)
+            .setRequireClientCertificate(BoolValue.of(false))
+            .build();
+
+    Server server = getXdsServer(downstreamTlsContext);
+
+    String clientPem = TestUtils.loadCert("client.pem").getAbsolutePath();
+    String clientKey = TestUtils.loadCert("client.key").getAbsolutePath();
+
+    TlsCertificate tlsCert1 =
+        TlsCertificate.newBuilder()
+            .setPrivateKey(DataSource.newBuilder().setFilename(clientKey).build())
+            .setCertificateChain(DataSource.newBuilder().setFilename(clientPem).build())
+            .build();
+
+    CommonTlsContext commonTlsContext1 =
+        CommonTlsContext.newBuilder()
+            .addTlsCertificates(tlsCert1)
+            .setValidationContext(certContext)
+            .build();
+
+    UpstreamTlsContext upstreamTlsContext =
+        UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext1).build();
+
+    buildClientAndTest(upstreamTlsContext, "foo.test.google.fr", "buddy", server.getPort());
+  }
+
+  private Server getXdsServer(DownstreamTlsContext downstreamTlsContext) throws IOException {
+    XdsServerBuilder serverBuilder =
+        XdsServerBuilder.forPort(0) // get unused port
+            .addService(new SimpleServiceImpl())
+            .tlsContext(downstreamTlsContext);
+    return cleanupRule.register(serverBuilder.build()).start();
+  }
+
+  private void buildClientAndTest(
+      UpstreamTlsContext upstreamTlsContext,
+      String overrideAuthority,
+      String requestMessage,
+      int serverPort) {
+
+    XdsChannelBuilder builder =
+        XdsChannelBuilder.forTarget("localhost:" + serverPort).tlsContext(upstreamTlsContext);
+    if (overrideAuthority != null) {
+      builder = builder.overrideAuthority(overrideAuthority);
+    }
+    SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub =
+        SimpleServiceGrpc.newBlockingStub(cleanupRule.register(builder.build()));
+    String resp = unaryRpc(requestMessage, blockingStub);
+    assertThat(resp).isEqualTo("Hello " + requestMessage);
+  }
+
+  /** Say hello to server. */
+  private static String unaryRpc(
+      String requestMessage, SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub) {
+    SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(requestMessage).build();
+    SimpleResponse response = blockingStub.unaryRpc(request);
+    return response.getResponseMessage();
+  }
+
+  private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {
+
+    @Override
+    public void unaryRpc(SimpleRequest req, StreamObserver<SimpleResponse> responseObserver) {
+      SimpleResponse response =
+          SimpleResponse.newBuilder()
+              .setResponseMessage("Hello " + req.getRequestMessage())
+              .build();
+      responseObserver.onNext(response);
+      responseObserver.onCompleted();
+    }
+  }
+}
diff --git a/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java
new file mode 100644
index 0000000..ea3c653
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/sds/internal/SdsProtocolNegotiatorsTest.java
@@ -0,0 +1,263 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds.sds.internal;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+import com.google.common.base.Strings;
+import io.envoyproxy.envoy.api.v2.auth.CertificateValidationContext;
+import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.TlsCertificate;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.core.DataSource;
+import io.grpc.internal.testing.TestUtils;
+import io.grpc.netty.GrpcHttp2ConnectionHandler;
+import io.grpc.netty.InternalProtocolNegotiationEvent;
+import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsHandler;
+import io.grpc.xds.sds.internal.SdsProtocolNegotiators.ClientSdsProtocolNegotiator;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.channel.ChannelPipeline;
+import io.netty.channel.ChannelPromise;
+import io.netty.channel.embedded.EmbeddedChannel;
+import io.netty.handler.codec.http2.DefaultHttp2Connection;
+import io.netty.handler.codec.http2.DefaultHttp2ConnectionDecoder;
+import io.netty.handler.codec.http2.DefaultHttp2ConnectionEncoder;
+import io.netty.handler.codec.http2.DefaultHttp2FrameReader;
+import io.netty.handler.codec.http2.DefaultHttp2FrameWriter;
+import io.netty.handler.codec.http2.Http2ConnectionDecoder;
+import io.netty.handler.codec.http2.Http2ConnectionEncoder;
+import io.netty.handler.codec.http2.Http2Settings;
+import io.netty.handler.ssl.SslHandler;
+import io.netty.handler.ssl.SslHandshakeCompletionEvent;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link SdsProtocolNegotiators}. */
+@RunWith(JUnit4.class)
+public class SdsProtocolNegotiatorsTest {
+
+  private static final String SERVER_1_PEM_FILE = "server1.pem";
+  private static final String SERVER_1_KEY_FILE = "server1.key";
+  private static final String CLIENT_PEM_FILE = "client.pem";
+  private static final String CLIENT_KEY_FILE = "client.key";
+  private static final String CA_PEM_FILE = "ca.pem";
+
+  private final GrpcHttp2ConnectionHandler grpcHandler =
+      FakeGrpcHttp2ConnectionHandler.newHandler();
+
+  private EmbeddedChannel channel = new EmbeddedChannel();
+  private ChannelPipeline pipeline = channel.pipeline();
+  private ChannelHandlerContext channelHandlerCtx;
+
+  private static String getTempFileNameForResourcesFile(String resFile) throws IOException {
+    return Strings.isNullOrEmpty(resFile) ? null : TestUtils.loadCert(resFile).getAbsolutePath();
+  }
+
+  /** Builds DownstreamTlsContext from file-names. */
+  private static DownstreamTlsContext buildDownstreamTlsContextFromFilenames(
+      String privateKey, String certChain, String trustCa) throws IOException {
+    return buildDownstreamTlsContext(
+        buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
+  }
+
+  /** Builds UpstreamTlsContext from file-names. */
+  private static UpstreamTlsContext buildUpstreamTlsContextFromFilenames(
+      String privateKey, String certChain, String trustCa) throws IOException {
+    return buildUpstreamTlsContext(
+        buildCommonTlsContextFromFilenames(privateKey, certChain, trustCa));
+  }
+
+  /** Builds UpstreamTlsContext from commonTlsContext. */
+  private static UpstreamTlsContext buildUpstreamTlsContext(CommonTlsContext commonTlsContext) {
+    UpstreamTlsContext upstreamTlsContext =
+        UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
+    return upstreamTlsContext;
+  }
+
+  /** Builds DownstreamTlsContext from commonTlsContext. */
+  private static DownstreamTlsContext buildDownstreamTlsContext(CommonTlsContext commonTlsContext) {
+    DownstreamTlsContext downstreamTlsContext =
+        DownstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
+    return downstreamTlsContext;
+  }
+
+  private static CommonTlsContext buildCommonTlsContextFromFilenames(
+      String privateKey, String certChain, String trustCa) throws IOException {
+    TlsCertificate tlsCert = null;
+    privateKey = getTempFileNameForResourcesFile(privateKey);
+    certChain = getTempFileNameForResourcesFile(certChain);
+    trustCa = getTempFileNameForResourcesFile(trustCa);
+    if (!Strings.isNullOrEmpty(privateKey) && !Strings.isNullOrEmpty(certChain)) {
+      tlsCert =
+          TlsCertificate.newBuilder()
+              .setCertificateChain(DataSource.newBuilder().setFilename(certChain))
+              .setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
+              .build();
+    }
+    CertificateValidationContext certContext = null;
+    if (!Strings.isNullOrEmpty(trustCa)) {
+      certContext =
+          CertificateValidationContext.newBuilder()
+              .setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
+              .build();
+    }
+    return getCommonTlsContext(tlsCert, certContext);
+  }
+
+  private static CommonTlsContext getCommonTlsContext(
+      TlsCertificate tlsCertificate, CertificateValidationContext certContext) {
+    CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
+    if (tlsCertificate != null) {
+      builder = builder.addTlsCertificates(tlsCertificate);
+    }
+    if (certContext != null) {
+      builder = builder.setValidationContext(certContext);
+    }
+    return builder.build();
+  }
+
+  @Test
+  public void clientSdsProtocolNegotiatorNewHandler_nullTlsContext() {
+    ClientSdsProtocolNegotiator pn =
+        new ClientSdsProtocolNegotiator(/* upstreamTlsContext= */ null);
+    ChannelHandler newHandler = pn.newHandler(grpcHandler);
+    assertThat(newHandler).isNotNull();
+    // ProtocolNegotiators.WaitUntilActiveHandler not accessible, get canonical name
+    assertThat(newHandler.getClass().getCanonicalName())
+        .contains("io.grpc.netty.ProtocolNegotiators.WaitUntilActiveHandler");
+  }
+
+  @Test
+  public void clientSdsProtocolNegotiatorNewHandler_nonNullTlsContext() {
+    UpstreamTlsContext upstreamTlsContext =
+        buildUpstreamTlsContext(getCommonTlsContext(null, null));
+    ClientSdsProtocolNegotiator pn = new ClientSdsProtocolNegotiator(upstreamTlsContext);
+    ChannelHandler newHandler = pn.newHandler(grpcHandler);
+    assertThat(newHandler).isNotNull();
+    assertThat(newHandler).isInstanceOf(ClientSdsHandler.class);
+  }
+
+  @Test
+  public void clientSdsHandler_addLast() throws IOException {
+    UpstreamTlsContext upstreamTlsContext =
+        buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
+
+    SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
+        new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext);
+    pipeline.addLast(clientSdsHandler);
+    channelHandlerCtx = pipeline.context(clientSdsHandler);
+    assertNotNull(channelHandlerCtx); // clientSdsHandler ctx is non-null since we just added it
+
+    // kick off protocol negotiation.
+    pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
+    channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+    channelHandlerCtx = pipeline.context(clientSdsHandler);
+    assertThat(channelHandlerCtx).isNull();
+
+    // pipeline should have SslHandler and ClientTlsHandler
+    Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
+    assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
+    // ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name
+    assertThat(iterator.next().getValue().getClass().getCanonicalName())
+        .contains("ProtocolNegotiators.ClientTlsHandler");
+  }
+
+  @Test
+  public void serverSdsHandler_addLast() throws IOException {
+    DownstreamTlsContext downstreamTlsContext =
+        buildDownstreamTlsContextFromFilenames(SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE);
+
+    SdsProtocolNegotiators.ServerSdsHandler serverSdsHandler =
+        new SdsProtocolNegotiators.ServerSdsHandler(grpcHandler, downstreamTlsContext);
+    pipeline.addLast(serverSdsHandler);
+    channelHandlerCtx = pipeline.context(serverSdsHandler);
+    assertNotNull(channelHandlerCtx); // serverSdsHandler ctx is non-null since we just added it
+
+    // kick off protocol negotiation
+    pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
+    channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+    channelHandlerCtx = pipeline.context(serverSdsHandler);
+    assertThat(channelHandlerCtx).isNull();
+
+    // pipeline should have SslHandler and ServerTlsHandler
+    Iterator<Map.Entry<String, ChannelHandler>> iterator = pipeline.iterator();
+    assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class);
+    // ProtocolNegotiators.ServerTlsHandler.class is not accessible, get canonical name
+    assertThat(iterator.next().getValue().getClass().getCanonicalName())
+        .contains("ProtocolNegotiators.ServerTlsHandler");
+  }
+
+  @Test
+  public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent()
+      throws IOException, InterruptedException {
+    UpstreamTlsContext upstreamTlsContext =
+        buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE);
+
+    SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler =
+        new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, upstreamTlsContext);
+
+    pipeline.addLast(clientSdsHandler);
+    channelHandlerCtx = pipeline.context(clientSdsHandler);
+    assertNotNull(channelHandlerCtx); // non-null since we just added it
+
+    // kick off protocol negotiation.
+    pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault());
+    channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+    channelHandlerCtx = pipeline.context(clientSdsHandler);
+    assertThat(channelHandlerCtx).isNull();
+    Object sslEvent = SslHandshakeCompletionEvent.SUCCESS;
+
+    pipeline.fireUserEventTriggered(sslEvent);
+    channel.runPendingTasks(); // need this for tasks to execute on eventLoop
+    assertTrue(channel.isOpen());
+  }
+
+  private static final class FakeGrpcHttp2ConnectionHandler extends GrpcHttp2ConnectionHandler {
+
+    FakeGrpcHttp2ConnectionHandler(
+        ChannelPromise channelUnused,
+        Http2ConnectionDecoder decoder,
+        Http2ConnectionEncoder encoder,
+        Http2Settings initialSettings) {
+      super(channelUnused, decoder, encoder, initialSettings);
+    }
+
+    static FakeGrpcHttp2ConnectionHandler newHandler() {
+      DefaultHttp2Connection conn = new DefaultHttp2Connection(/*server=*/ false);
+      DefaultHttp2ConnectionEncoder encoder =
+          new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter());
+      DefaultHttp2ConnectionDecoder decoder =
+          new DefaultHttp2ConnectionDecoder(conn, encoder, new DefaultHttp2FrameReader());
+      Http2Settings settings = new Http2Settings();
+      return new FakeGrpcHttp2ConnectionHandler(
+          /*channelUnused=*/ null, decoder, encoder, settings);
+    }
+
+    @Override
+    public String getAuthority() {
+      return "authority";
+    }
+  }
+}