blob: 8e81c241705d0fad626194dabb2695eddef436db [file] [log] [blame]
/*
* Copyright 2019 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.security;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableSet;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext;
import io.grpc.xds.XdsInitializationException;
import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider;
import io.grpc.xds.internal.certprovider.CertificateProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderProvider;
import io.grpc.xds.internal.certprovider.CertificateProviderRegistry;
import io.grpc.xds.internal.certprovider.CertificateProviderStore;
import io.grpc.xds.internal.certprovider.TestCertificateProvider;
import java.io.IOException;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link ClientSslContextProviderFactory}. */
@RunWith(JUnit4.class)
public class ClientSslContextProviderFactoryTest {
CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory;
ClientSslContextProviderFactory clientSslContextProviderFactory;
@Before
public void setUp() {
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderClientSslContextProviderFactory =
new CertProviderClientSslContextProvider.Factory(certificateProviderStore);
}
@Test
public void createCertProviderClientSslContextProvider() throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
// verify that bootstrapInfo is cached...
sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
}
@Test
public void bothPresent_expectCertProviderClientSslContextProvider()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder();
builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem");
upstreamTlsContext = new UpstreamTlsContext(builder.build());
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_onlyRootCert()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_withStaticContext()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
@SuppressWarnings("deprecation")
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(bootstrapInfo,
certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createCertProviderClientSslContextProvider_2providers()
throws XdsInitializationException {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
@Test
public void createNewCertProviderClientSslContextProvider_withSans() {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[2];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "file_watcher", 1);
@SuppressWarnings("deprecation")
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
"gcp_id",
"cert-default",
"file_provider",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
verifyWatcher(sslContextProvider, watcherCaptor[1]);
}
@Test
public void createNewCertProviderClientSslContextProvider_onlyRootCert() {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0);
@SuppressWarnings("deprecation")
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.addAllMatchSubjectAltNames(
ImmutableSet.of(
StringMatcher.newBuilder().setExact("foo").build(),
StringMatcher.newBuilder().setExact("bar").build()))
.build();
UpstreamTlsContext upstreamTlsContext =
CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance(
/* certInstanceName= */ null,
/* certName= */ null,
"gcp_id",
"root-default",
/* alpnProtocols= */ null,
staticCertValidationContext);
Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo();
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
bootstrapInfo, certProviderClientSslContextProviderFactory);
SslContextProvider sslContextProvider =
clientSslContextProviderFactory.create(upstreamTlsContext);
assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class);
verifyWatcher(sslContextProvider, watcherCaptor[0]);
}
@Test
public void createNullCommonTlsContext_exception() throws IOException {
clientSslContextProviderFactory =
new ClientSslContextProviderFactory(
null, certProviderClientSslContextProviderFactory);
UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null);
try {
clientSslContextProviderFactory.create(upstreamTlsContext);
Assert.fail("no exception thrown");
} catch (NullPointerException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("upstreamTlsContext should have CommonTlsContext");
}
}
static void createAndRegisterProviderProvider(
CertificateProviderRegistry certificateProviderRegistry,
final CertificateProvider.DistributorWatcher[] watcherCaptor,
String testca,
final int i) {
final CertificateProviderProvider mockProviderProviderTestCa =
mock(CertificateProviderProvider.class);
when(mockProviderProviderTestCa.getName()).thenReturn(testca);
when(mockProviderProviderTestCa.createCertificateProvider(
any(Object.class), any(CertificateProvider.DistributorWatcher.class), eq(true)))
.thenAnswer(
new Answer<CertificateProvider>() {
@Override
public CertificateProvider answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
CertificateProvider.DistributorWatcher watcher =
(CertificateProvider.DistributorWatcher) args[1];
watcherCaptor[i] = watcher;
return new TestCertificateProvider(
watcher, true, args[0], mockProviderProviderTestCa, false);
}
});
certificateProviderRegistry.register(mockProviderProviderTestCa);
}
static void verifyWatcher(
SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) {
assertThat(watcherCaptor).isNotNull();
assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1);
assertThat(watcherCaptor.getDownstreamWatchers().iterator().next())
.isSameInstanceAs(sslContextProvider);
}
@SuppressWarnings("deprecation")
static CommonTlsContext.Builder addFilenames(
CommonTlsContext.Builder builder, String certChain, String privateKey, String trustCa) {
TlsCertificate tlsCert =
TlsCertificate.newBuilder()
.setCertificateChain(DataSource.newBuilder().setFilename(certChain))
.setPrivateKey(DataSource.newBuilder().setFilename(privateKey))
.build();
CertificateValidationContext certContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setFilename(trustCa))
.build();
CommonTlsContext.CertificateProviderInstance certificateProviderInstance =
builder.getValidationContextCertificateProviderInstance();
CommonTlsContext.CombinedCertificateValidationContext.Builder combinedBuilder =
CommonTlsContext.CombinedCertificateValidationContext.newBuilder();
combinedBuilder
.setDefaultValidationContext(certContext)
.setValidationContextCertificateProviderInstance(certificateProviderInstance);
return builder
.addTlsCertificates(tlsCert)
.setCombinedValidationContext(combinedBuilder.build());
}
}