blob: 783ce2b11f7fba10abc6f1efefe2827b315841ec [file] [log] [blame]
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.xds.internal.certprovider.CommonCertProviderTestUtils.getCertFromResourceName;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.doChecksOnSslContext;
import static org.junit.Assert.fail;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
import io.envoyproxy.envoy.config.core.v3.DataSource;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.xds.Bootstrapper;
import io.grpc.xds.CommonBootstrapperTestUtils;
import io.grpc.xds.EnvoyServerProtoData;
import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProviderTest.QueuedExecutor;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.TestCallback;
import java.util.Arrays;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Unit tests for {@link CertProviderServerSslContextProvider}. */
@RunWith(JUnit4.class)
public class CertProviderServerSslContextProviderTest {
CertificateProviderRegistry certificateProviderRegistry;
CertificateProviderStore certificateProviderStore;
private CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory;
@Before
public void setUp() throws Exception {
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
certProviderServerSslContextProviderFactory =
new CertProviderServerSslContextProvider.Factory(certificateProviderStore);
}
/** Helper method to build CertProviderServerSslContextProvider. */
private CertProviderServerSslContextProvider getSslContextProvider(
String certInstanceName,
String rootInstanceName,
Bootstrapper.BootstrapInfo bootstrapInfo,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext,
boolean requireClientCert) {
EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance(
certInstanceName,
"cert-default",
rootInstanceName,
"root-default",
alpnProtocols,
staticCertValidationContext,
requireClientCert);
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
}
/** Helper method to build CertProviderServerSslContextProvider. */
private CertProviderServerSslContextProvider getNewSslContextProvider(
String certInstanceName,
String rootInstanceName,
Bootstrapper.BootstrapInfo bootstrapInfo,
Iterable<String> alpnProtocols,
CertificateValidationContext staticCertValidationContext,
boolean requireClientCert) {
EnvoyServerProtoData.DownstreamTlsContext downstreamTlsContext =
CommonTlsContextTestsUtil.buildNewDownstreamTlsContextForCertProviderInstance(
certInstanceName,
"cert-default",
rootInstanceName,
"root-default",
alpnProtocols,
staticCertValidationContext,
requireClientCert);
return certProviderServerSslContextProviderFactory.getProvider(
downstreamTlsContext,
bootstrapInfo.getNode().toEnvoyProtoNode(),
bootstrapInfo.getCertProviders());
}
@Test
public void testProviderForServer_mtls() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderServerSslContextProvider provider =
getSslContextProvider(
"gcp_id",
"gcp_id",
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNull();
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
assertThat(provider.savedKey).isNotNull();
assertThat(provider.savedCertChain).isNotNull();
assertThat(provider.getSslContext()).isNull();
// now generate root cert update
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
assertThat(provider.getSslContext()).isNotNull();
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
TestCallback testCallback1 =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
// just do root cert update: sslContext should still be the same
watcherCaptor[0].updateTrustedRoots(
ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE)));
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNotNull();
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
// now update id cert: sslContext should be updated i.e.different from the previous one
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE)));
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNotNull();
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext);
}
@Test
public void testProviderForServer_mtls_newXds() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder().addAllMatchSubjectAltNames(Arrays
.asList(StringMatcher.newBuilder().setExact("foo.com").build(),
StringMatcher.newBuilder().setExact("bar.com").build())).build();
CertProviderServerSslContextProvider provider =
getNewSslContextProvider(
"gcp_id",
"gcp_id",
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
staticCertValidationContext,
/* requireClientCert= */ true);
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNull();
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
assertThat(provider.savedKey).isNotNull();
assertThat(provider.savedCertChain).isNotNull();
assertThat(provider.getSslContext()).isNull();
// now generate root cert update
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
assertThat(provider.getSslContext()).isNotNull();
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
TestCallback testCallback1 =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
// just do root cert update: sslContext should still be the same
watcherCaptor[0].updateTrustedRoots(
ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE)));
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNotNull();
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext);
// now update id cert: sslContext should be updated i.e.different from the previous one
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE)));
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNotNull();
testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider);
assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext);
}
@Test
public void testProviderForServer_queueExecutor() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderServerSslContextProvider provider =
getSslContextProvider(
"gcp_id",
"gcp_id",
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ true);
QueuedExecutor queuedExecutor = new QueuedExecutor();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider, queuedExecutor);
assertThat(queuedExecutor.runQueue).isEmpty();
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
assertThat(queuedExecutor.runQueue).isEmpty(); // still empty
// now generate root cert update
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
assertThat(queuedExecutor.runQueue).hasSize(1);
queuedExecutor.drain();
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForServer_tls() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderServerSslContextProvider provider =
getSslContextProvider(
"gcp_id",
/* rootInstanceName= */ null,
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ false);
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
assertThat(provider.getSslContext()).isNull();
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
assertThat(provider.getSslContext()).isNotNull();
assertThat(provider.savedKey).isNull();
assertThat(provider.savedCertChain).isNull();
assertThat(provider.savedTrustedRoots).isNull();
TestCallback testCallback =
CommonTlsContextTestsUtil.getValueThruCallback(provider);
doChecksOnSslContext(true, testCallback.updatedSslContext, /* expectedApnProtos= */ null);
}
@Test
public void testProviderForServer_sslContextException_onError() throws Exception {
CertificateValidationContext staticCertValidationContext =
CertificateValidationContext.newBuilder()
.setTrustedCa(DataSource.newBuilder().setInlineString("foo"))
.build();
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
CertProviderServerSslContextProvider provider =
getSslContextProvider(
/* certInstanceName= */ "gcp_id",
/* rootInstanceName= */ "gcp_id",
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */null,
staticCertValidationContext,
/* requireClientCert= */ true);
// now generate cert update
watcherCaptor[0].updateCertificate(
CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE),
ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE)));
TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor());
provider.addCallback(testCallback);
try {
watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE)));
fail("exception expected");
} catch (RuntimeException expected) {
assertThat(expected)
.hasMessageThat()
.contains("only static certificateValidationContext expected");
}
assertThat(testCallback.updatedThrowable).isNotNull();
assertThat(testCallback.updatedThrowable)
.hasCauseThat()
.hasMessageThat()
.contains("only static certificateValidationContext expected");
}
@Test
public void testProviderForServer_certInstanceNull_expectError() throws Exception {
final CertificateProvider.DistributorWatcher[] watcherCaptor =
new CertificateProvider.DistributorWatcher[1];
TestCertificateProvider.createAndRegisterProviderProvider(
certificateProviderRegistry, watcherCaptor, "testca", 0);
try {
getSslContextProvider(
/* certInstanceName= */ null,
/* rootInstanceName= */ null,
CommonBootstrapperTestUtils.getTestBootstrapInfo(),
/* alpnProtocols= */ null,
/* staticCertValidationContext= */ null,
/* requireClientCert= */ false);
fail("exception expected");
} catch (NullPointerException expected) {
assertThat(expected).hasMessageThat().contains("Server SSL requires certInstance");
}
}
}