blob: 840cced424f579d3614042d8830335e54f02de0f [file] [log] [blame]
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.sds;
import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.common.io.CharStreams;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.BoolValue;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.internal.testing.TestUtils;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.internal.sds.trust.CertificateUtils;
import io.netty.handler.ssl.SslContext;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
/** Utility class for client and server ssl provider tests. */
public class CommonTlsContextTestsUtil {
public static final String SERVER_0_PEM_FILE = "server0.pem";
public static final String SERVER_0_KEY_FILE = "server0.key";
public static final String SERVER_1_PEM_FILE = "server1.pem";
public static final String SERVER_1_KEY_FILE = "server1.key";
public static final String CLIENT_PEM_FILE = "client.pem";
public static final String CLIENT_KEY_FILE = "client.key";
public static final String CA_PEM_FILE = "ca.pem";
/** Bad/untrusted server certs. */
public static final String BAD_SERVER_PEM_FILE = "badserver.pem";
public static final String BAD_SERVER_KEY_FILE = "badserver.key";
public static final String BAD_CLIENT_PEM_FILE = "badclient.pem";
public static final String BAD_CLIENT_KEY_FILE = "badclient.key";
/** takes additional values and creates CombinedCertificateValidationContext as needed. */
@SuppressWarnings("deprecation")
static CommonTlsContext buildCommonTlsContextWithAdditionalValues(
String certInstanceName, String certName,
String validationContextCertInstanceName, String validationContextCertName,
Iterable<StringMatcher> matchSubjectAltNames,
Iterable<String> alpnNames) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
CertificateProviderInstance certificateProviderInstance = CertificateProviderInstance
.newBuilder().setInstanceName(certInstanceName).setCertificateName(certName).build();
if (certificateProviderInstance != null) {
builder.setTlsCertificateCertificateProviderInstance(certificateProviderInstance);
}
CertificateProviderInstance validationCertificateProviderInstance =
CertificateProviderInstance.newBuilder().setInstanceName(validationContextCertInstanceName)
.setCertificateName(validationContextCertName).build();
CertificateValidationContext certValidationContext =
matchSubjectAltNames == null
? null
: CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(matchSubjectAltNames)
.build();
if (validationCertificateProviderInstance != null) {
CombinedCertificateValidationContext.Builder combinedBuilder =
CombinedCertificateValidationContext.newBuilder()
.setValidationContextCertificateProviderInstance(
validationCertificateProviderInstance);
if (certValidationContext != null) {
combinedBuilder = combinedBuilder.setDefaultValidationContext(certValidationContext);
}
builder.setCombinedValidationContext(combinedBuilder);
} else if (validationCertificateProviderInstance != null) {
builder
.setValidationContextCertificateProviderInstance(validationCertificateProviderInstance);
} else if (certValidationContext != null) {
builder.setValidationContext(certValidationContext);
}
if (alpnNames != null) {
builder.addAllAlpnProtocols(alpnNames);
}
return builder.build();
}
/** Helper method to build DownstreamTlsContext for multiple test classes. */
static DownstreamTlsContext buildDownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCert) {
DownstreamTlsContext.Builder downstreamTlsContextBuilder =
DownstreamTlsContext.newBuilder()
.setRequireClientCertificate(BoolValue.of(requireClientCert));
if (commonTlsContext != null) {
downstreamTlsContextBuilder = downstreamTlsContextBuilder
.setCommonTlsContext(commonTlsContext);
}
return downstreamTlsContextBuilder.build();
}
/** Helper method to build DownstreamTlsContext for multiple test classes. */
public static EnvoyServerProtoData.DownstreamTlsContext buildDownstreamTlsContext(
String commonInstanceName, boolean hasRootCert,
boolean requireClientCertificate) {
return buildDownstreamTlsContextForCertProviderInstance(
commonInstanceName,
"default",
hasRootCert ? commonInstanceName : null,
hasRootCert ? "ROOT" : null,
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ requireClientCertificate);
}
/** Helper method to build internal DownstreamTlsContext for multiple test classes. */
static EnvoyServerProtoData.DownstreamTlsContext buildInternalDownstreamTlsContext(
CommonTlsContext commonTlsContext, boolean requireClientCert) {
return EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(
buildDownstreamTlsContext(commonTlsContext, requireClientCert));
}
/** Helper method for creating DownstreamTlsContext values with names. */
public static DownstreamTlsContext buildTestDownstreamTlsContext(
String certName, String validationContextCertName, boolean useSans) {
CommonTlsContext commonTlsContext = null;
if (certName != null || validationContextCertName != null || useSans) {
commonTlsContext = buildCommonTlsContextWithAdditionalValues(
"cert-instance-name", certName,
"cert-instance-name", validationContextCertName,
useSans ? Arrays.asList(
StringMatcher.newBuilder()
.setExact("spiffe://grpc-sds-testing.svc.id.goog/ns/default/sa/bob")
.build()) : null,
Arrays.asList("managed-tls"));
}
return buildDownstreamTlsContext(commonTlsContext, /* requireClientCert= */ false);
}
public static EnvoyServerProtoData.DownstreamTlsContext buildTestInternalDownstreamTlsContext(
String certName, String validationContextName) {
return EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext(
buildTestDownstreamTlsContext(certName, validationContextName, true));
}
public static String getTempFileNameForResourcesFile(String resFile) throws IOException {
return TestUtils.loadCert(resFile).getAbsolutePath();
}
/**
* Helper method to build UpstreamTlsContext for above tests. Called from other classes as well.
*/
static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext(
CommonTlsContext commonTlsContext) {
UpstreamTlsContext upstreamTlsContext =
UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build();
return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
upstreamTlsContext);
}
/** Helper method to build UpstreamTlsContext for multiple test classes. */
public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext(
String commonInstanceName, boolean hasIdentityCert) {
return buildUpstreamTlsContextForCertProviderInstance(
hasIdentityCert ? commonInstanceName : null,
hasIdentityCert ? "default" : null,
commonInstanceName,
"ROOT",
null,
null);
}
/** Gets a cert from contents of a resource. */
public static X509Certificate getCertFromResourceName(String resourceName)
throws IOException, CertificateException {
try (ByteArrayInputStream bais =
new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) {
return CertificateUtils.toX509Certificate(bais);
}
}
/** Gets contents of a resource from TestUtils.class loader. */
public static String getResourceContents(String resourceName) throws IOException {
InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
String text = null;
try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
text = CharStreams.toString(reader);
}
return text;
}
@SuppressWarnings("deprecation")
private static CommonTlsContext buildCommonTlsContextForCertProviderInstance(
String certInstanceName,
String certName,
String rootInstanceName,
String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (certInstanceName != null) {
builder =
builder.setTlsCertificateCertificateProviderInstance(
CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(certInstanceName)
.setCertificateName(certName));
}
builder =
addCertificateValidationContext(
builder, rootInstanceName, rootCertName, staticCertValidationContext);
if (alpnProtocols != null) {
builder.addAllAlpnProtocols(alpnProtocols);
}
return builder.build();
}
private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance(
String certInstanceName,
String certName,
String rootInstanceName,
String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
CommonTlsContext.Builder builder = CommonTlsContext.newBuilder();
if (certInstanceName != null) {
builder =
builder.setTlsCertificateProviderInstance(
CertificateProviderPluginInstance.newBuilder()
.setInstanceName(certInstanceName)
.setCertificateName(certName));
}
builder =
addNewCertificateValidationContext(
builder, rootInstanceName, rootCertName, staticCertValidationContext);
if (alpnProtocols != null) {
builder.addAllAlpnProtocols(alpnProtocols);
}
return builder.build();
}
@SuppressWarnings("deprecation")
private static CommonTlsContext.Builder addCertificateValidationContext(
CommonTlsContext.Builder builder,
String rootInstanceName,
String rootCertName,
CertificateValidationContext staticCertValidationContext) {
if (rootInstanceName != null) {
CertificateProviderInstance providerInstance =
CertificateProviderInstance.newBuilder()
.setInstanceName(rootInstanceName)
.setCertificateName(rootCertName)
.build();
if (staticCertValidationContext != null) {
CombinedCertificateValidationContext combined =
CombinedCertificateValidationContext.newBuilder()
.setDefaultValidationContext(staticCertValidationContext)
.setValidationContextCertificateProviderInstance(providerInstance)
.build();
return builder.setCombinedValidationContext(combined);
}
builder = builder.setValidationContextCertificateProviderInstance(providerInstance);
}
return builder;
}
private static CommonTlsContext.Builder addNewCertificateValidationContext(
CommonTlsContext.Builder builder,
String rootInstanceName,
String rootCertName,
CertificateValidationContext staticCertValidationContext) {
if (rootInstanceName != null) {
CertificateProviderPluginInstance providerInstance =
CertificateProviderPluginInstance.newBuilder()
.setInstanceName(rootInstanceName)
.setCertificateName(rootCertName)
.build();
CertificateValidationContext.Builder validationContextBuilder =
staticCertValidationContext != null ? staticCertValidationContext.toBuilder()
: CertificateValidationContext.newBuilder();
return builder.setValidationContext(
validationContextBuilder.setCaCertificateProviderInstance(providerInstance));
}
return builder;
}
/** Helper method to build UpstreamTlsContext for CertProvider tests. */
public static EnvoyServerProtoData.UpstreamTlsContext
buildUpstreamTlsContextForCertProviderInstance(
@Nullable String certInstanceName,
@Nullable String certName,
@Nullable String rootInstanceName,
@Nullable String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
return buildUpstreamTlsContext(
buildCommonTlsContextForCertProviderInstance(
certInstanceName,
certName,
rootInstanceName,
rootCertName,
alpnProtocols,
staticCertValidationContext));
}
/** Helper method to build UpstreamTlsContext for CertProvider tests. */
public static EnvoyServerProtoData.UpstreamTlsContext
buildNewUpstreamTlsContextForCertProviderInstance(
@Nullable String certInstanceName,
@Nullable String certName,
@Nullable String rootInstanceName,
@Nullable String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext) {
return buildUpstreamTlsContext(
buildNewCommonTlsContextForCertProviderInstance(
certInstanceName,
certName,
rootInstanceName,
rootCertName,
alpnProtocols,
staticCertValidationContext));
}
/** Helper method to build DownstreamTlsContext for CertProvider tests. */
public static EnvoyServerProtoData.DownstreamTlsContext
buildDownstreamTlsContextForCertProviderInstance(
@Nullable String certInstanceName,
@Nullable String certName,
@Nullable String rootInstanceName,
@Nullable String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext,
boolean requireClientCert) {
return buildInternalDownstreamTlsContext(
buildCommonTlsContextForCertProviderInstance(
certInstanceName,
certName,
rootInstanceName,
rootCertName,
alpnProtocols,
staticCertValidationContext), requireClientCert);
}
/** Helper method to build DownstreamTlsContext for CertProvider tests. */
public static EnvoyServerProtoData.DownstreamTlsContext
buildNewDownstreamTlsContextForCertProviderInstance(
@Nullable String certInstanceName,
@Nullable String certName,
@Nullable String rootInstanceName,
@Nullable String rootCertName,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext,
boolean requireClientCert) {
return buildInternalDownstreamTlsContext(
buildNewCommonTlsContextForCertProviderInstance(
certInstanceName,
certName,
rootInstanceName,
rootCertName,
alpnProtocols,
staticCertValidationContext), requireClientCert);
}
/** Perform some simple checks on sslContext. */
public static void doChecksOnSslContext(boolean server, SslContext sslContext,
List<String> expectedApnProtos) {
if (server) {
assertThat(sslContext.isServer()).isTrue();
} else {
assertThat(sslContext.isClient()).isTrue();
}
List<String> apnProtos = sslContext.applicationProtocolNegotiator().protocols();
assertThat(apnProtos).isNotNull();
if (expectedApnProtos != null) {
assertThat(apnProtos).isEqualTo(expectedApnProtos);
} else {
assertThat(apnProtos).contains("h2");
}
}
/**
* Helper method to get the value thru directExecutor callback. Because of directExecutor this is
* a synchronous callback - so need to provide a listener.
*/
public static TestCallback getValueThruCallback(SslContextProvider provider) {
return getValueThruCallback(provider, MoreExecutors.directExecutor());
}
/** Helper method to get the value thru callback with a user passed executor. */
public static TestCallback getValueThruCallback(SslContextProvider provider, Executor executor) {
TestCallback testCallback = new TestCallback(executor);
provider.addCallback(testCallback);
return testCallback;
}
public static class TestCallback extends SslContextProvider.Callback {
public SslContext updatedSslContext;
public Throwable updatedThrowable;
public TestCallback(Executor executor) {
super(executor);
}
@Override
public void updateSecret(SslContext sslContext) {
updatedSslContext = sslContext;
}
@Override
public void onException(Throwable throwable) {
updatedThrowable = throwable;
}
}
}