blob: 569d72bf43df71840a606432fb3afdba605e8439 [file] [log] [blame]
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyListOf;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import io.grpc.Status;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.util.List;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
/** Unit tests for {@link CertificateProviderStore}. */
@RunWith(JUnit4.class)
public class CertificateProviderStoreTest {
private CertificateProviderRegistry certificateProviderRegistry;
private CertificateProviderStore certificateProviderStore;
private boolean throwExceptionForCertUpdates;
private class TestCertificateProvider extends CertificateProvider {
Object config;
CertificateProviderProvider certProviderProvider;
int closeCalled = 0;
int startCalled = 0;
protected TestCertificateProvider(
CertificateProvider.DistributorWatcher watcher,
boolean notifyCertUpdates,
Object config,
CertificateProviderProvider certificateProviderProvider) {
super(watcher, notifyCertUpdates);
if (throwExceptionForCertUpdates && notifyCertUpdates) {
throw new UnsupportedOperationException("Provider does not support Certificate Updates.");
}
this.config = config;
this.certProviderProvider = certificateProviderProvider;
}
@Override
public void close() {
closeCalled++;
}
@Override
public void start() {
startCalled++;
}
}
@Before
public void setUp() {
certificateProviderRegistry = new CertificateProviderRegistry();
certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry);
throwExceptionForCertUpdates = false;
}
@Test
public void pluginNotRegistered_expectException() {
CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class);
try {
CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher, true);
fail("exception expected");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("Provider not found.");
}
}
@Test
public void pluginUnregistered_expectException() {
CertificateProviderProvider certificateProviderProvider = registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher, true);
handle.close();
certificateProviderRegistry.deregister(certificateProviderProvider);
try {
handle = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher, true);
fail("exception expected");
} catch (IllegalArgumentException expected) {
assertThat(expected).hasMessageThat().isEqualTo("Provider not found.");
}
}
@Test
public void notifyCertUpdatesNotSupported_expectException() {
CertificateProviderProvider unused = registerPlugin("plugin1");
throwExceptionForCertUpdates = true;
CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class);
try {
CertificateProviderStore.Handle unused1 =
certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher, true);
fail("exception expected");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("Provider does not support Certificate Updates.");
}
}
@Test
public void notifyCertUpdatesNotSupported_expectExceptionOnSecondCall() {
registerPlugin("plugin1");
throwExceptionForCertUpdates = true;
CertificateProvider.Watcher mockWatcher = mock(CertificateProvider.Watcher.class);
try (CertificateProviderStore.Handle unused =
certificateProviderStore
.createOrGetProvider("cert-name1", "plugin1", "config", mockWatcher, false)) {
try {
certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher, true);
fail("exception expected");
} catch (UnsupportedOperationException expected) {
assertThat(expected)
.hasMessageThat()
.isEqualTo("Provider does not support Certificate Updates.");
}
}
}
@Test
@SuppressWarnings("deprecation")
public void onePluginSameConfig_sameInstance() {
registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher2, true);
assertThat(handle1).isNotSameInstanceAs(handle2);
assertThat(handle1.certProvider).isSameInstanceAs(handle2.certProvider);
assertThat(handle1.certProvider).isInstanceOf(TestCertificateProvider.class);
TestCertificateProvider testCertificateProvider =
(TestCertificateProvider) handle1.certProvider;
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
assertThat(distWatcher.downstreamWatchers).hasSize(2);
PrivateKey testKey = mock(PrivateKey.class);
X509Certificate cert = mock(X509Certificate.class);
List<X509Certificate> testList = ImmutableList.of(cert);
testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
verify(mockWatcher1, times(1)).updateCertificate(eq(testKey), eq(testList));
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
reset(mockWatcher1);
reset(mockWatcher2);
testCertificateProvider.getWatcher().updateTrustedRoots(testList);
verify(mockWatcher1, times(1)).updateTrustedRoots(eq(testList));
verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
reset(mockWatcher1);
reset(mockWatcher2);
handle1.close();
assertThat(testCertificateProvider.closeCalled).isEqualTo(0);
assertThat(distWatcher.downstreamWatchers).hasSize(1);
testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
verify(mockWatcher1, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
testCertificateProvider.getWatcher().updateTrustedRoots(testList);
verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
handle2.close();
assertThat(testCertificateProvider.closeCalled).isEqualTo(1);
}
@Test
@SuppressWarnings("deprecation")
public void onePluginSameConfig_secondWatcherAfterFirstNotify() {
registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
TestCertificateProvider testCertificateProvider =
(TestCertificateProvider) handle1.certProvider;
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
PrivateKey testKey = mock(PrivateKey.class);
X509Certificate cert = mock(X509Certificate.class);
List<X509Certificate> testList = ImmutableList.of(cert);
testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
verify(mockWatcher1, times(1)).updateCertificate(eq(testKey), eq(testList));
testCertificateProvider.getWatcher().updateTrustedRoots(testList);
verify(mockWatcher1, times(1)).updateTrustedRoots(eq(testList));
testCertificateProvider.getWatcher().onError(Status.CANCELLED);
verify(mockWatcher1, times(1)).onError(eq(Status.CANCELLED));
reset(mockWatcher1);
// now add the second watcher
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher2, true);
assertThat(distWatcher.downstreamWatchers).hasSize(2);
// updates sent to the second watcher
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
// but not errors!
verify(mockWatcher2, never()).onError(eq(Status.CANCELLED));
// and none to first one
verify(mockWatcher1, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
verify(mockWatcher1, never()).updateTrustedRoots(anyListOf(X509Certificate.class));
verify(mockWatcher1, never()).onError(any(Status.class));
}
@Test
public void onePluginTwoInstances_notifyError() {
registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher2, true);
TestCertificateProvider testCertificateProvider =
(TestCertificateProvider) handle1.certProvider;
testCertificateProvider.getWatcher().onError(Status.CANCELLED);
verify(mockWatcher1, times(1)).onError(eq(Status.CANCELLED));
verify(mockWatcher2, times(1)).onError(eq(Status.CANCELLED));
handle1.close();
handle2.close();
}
@Test
public void onePluginDifferentConfig_differentInstance() {
CertificateProviderProvider certProviderProvider = registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config2", mockWatcher2, true);
checkDifferentInstances(
mockWatcher1, handle1, certProviderProvider, mockWatcher2, handle2, certProviderProvider);
}
@Test
public void onePluginDifferentCertName_differentInstance() {
CertificateProviderProvider certProviderProvider = registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider(
"cert-name2", "plugin1", "config", mockWatcher2, true);
checkDifferentInstances(
mockWatcher1, handle1, certProviderProvider, mockWatcher2, handle2, certProviderProvider);
}
@Test
public void onePluginDifferentNotifyValue_sameInstance() {
CertificateProviderProvider unused = registerPlugin("plugin1");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher2, false);
assertThat(handle1).isNotSameInstanceAs(handle2);
assertThat(handle1.certProvider).isSameInstanceAs(handle2.certProvider);
}
@Test
public void twoPlugins_differentInstance() {
CertificateProviderProvider certProviderProvider1 = registerPlugin("plugin1");
CertificateProviderProvider certProviderProvider2 = registerPlugin("plugin2");
CertificateProvider.Watcher mockWatcher1 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle1 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher1, true);
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle handle2 = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin2", "config", mockWatcher2, true);
checkDifferentInstances(
mockWatcher1, handle1, certProviderProvider1, mockWatcher2, handle2, certProviderProvider2);
}
@SuppressWarnings("deprecation")
private static void checkDifferentInstances(
CertificateProvider.Watcher mockWatcher1,
CertificateProviderStore.Handle handle1,
CertificateProviderProvider certProviderProvider1,
CertificateProvider.Watcher mockWatcher2,
CertificateProviderStore.Handle handle2,
CertificateProviderProvider certProviderProvider2) {
assertThat(handle1.certProvider).isNotSameInstanceAs(handle2.certProvider);
TestCertificateProvider testCertificateProvider1 =
(TestCertificateProvider) handle1.certProvider;
TestCertificateProvider testCertificateProvider2 =
(TestCertificateProvider) handle2.certProvider;
assertThat(testCertificateProvider1.certProviderProvider)
.isSameInstanceAs(certProviderProvider1);
assertThat(testCertificateProvider2.certProviderProvider)
.isSameInstanceAs(certProviderProvider2);
CertificateProvider.DistributorWatcher distWatcher1 = testCertificateProvider1.getWatcher();
assertThat(distWatcher1.downstreamWatchers).hasSize(1);
CertificateProvider.DistributorWatcher distWatcher2 = testCertificateProvider2.getWatcher();
assertThat(distWatcher2.downstreamWatchers).hasSize(1);
PrivateKey testKey1 = mock(PrivateKey.class);
X509Certificate cert1 = mock(X509Certificate.class);
List<X509Certificate> testList1 = ImmutableList.of(cert1);
testCertificateProvider1.getWatcher().updateCertificate(testKey1, testList1);
verify(mockWatcher1, times(1)).updateCertificate(eq(testKey1), eq(testList1));
verify(mockWatcher2, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
reset(mockWatcher1);
PrivateKey testKey2 = mock(PrivateKey.class);
X509Certificate cert2 = mock(X509Certificate.class);
List<X509Certificate> testList2 = ImmutableList.of(cert2);
testCertificateProvider2.getWatcher().updateCertificate(testKey2, testList2);
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey2), eq(testList2));
verify(mockWatcher1, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
assertThat(testCertificateProvider1.startCalled).isEqualTo(1);
assertThat(testCertificateProvider2.startCalled).isEqualTo(1);
handle2.close();
assertThat(testCertificateProvider2.closeCalled).isEqualTo(1);
handle1.close();
assertThat(testCertificateProvider1.closeCalled).isEqualTo(1);
}
private CertificateProviderProvider registerPlugin(String pluginName) {
final CertificateProviderProvider certProviderProvider =
mock(CertificateProviderProvider.class);
when(certProviderProvider.getName()).thenReturn(pluginName);
when(certProviderProvider.createCertificateProvider(
any(Object.class),
any(CertificateProvider.DistributorWatcher.class),
anyBoolean()))
.then(
new Answer<CertificateProvider>() {
@Override
public CertificateProvider answer(InvocationOnMock invocation) throws Throwable {
Object[] args = invocation.getArguments();
Object config = args[0];
CertificateProvider.DistributorWatcher watcher =
(CertificateProvider.DistributorWatcher) args[1];
boolean notifyCertUpdates = (Boolean) args[2];
return new TestCertificateProvider(
watcher, notifyCertUpdates, config, certProviderProvider);
}
});
certificateProviderRegistry.register(certProviderProvider);
return certProviderProvider;
}
}