xds: implement alpnProtocols based on list from xDS (#6594)
diff --git a/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java b/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
index bdb3fc1..3c7cf8b 100644
--- a/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
+++ b/xds/src/main/java/io/grpc/xds/sds/SdsSslContextProvider.java
@@ -31,6 +31,7 @@
import io.grpc.Status;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.sds.trust.SdsTrustManagerFactory;
+import io.netty.handler.ssl.ApplicationProtocolConfig;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import java.io.IOException;
@@ -108,7 +109,7 @@
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextSdsSecretConfig()) {
validationContextSdsConfig =
- combinedValidationContext.getValidationContextSdsSecretConfig();
+ combinedValidationContext.getValidationContextSdsSecretConfig();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
@@ -223,7 +224,7 @@
try {
SslContextBuilder sslContextBuilder;
CertificateValidationContext localCertValidationContext =
- mergeStaticAndDynamicCertContexts();
+ mergeStaticAndDynamicCertContexts();
if (server) {
logger.log(Level.FINEST, "for server");
sslContextBuilder =
@@ -248,6 +249,16 @@
tlsCertificate.hasPassword() ? tlsCertificate.getPassword().getInlineString() : null);
}
}
+ CommonTlsContext commonTlsContext = getCommonTlsContext();
+ if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) {
+ List<String> alpnList = commonTlsContext.getAlpnProtocolsList();
+ ApplicationProtocolConfig apn = new ApplicationProtocolConfig(
+ ApplicationProtocolConfig.Protocol.ALPN,
+ ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
+ ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT,
+ alpnList);
+ sslContextBuilder.applicationProtocolConfig(apn);
+ }
SslContext sslContextCopy = sslContextBuilder.build();
sslContext = sslContextCopy;
makePendingCallbacks(sslContextCopy);
diff --git a/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
index 817e7af..d3e6368 100644
--- a/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
+++ b/xds/src/main/java/io/grpc/xds/sds/SslContextProvider.java
@@ -18,6 +18,9 @@
import static com.google.common.base.Preconditions.checkNotNull;
+import io.envoyproxy.envoy.api.v2.auth.CommonTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.DownstreamTlsContext;
+import io.envoyproxy.envoy.api.v2.auth.UpstreamTlsContext;
import io.grpc.Internal;
import io.netty.handler.ssl.SslContext;
import java.util.concurrent.Executor;
@@ -56,6 +59,15 @@
return source;
}
+ CommonTlsContext getCommonTlsContext() {
+ if (source instanceof UpstreamTlsContext) {
+ return ((UpstreamTlsContext) source).getCommonTlsContext();
+ } else if (source instanceof DownstreamTlsContext) {
+ return ((DownstreamTlsContext) source).getCommonTlsContext();
+ }
+ return null;
+ }
+
/** Closes this provider and releases any resources. */
void close() {}
diff --git a/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java
index 594df59..9759979 100644
--- a/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java
+++ b/xds/src/test/java/io/grpc/xds/sds/CommonTlsContextTestsUtil.java
@@ -102,6 +102,7 @@
String validationContextName,
String validationContextTargetUri,
Iterable<String> verifySubjectAltNames,
+ Iterable<String> alpnNames,
String channelType) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
@@ -128,6 +129,9 @@
} else if (certValidationContext != null) {
builder.setValidationContext(certValidationContext);
}
+ if (alpnNames != null) {
+ builder.addAllAlpnProtocols(alpnNames);
+ }
return builder.build();
}
}
diff --git a/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java
index b2355a7..e5d3faa 100644
--- a/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java
+++ b/xds/src/test/java/io/grpc/xds/sds/SdsSslContextProviderTest.java
@@ -66,7 +66,7 @@
/** Helper method to build SdsSslContextProvider from given names. */
private SdsSslContextProvider<?> getSdsSslContextProvider(
boolean server, String certName, String validationContextName,
- Iterable<String> verifySubjectAltNames) throws IOException {
+ Iterable<String> verifySubjectAltNames, Iterable<String> alpnProtocols) throws IOException {
CommonTlsContext commonTlsContext =
CommonTlsContextTestsUtil.buildCommonTlsContextWithAdditionalValues(
@@ -75,6 +75,7 @@
validationContextName,
/* validationContextTargetUri= */ "inproc",
verifySubjectAltNames,
+ alpnProtocols,
/* channelType= */ "inproc");
return server
@@ -98,11 +99,11 @@
.thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
SdsSslContextProvider<?> provider =
- getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null);
+ getSdsSslContextProvider(/* server= */ true, "cert1", "valid1", null, null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
- doChecksOnSslContext(true, testCallback.updatedSslContext);
+ doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
@@ -117,11 +118,12 @@
/* server= */ false,
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
- /* verifySubjectAltNames= */ null);
+ /* verifySubjectAltNames= */ null,
+ /* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
- doChecksOnSslContext(false, testCallback.updatedSslContext);
+ doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
@@ -132,11 +134,11 @@
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ true, /* certName= */ "cert1", /* validationContextName= */ null,
- /* verifySubjectAltNames= */ null);
+ /* verifySubjectAltNames= */ null, /* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
- doChecksOnSslContext(true, testCallback.updatedSslContext);
+ doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
@@ -147,11 +149,11 @@
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ false, /* certName= */ null, /* validationContextName= */ "valid1",
- /* verifySubjectAltNames= */ null);
+ /* verifySubjectAltNames= */ null, null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
- doChecksOnSslContext(false, testCallback.updatedSslContext);
+ doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
@@ -162,7 +164,7 @@
SdsSslContextProvider<?> provider =
getSdsSslContextProvider(
/* server= */ true, /* certName= */ null, /* validationContextName= */ "valid1",
- /* verifySubjectAltNames= */ null);
+ /* verifySubjectAltNames= */ null, /* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
@@ -189,10 +191,53 @@
/* certName= */ "cert1",
/* validationContextName= */ "valid1",
Arrays.asList(
- "spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"));
+ "spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob"),
+ /* alpnProtocols= */ null);
SecretVolumeSslContextProviderTest.TestCallback testCallback =
SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
- doChecksOnSslContext(false, testCallback.updatedSslContext);
+ doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
+ }
+
+ @Test
+ public void testProviderForClient_withAlpnProtocols() throws IOException {
+ when(serverMock.getSecretFor(/* name= */ "cert1"))
+ .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", CLIENT_KEY_FILE, CLIENT_PEM_FILE));
+ when(serverMock.getSecretFor("valid1"))
+ .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
+
+ SdsSslContextProvider<?> provider =
+ getSdsSslContextProvider(
+ /* server= */ false,
+ /* certName= */ "cert1",
+ /* validationContextName= */ "valid1",
+ /* verifySubjectAltNames= */ null,
+ /* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
+ SecretVolumeSslContextProviderTest.TestCallback testCallback =
+ SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
+
+ doChecksOnSslContext(
+ false, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
+ }
+
+ @Test
+ public void testProviderForServer_withAlpnProtocols() throws IOException {
+ when(serverMock.getSecretFor(/* name= */ "cert1"))
+ .thenReturn(getOneTlsCertSecret(/* name= */ "cert1", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE));
+ when(serverMock.getSecretFor(/* name= */ "valid1"))
+ .thenReturn(getOneCertificateValidationContextSecret(/* name= */ "valid1", CA_PEM_FILE));
+
+ SdsSslContextProvider<?> provider =
+ getSdsSslContextProvider(
+ /* server= */ true,
+ /* certName= */ "cert1",
+ /* validationContextName= */ "valid1",
+ /* verifySubjectAltNames= */ null,
+ /* alpnProtocols= */ Arrays.asList("managed-mtls", "h2"));
+ SecretVolumeSslContextProviderTest.TestCallback testCallback =
+ SecretVolumeSslContextProviderTest.getValueThruCallback(provider);
+
+ doChecksOnSslContext(
+ true, testCallback.updatedSslContext, Arrays.asList("managed-mtls", "h2"));
}
}
diff --git a/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java
index 766b12d..826f05a 100644
--- a/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java
+++ b/xds/src/test/java/io/grpc/xds/sds/SecretVolumeSslContextProviderTest.java
@@ -395,10 +395,11 @@
getSslContextSecretVolumeSecretProvider(server, pemFile, keyFile, caFile);
SslContext sslContext = provider.buildSslContextFromSecrets();
- doChecksOnSslContext(server, sslContext);
+ doChecksOnSslContext(server, sslContext, /* expectedApnProtos= */ null);
}
- static void doChecksOnSslContext(boolean server, SslContext sslContext) {
+ static void doChecksOnSslContext(boolean server, SslContext sslContext,
+ List<String> expectedApnProtos) {
if (server) {
assertThat(sslContext.isServer()).isTrue();
} else {
@@ -406,7 +407,11 @@
}
List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols();
assertThat(apnProtos).isNotNull();
- assertThat(apnProtos).contains("h2");
+ if (expectedApnProtos != null) {
+ assertThat(apnProtos).isEqualTo(expectedApnProtos);
+ } else {
+ assertThat(apnProtos).contains("h2");
+ }
}
/**
@@ -544,7 +549,7 @@
true, SERVER_1_PEM_FILE, SERVER_1_KEY_FILE, CA_PEM_FILE);
TestCallback testCallback = getValueThruCallback(provider);
- doChecksOnSslContext(true, testCallback.updatedSslContext);
+ doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
@@ -554,7 +559,7 @@
false, CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE);
TestCallback testCallback = getValueThruCallback(provider);
- doChecksOnSslContext(false, testCallback.updatedSslContext);
+ doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
// note this test generates stack-trace but can be safely ignored