xds: implement alpnProtocols based on list from xDS (#6594)

diff --git a/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java b/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
index bdb3fc1..3c7cf8b 100644
--- a/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
+++ b/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
@@ -31,6 +31,7 @@
 import io.grpc.Status;
 import io.grpc.netty.GrpcSslContexts;
 import io.grpc.xds.sds.trust.SdsTrustManagerFactory;
+import io.netty.handler.ssl.ApplicationProtocolConfig;
 import io.netty.handler.ssl.SslContext;
 import io.netty.handler.ssl.SslContextBuilder;
 import java.io.IOException;
@@ -108,7 +109,7 @@
           commonTlsContext.getCombinedValidationContext();
       if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
         validationContextSdsConfig =
-           combinedValidationContext.getValidationContextSdsSecretConfig();
+            combinedValidationContext.getValidationContextSdsSecretConfig();
       }
       if (combinedValidationContext.hasDefaultValidationContext()) {
         staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
@@ -223,7 +224,7 @@
     try {
       SslContextBuilder sslContextBuilder;
       CertificateValidationContext localCertValidationContext =
-              mergeStaticAndDynamicCertContexts();
+          mergeStaticAndDynamicCertContexts();
       if (server) {
         logger.log(Level.FINEST, "for server");
         sslContextBuilder =
@@ -248,6 +249,16 @@
               tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null);
         }
       }
+      CommonTlsContext commonTlsContext = getCommonTlsContext();
+      if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) {
+        List<String> alpnList = commonTlsContext.getAlpnProtocolsList();
+        ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+            ApplicationProtocolConfig.Protocol.ALPN,
+            ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+            ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+            alpnList);
+        sslContextBuilder.applicationProtocolConfig(apn);
+      }
       SslContext sslContextCopy = sslContextBuilder.build();
       sslContext = sslContextCopy;
       makePendingCallbacks(sslContextCopy);
diff --git a/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
index 817e7af..d3e6368 100644
--- a/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
+++ b/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
@@ -18,6 +18,9 @@
 
 import static com.google.common.base.Preconditions.checkNotNull;
 
+import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
 import io.grpc.Internal;
 import io.netty.handler.ssl.SslContext;
 import java.util.concurrent.Executor;
@@ -56,6 +59,15 @@
     return source;
   }
 
+  CommonTlsContext getCommonTlsContext() {
+    if (source instanceof UpstreamTlsContext) {
+      return ((UpstreamTlsContext) source).getCommonTlsContext();
+    } else if (source instanceof DownstreamTlsContext) {
+      return ((DownstreamTlsContext) source).getCommonTlsContext();
+    }
+    return null;
+  }
+
   /** Closes this provider and releases any resources. */
   void close() {}
 
diff --git a/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java
index 594df59..9759979 100644
--- a/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java
+++ b/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java
@@ -102,6 +102,7 @@
       String validationContextName,
       String validationContextTargetUri,
       Iterable<String> verifySubjectAltNames,
+      Iterable<String> alpnNames,
       String channelType) {
 
     CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
@@ -128,6 +129,9 @@
     } else if (certValidationContext != null) {
       builder.setValidationContext(certValidationContext);
     }
+    if (alpnNames != null) {
+      builder.addAllAlpnProtocols(alpnNames);
+    }
     return builder.build();
   }
 }
diff --git a/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java
index b2355a7..e5d3faa 100644
--- a/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java
+++ b/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java
@@ -66,7 +66,7 @@
   /** Helper method to build SdsSslContextProvider from given names. */
   private SdsSslContextProvider<?> getSdsSslContextProvider(
       boolean server, String certName, String validationContextName,
-      Iterable<String> verifySubjectAltNames) throws IOException {
+      Iterable<String> verifySubjectAltNames, Iterable<String> alpnProtocols) throws IOException {
 
     CommonTlsContext commonTlsContext =
         CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
@@ -75,6 +75,7 @@
             validationContextName,
             /* validationContextTargetUri= */ "inproc",
             verifySubjectAltNames,
+            alpnProtocols,
             /* channelType= */ "inproc");
 
     return server
@@ -98,11 +99,11 @@
         .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
 
     SdsSslContextProvider<?> provider =
-        getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null);
+        getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null, null);
     SecretVolumeSslContextProviderTest.TestCallback testCallback =
         SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
 
-    doChecksOnSslContext(true, testCallback.updatedSslContext);
+    doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
   }
 
   @Test
@@ -117,11 +118,12 @@
             /* server= */ false,
             /* certName= */ "cert1",
             /* validationContextName= */ "valid1",
-            /* verifySubjectAltNames= */ null);
+            /* verifySubjectAltNames= */ null,
+            /* alpnProtocols= */ null);
     SecretVolumeSslContextProviderTest.TestCallback testCallback =
         SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
 
-    doChecksOnSslContext(false, testCallback.updatedSslContext);
+    doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
   }
 
   @Test
@@ -132,11 +134,11 @@
     SdsSslContextProvider<?> provider =
         getSdsSslContextProvider(
             /* server= */ true, /* certName= */ "cert1", /* validationContextName= */ null,
-            /* verifySubjectAltNames= */ null);
+            /* verifySubjectAltNames= */ null, /* alpnProtocols= */ null);
     SecretVolumeSslContextProviderTest.TestCallback testCallback =
         SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
 
-    doChecksOnSslContext(true, testCallback.updatedSslContext);
+    doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
   }
 
   @Test
@@ -147,11 +149,11 @@
     SdsSslContextProvider<?> provider =
         getSdsSslContextProvider(
             /* server= */ false, /* certName= */ null, /* validationContextName= */ "valid1",
-            /* verifySubjectAltNames= */ null);
+            /* verifySubjectAltNames= */ null, null);
     SecretVolumeSslContextProviderTest.TestCallback testCallback =
         SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
 
-    doChecksOnSslContext(false, testCallback.updatedSslContext);
+    doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
   }
 
   @Test
@@ -162,7 +164,7 @@
     SdsSslContextProvider<?> provider =
         getSdsSslContextProvider(
             /* server= */ true, /* certName= */ null, /* validationContextName= */ "valid1",
-            /* verifySubjectAltNames= */ null);
+            /* verifySubjectAltNames= */ null, /* alpnProtocols= */ null);
     SecretVolumeSslContextProviderTest.TestCallback testCallback =
         SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
 
@@ -189,10 +191,53 @@
             /* certName= */ "cert1",
             /* validationContextName= */ "valid1",
             Arrays.asList(
-                "spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"));
+                "spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
+            /* alpnProtocols= */ null);
 
     SecretVolumeSslContextProviderTest.TestCallback testCallback =
         SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
-    doChecksOnSslContext(false, testCallback.updatedSslContext);
+    doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
+  }
+
+  @Test
+  public void testProviderForClient_withAlpnProtocols() throws IOException {
+    when(serverMock.getSecretFor(/* name= */ "cert1"))
+        .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE));
+    when(serverMock.getSecretFor("valid1"))
+        .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
+
+    SdsSslContextProvider<?> provider =
+        getSdsSslContextProvider(
+            /* server= */ false,
+            /* certName= */ "cert1",
+            /* validationContextName= */ "valid1",
+            /* verifySubjectAltNames= */ null,
+            /* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
+    SecretVolumeSslContextProviderTest.TestCallback testCallback =
+        SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
+
+    doChecksOnSslContext(
+        false, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
+  }
+
+  @Test
+  public void testProviderForServer_withAlpnProtocols() throws IOException {
+    when(serverMock.getSecretFor(/* name= */ "cert1"))
+        .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
+    when(serverMock.getSecretFor(/* name= */ "valid1"))
+        .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
+
+    SdsSslContextProvider<?> provider =
+        getSdsSslContextProvider(
+            /* server= */ true,
+            /* certName= */ "cert1",
+            /* validationContextName= */ "valid1",
+            /* verifySubjectAltNames= */ null,
+            /* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
+    SecretVolumeSslContextProviderTest.TestCallback testCallback =
+        SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
+
+    doChecksOnSslContext(
+        true, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
   }
 }
diff --git a/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java
index 766b12d..826f05a 100644
--- a/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java
+++ b/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java
@@ -395,10 +395,11 @@
         getSslContextSecretVolumeSecretProvider(server, pemFile, keyFile, caFile);
 
     SslContext sslContext = provider.buildSslContextFromSecrets();
-    doChecksOnSslContext(server, sslContext);
+    doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null);
   }
 
-  static void doChecksOnSslContext(boolean server, SslContext sslContext) {
+  static void doChecksOnSslContext(boolean server, SslContext sslContext,
+      List<String> expectedApnProtos) {
     if (server) {
       assertThat(sslContext.isServer()).isTrue();
     } else {
@@ -406,7 +407,11 @@
     }
     List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols();
     assertThat(apnProtos).isNotNull();
-    assertThat(apnProtos).contains("h2");
+    if (expectedApnProtos != null) {
+      assertThat(apnProtos).isEqualTo(expectedApnProtos);
+    } else {
+      assertThat(apnProtos).contains("h2");
+    }
   }
 
   /**
@@ -544,7 +549,7 @@
             true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, CA_PEM_FILE);
 
     TestCallback testCallback = getValueThruCallback(provider);
-    doChecksOnSslContext(true, testCallback.updatedSslContext);
+    doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
   }
 
   @Test
@@ -554,7 +559,7 @@
             false, CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE);
 
     TestCallback testCallback = getValueThruCallback(provider);
-    doChecksOnSslContext(false, testCallback.updatedSslContext);
+    doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
   }
 
   // note this test generates stack-trace but can be safely ignored