xds: implement MeshCACertificateProvider (#7274)
diff --git a/build.gradle b/build.gradle
index 8eb98f2..f541539 100644
--- a/build.gradle
+++ b/build.gradle
@@ -177,6 +177,8 @@
conscrypt: 'org.conscrypt:conscrypt-openjdk-uber:2.2.1',
re2j: 'com.google.re2j:re2j:1.2',
+ bouncycastle: 'org.bouncycastle:bcpkix-jdk15on:1.61',
+
// Test dependencies.
junit: 'junit:junit:4.12',
mockito: 'org.mockito:mockito-core:3.3.3',
diff --git a/xds/build.gradle b/xds/build.gradle
index 43902d8..2903366 100644
--- a/xds/build.gradle
+++ b/xds/build.gradle
@@ -26,7 +26,8 @@
project(':grpc-auth'),
project(path: ':grpc-alts', configuration: 'shadow'),
libraries.gson,
- libraries.re2j
+ libraries.re2j,
+ libraries.bouncycastle
def nettyDependency = implementation project(':grpc-netty')
implementation (libraries.opencensus_proto) {
diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java
index 9b96f99..b5d1497 100644
--- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java
+++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java
@@ -49,42 +49,42 @@
@VisibleForTesting
static final class DistributorWatcher implements Watcher {
- private PrivateKey lastKey;
- private List<X509Certificate> lastCertChain;
- private List<X509Certificate> lastTrustedRoots;
+ private PrivateKey privateKey;
+ private List<X509Certificate> certChain;
+ private List<X509Certificate> trustedRoots;
@VisibleForTesting
- final Set<Watcher> downsstreamWatchers = new HashSet<>();
+ final Set<Watcher> downstreamWatchers = new HashSet<>();
synchronized void addWatcher(Watcher watcher) {
- downsstreamWatchers.add(watcher);
- if (lastKey != null && lastCertChain != null) {
+ downstreamWatchers.add(watcher);
+ if (privateKey != null && certChain != null) {
sendLastCertificateUpdate(watcher);
}
- if (lastTrustedRoots != null) {
+ if (trustedRoots != null) {
sendLastTrustedRootsUpdate(watcher);
}
}
synchronized void removeWatcher(Watcher watcher) {
- downsstreamWatchers.remove(watcher);
+ downstreamWatchers.remove(watcher);
}
private void sendLastCertificateUpdate(Watcher watcher) {
- watcher.updateCertificate(lastKey, lastCertChain);
+ watcher.updateCertificate(privateKey, certChain);
}
private void sendLastTrustedRootsUpdate(Watcher watcher) {
- watcher.updateTrustedRoots(lastTrustedRoots);
+ watcher.updateTrustedRoots(trustedRoots);
}
@Override
public synchronized void updateCertificate(PrivateKey key, List<X509Certificate> certChain) {
checkNotNull(key, "key");
checkNotNull(certChain, "certChain");
- lastKey = key;
- lastCertChain = certChain;
- for (Watcher watcher : downsstreamWatchers) {
+ privateKey = key;
+ this.certChain = certChain;
+ for (Watcher watcher : downstreamWatchers) {
sendLastCertificateUpdate(watcher);
}
}
@@ -92,18 +92,36 @@
@Override
public synchronized void updateTrustedRoots(List<X509Certificate> trustedRoots) {
checkNotNull(trustedRoots, "trustedRoots");
- lastTrustedRoots = trustedRoots;
- for (Watcher watcher : downsstreamWatchers) {
+ this.trustedRoots = trustedRoots;
+ for (Watcher watcher : downstreamWatchers) {
sendLastTrustedRootsUpdate(watcher);
}
}
@Override
public synchronized void onError(Status errorStatus) {
- for (Watcher watcher : downsstreamWatchers) {
+ for (Watcher watcher : downstreamWatchers) {
watcher.onError(errorStatus);
}
}
+
+ X509Certificate getLastIdentityCert() {
+ if (certChain != null && !certChain.isEmpty()) {
+ return certChain.get(0);
+ }
+ return null;
+ }
+
+ void close() {
+ downstreamWatchers.clear();
+ clearValues();
+ }
+
+ void clearValues() {
+ privateKey = null;
+ certChain = null;
+ trustedRoots = null;
+ }
}
/**
diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java
index 348f701..c038297 100644
--- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java
+++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java
@@ -17,35 +17,336 @@
package io.grpc.xds.internal.certprovider;
import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static io.grpc.Status.Code.ABORTED;
+import static io.grpc.Status.Code.CANCELLED;
+import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
+import static io.grpc.Status.Code.INTERNAL;
+import static io.grpc.Status.Code.RESOURCE_EXHAUSTED;
+import static io.grpc.Status.Code.UNAVAILABLE;
+import static io.grpc.Status.Code.UNKNOWN;
+import static java.nio.charset.StandardCharsets.UTF_8;
import com.google.auth.oauth2.GoogleCredentials;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.Duration;
+import google.security.meshca.v1.MeshCertificateServiceGrpc;
+import google.security.meshca.v1.Meshca;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall;
+import io.grpc.InternalLogId;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.Status;
+import io.grpc.SynchronizationContext;
+import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.BackoffPolicy;
+import io.grpc.internal.TimeProvider;
+import io.grpc.xds.internal.sds.trust.CertificateUtils;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.StringWriter;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
import java.util.logging.Level;
import java.util.logging.Logger;
+import javax.security.auth.x500.X500Principal;
+import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.bouncycastle.pkcs.PKCS10CertificationRequest;
+import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder;
+import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder;
+import org.bouncycastle.util.io.pem.PemObject;
/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
final class MeshCaCertificateProvider extends CertificateProvider {
private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
- MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
- String meshCaUrl, String zone, long validitySeconds,
- int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
- BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
- int maxRetryAttempts, GoogleCredentials oauth2Creds) {
+ MeshCaCertificateProvider(
+ DistributorWatcher watcher,
+ boolean notifyCertUpdates,
+ String meshCaUrl,
+ String zone,
+ long validitySeconds,
+ int keySize,
+ String alg,
+ String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
+ BackoffPolicy.Provider backoffPolicyProvider,
+ long renewalGracePeriodSeconds,
+ int maxRetryAttempts,
+ GoogleCredentials oauth2Creds,
+ ScheduledExecutorService scheduledExecutorService,
+ TimeProvider timeProvider,
+ long rpcTimeoutMillis) {
super(watcher, notifyCertUpdates);
+ this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl");
+ checkArgument(
+ validitySeconds > INITIAL_DELAY_SECONDS,
+ "validitySeconds must be greater than " + INITIAL_DELAY_SECONDS);
+ this.validitySeconds = validitySeconds;
+ this.keySize = keySize;
+ this.alg = checkNotNull(alg, "alg");
+ this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg");
+ this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory");
+ this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
+ checkArgument(
+ renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds,
+ "renewalGracePeriodSeconds should be between 0 and " + validitySeconds);
+ this.renewalGracePeriodSeconds = renewalGracePeriodSeconds;
+ checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0");
+ this.maxRetryAttempts = maxRetryAttempts;
+ this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds");
+ this.scheduledExecutorService =
+ checkNotNull(scheduledExecutorService, "scheduledExecutorService");
+ this.timeProvider = checkNotNull(timeProvider, "timeProvider");
+ this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone"));
+ this.syncContext = createSynchronizationContext(meshCaUrl);
+ this.rpcTimeoutMillis = rpcTimeoutMillis;
+ }
+
+ private SynchronizationContext createSynchronizationContext(String details) {
+ final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details);
+ return new SynchronizationContext(
+ new Thread.UncaughtExceptionHandler() {
+ private boolean panicMode;
+
+ @Override
+ public void uncaughtException(Thread t, Throwable e) {
+ logger.log(
+ Level.SEVERE,
+ "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
+ e);
+ panic(e);
+ }
+
+ void panic(final Throwable t) {
+ if (panicMode) {
+ // Preserve the first panic information
+ return;
+ }
+ panicMode = true;
+ close();
+ }
+ });
}
@Override
public void start() {
- // TODO implement
+ scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS);
}
@Override
public void close() {
- // TODO implement
+ if (scheduledHandle != null) {
+ scheduledHandle.cancel();
+ scheduledHandle = null;
+ }
+ getWatcher().close();
+ }
+
+ private void scheduleNextRefreshCertificate(long delayInSeconds) {
+ if (scheduledHandle != null && scheduledHandle.isPending()) {
+ logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!");
+ scheduledHandle.cancel();
+ }
+ RefreshCertificateTask runnable = new RefreshCertificateTask();
+ scheduledHandle = syncContext.schedule(
+ runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
+ }
+
+ @VisibleForTesting
+ void refreshCertificate()
+ throws NoSuchAlgorithmException, IOException, OperatorCreationException {
+ long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry();
+ ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl);
+ try {
+ String uniqueReqIdForAllRetries = UUID.randomUUID().toString();
+ Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build();
+ KeyPair keyPair = generateKeyPair();
+ String csr = generateCsr(keyPair);
+ MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub =
+ createStubToMeshCa(channel);
+ List<X509Certificate> x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries,
+ duration, csr);
+ if (x509Chain != null) {
+ refreshDelaySeconds =
+ computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds;
+ getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain);
+ getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1)));
+ }
+ } finally {
+ shutdownChannel(channel);
+ scheduleNextRefreshCertificate(refreshDelaySeconds);
+ }
+ }
+
+ private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa(
+ ManagedChannel channel) {
+ return MeshCertificateServiceGrpc
+ .newBlockingStub(channel)
+ .withCallCredentials(MoreCallCredentials.from(oauth2Creds))
+ .withInterceptors(headerInterceptor);
+ }
+
+ private List<X509Certificate> makeRequestWithRetries(
+ MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub,
+ String reqId,
+ Duration duration,
+ String csr) {
+ Meshca.MeshCertificateRequest request =
+ Meshca.MeshCertificateRequest.newBuilder()
+ .setValidity(duration)
+ .setCsr(csr)
+ .setRequestId(reqId)
+ .build();
+
+ BackoffPolicy backoffPolicy = backoffPolicyProvider.get();
+ Throwable lastException = null;
+ for (int i = 0; i <= maxRetryAttempts; i++) {
+ try {
+ Meshca.MeshCertificateResponse response =
+ stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS)
+ .createCertificate(request);
+ return getX509CertificatesFromResponse(response);
+ } catch (Throwable t) {
+ if (!retriable(t)) {
+ generateErrorIfCurrentCertExpired(t);
+ return null;
+ }
+ lastException = t;
+ sleepForNanos(backoffPolicy.nextBackoffNanos());
+ }
+ }
+ generateErrorIfCurrentCertExpired(lastException);
+ return null;
+ }
+
+ private void sleepForNanos(long nanos) {
+ ScheduledFuture<?> future = scheduledExecutorService.schedule(new Runnable() {
+ @Override
+ public void run() {
+ // do nothing
+ }
+ }, nanos, TimeUnit.NANOSECONDS);
+ try {
+ future.get(nanos, TimeUnit.NANOSECONDS);
+ } catch (InterruptedException ie) {
+ logger.log(Level.SEVERE, "Inside sleep", ie);
+ Thread.currentThread().interrupt();
+ } catch (ExecutionException | TimeoutException ex) {
+ logger.log(Level.SEVERE, "Inside sleep", ex);
+ }
+ }
+
+ private static boolean retriable(Throwable t) {
+ return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode());
+ }
+
+ private void generateErrorIfCurrentCertExpired(Throwable t) {
+ X509Certificate currentCert = getWatcher().getLastIdentityCert();
+ if (currentCert != null) {
+ long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
+ if (delaySeconds > INITIAL_DELAY_SECONDS) {
+ return;
+ }
+ getWatcher().clearValues();
+ }
+ getWatcher().onError(Status.fromThrowable(t));
+ }
+
+ private KeyPair generateKeyPair() throws NoSuchAlgorithmException {
+ KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(alg);
+ keyPairGenerator.initialize(keySize);
+ return keyPairGenerator.generateKeyPair();
+ }
+
+ private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException {
+ PKCS10CertificationRequestBuilder p10Builder =
+ new JcaPKCS10CertificationRequestBuilder(
+ new X500Principal("CN=EXAMPLE.COM"), pair.getPublic());
+ JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg);
+ ContentSigner signer = csBuilder.build(pair.getPrivate());
+ PKCS10CertificationRequest csr = p10Builder.build(signer);
+ PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded());
+ try (StringWriter str = new StringWriter()) {
+ try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) {
+ pemWriter.writeObject(pemObject);
+ }
+ return str.toString();
+ }
+ }
+
+ /** Compute refresh interval as half of interval to current cert expiry. */
+ private long computeRefreshSecondsFromCurrentCertExpiry() {
+ X509Certificate lastCert = getWatcher().getLastIdentityCert();
+ if (lastCert == null) {
+ return INITIAL_DELAY_SECONDS;
+ }
+ long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2;
+ return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS);
+ }
+
+ @SuppressWarnings("JdkObsolete")
+ private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
+ checkNotNull(lastCert, "lastCert");
+ return TimeUnit.NANOSECONDS.toSeconds(
+ TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider
+ .currentTimeNanos());
+ }
+
+ private static void shutdownChannel(ManagedChannel channel) {
+ channel.shutdown();
+ try {
+ channel.awaitTermination(10, TimeUnit.SECONDS);
+ } catch (InterruptedException ex) {
+ logger.log(Level.SEVERE, "awaiting channel Termination", ex);
+ channel.shutdownNow();
+ Thread.currentThread().interrupt();
+ }
+ }
+
+ private List<X509Certificate> getX509CertificatesFromResponse(
+ Meshca.MeshCertificateResponse response) throws CertificateException, IOException {
+ List<String> certChain = response.getCertChainList();
+ List<X509Certificate> x509Chain = new ArrayList<>(certChain.size());
+ for (String certString : certChain) {
+ try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) {
+ x509Chain.add(CertificateUtils.toX509Certificate(bais));
+ }
+ }
+ return x509Chain;
+ }
+
+ @VisibleForTesting
+ class RefreshCertificateTask implements Runnable {
+ @Override
+ public void run() {
+ try {
+ refreshCertificate();
+ } catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) {
+ logger.log(Level.SEVERE, "refreshing certificate", ex);
+ }
+ }
}
/** Factory for creating channels to MeshCA sever. */
@@ -94,7 +395,10 @@
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
- GoogleCredentials oauth2Creds) {
+ GoogleCredentials oauth2Creds,
+ ScheduledExecutorService scheduledExecutorService,
+ TimeProvider timeProvider,
+ long rpcTimeoutMillis) {
return new MeshCaCertificateProvider(
watcher,
notifyCertUpdates,
@@ -108,7 +412,10 @@
backoffPolicyProvider,
renewalGracePeriodSeconds,
maxRetryAttempts,
- oauth2Creds);
+ oauth2Creds,
+ scheduledExecutorService,
+ timeProvider,
+ rpcTimeoutMillis);
}
};
@@ -129,6 +436,64 @@
BackoffPolicy.Provider backoffPolicyProvider,
long renewalGracePeriodSeconds,
int maxRetryAttempts,
- GoogleCredentials oauth2Creds);
+ GoogleCredentials oauth2Creds,
+ ScheduledExecutorService scheduledExecutorService,
+ TimeProvider timeProvider,
+ long rpcTimeoutMillis);
}
+
+ private class ZoneInfoClientInterceptor implements ClientInterceptor {
+ private final String zone;
+
+ ZoneInfoClientInterceptor(String zone) {
+ this.zone = zone;
+ }
+
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+ MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+ return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
+ next.newCall(method, callOptions)) {
+
+ @Override
+ public void start(Listener<RespT> responseListener, Metadata headers) {
+ headers.put(KEY_FOR_ZONE_INFO, zone);
+ super.start(responseListener, headers);
+ }
+ };
+ }
+ }
+
+ @VisibleForTesting
+ static final Metadata.Key<String> KEY_FOR_ZONE_INFO =
+ Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER);
+ @VisibleForTesting
+ static final long INITIAL_DELAY_SECONDS = 4L;
+
+ private static final EnumSet<Status.Code> RETRIABLE_CODES =
+ EnumSet.of(
+ CANCELLED,
+ UNKNOWN,
+ DEADLINE_EXCEEDED,
+ RESOURCE_EXHAUSTED,
+ ABORTED,
+ INTERNAL,
+ UNAVAILABLE);
+
+ private final SynchronizationContext syncContext;
+ private final ScheduledExecutorService scheduledExecutorService;
+ private final int maxRetryAttempts;
+ private final ZoneInfoClientInterceptor headerInterceptor;
+ private final BackoffPolicy.Provider backoffPolicyProvider;
+ private final String meshCaUrl;
+ private final long validitySeconds;
+ private final long renewalGracePeriodSeconds;
+ private final int keySize;
+ private final String alg;
+ private final String signatureAlg;
+ private final GoogleCredentials oauth2Creds;
+ private final TimeProvider timeProvider;
+ private final MeshCaChannelFactory meshCaChannelFactory;
+ @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
+ private final long rpcTimeoutMillis;
}
diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java
index a9c1b01..669b0f2 100644
--- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java
+++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java
@@ -21,10 +21,15 @@
import static com.google.common.base.Preconditions.checkState;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
+import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sts.StsCredentials;
import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -46,16 +51,20 @@
private static final String STS_URL_KEY = "stsUrl";
private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
- static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
- static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
- static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
- static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
- static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
- static final int KEY_SIZE_DEFAULT = 2048;
- static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
- static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
+ @VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
+ @VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
+ @VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L;
+ @VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L;
+ @VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType
+ @VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048;
+ @VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
+ @VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
+ @VisibleForTesting
static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
+ @VisibleForTesting
+ static final long RPC_TIMEOUT_SECONDS = 10L;
+
private static final Pattern CLUSTER_URL_PATTERN = Pattern
.compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
@@ -70,23 +79,32 @@
StsCredentials.Factory.getInstance(),
MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
new ExponentialBackoffPolicy.Provider(),
- MeshCaCertificateProvider.Factory.getInstance()));
+ MeshCaCertificateProvider.Factory.getInstance(),
+ ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
+ TimeProvider.SYSTEM_TIME_PROVIDER));
}
final StsCredentials.Factory stsCredentialsFactory;
final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
final BackoffPolicy.Provider backoffPolicyProvider;
final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
+ final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
+ final TimeProvider timeProvider;
@VisibleForTesting
- MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
+ MeshCaCertificateProviderProvider(
+ StsCredentials.Factory stsCredentialsFactory,
MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
BackoffPolicy.Provider backoffPolicyProvider,
- MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
+ MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory,
+ ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
+ TimeProvider timeProvider) {
this.stsCredentialsFactory = stsCredentialsFactory;
this.meshCaChannelFactory = meshCaChannelFactory;
this.backoffPolicyProvider = backoffPolicyProvider;
this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
+ this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
+ this.timeProvider = timeProvider;
}
@Override
@@ -106,12 +124,23 @@
StsCredentials stsCredentials = stsCredentialsFactory
.create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
- return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
+ return meshCaCertificateProviderFactory.create(
+ watcher,
+ notifyCertUpdates,
+ configObj.meshCaUrl,
configObj.zone,
- configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
+ configObj.certValiditySeconds,
+ configObj.keySize,
+ configObj.keyAlgo,
configObj.signatureAlgo,
- meshCaChannelFactory, backoffPolicyProvider,
- configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
+ meshCaChannelFactory,
+ backoffPolicyProvider,
+ configObj.renewalGracePeriodSeconds,
+ configObj.maxRetryAttempts,
+ stsCredentials,
+ scheduledExecutorServiceFactory.create(configObj.meshCaUrl),
+ timeProvider,
+ TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS));
}
private static Config validateAndTranslateConfig(Object config) {
@@ -177,6 +206,28 @@
configObj.zone = matcher.group(2);
}
+ abstract static class ScheduledExecutorServiceFactory {
+
+ private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
+ new ScheduledExecutorServiceFactory() {
+
+ @Override
+ ScheduledExecutorService create(String serverUri) {
+ return Executors.newSingleThreadScheduledExecutor(
+ new ThreadFactoryBuilder()
+ .setNameFormat("meshca-" + serverUri + "-%d")
+ .setDaemon(true)
+ .build());
+ }
+ };
+
+ static ScheduledExecutorServiceFactory getInstance() {
+ return DEFAULT_INSTANCE;
+ }
+
+ abstract ScheduledExecutorService create(String serverUri);
+ }
+
/** POJO class for storing various config values. */
@VisibleForTesting
static class Config {
diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java
index 834065a..a85ea3d 100644
--- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java
+++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java
@@ -30,7 +30,7 @@
/**
* Contains certificate utility method(s).
*/
-final class CertificateUtils {
+public final class CertificateUtils {
private static CertificateFactory factory;
@@ -46,19 +46,26 @@
* @param file a {@link File} containing the cert data
*/
static X509Certificate[] toX509Certificates(File file) throws CertificateException, IOException {
- FileInputStream fis = new FileInputStream(file);
- return toX509Certificates(new BufferedInputStream(fis));
+ try (FileInputStream fis = new FileInputStream(file);
+ BufferedInputStream bis = new BufferedInputStream(fis)) {
+ return toX509Certificates(bis);
+ }
}
static synchronized X509Certificate[] toX509Certificates(InputStream inputStream)
throws CertificateException, IOException {
initInstance();
- try {
- Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
- return certs.toArray(new X509Certificate[0]);
- } finally {
- inputStream.close();
- }
+ Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
+ return certs.toArray(new X509Certificate[0]);
+
+ }
+
+ /** See {@link CertificateFactory#generateCertificate(InputStream)}. */
+ public static synchronized X509Certificate toX509Certificate(InputStream inputStream)
+ throws CertificateException, IOException {
+ initInstance();
+ Certificate cert = factory.generateCertificate(inputStream);
+ return (X509Certificate) cert;
}
private CertificateUtils() {}
diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java
index b495570..6ff63c0 100644
--- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java
+++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java
@@ -27,6 +27,7 @@
import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
import java.io.File;
import java.io.IOException;
+import java.io.InputStream;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.NoSuchAlgorithmException;
@@ -69,8 +70,10 @@
"trustedCa.file-name in certificateValidationContext cannot be empty");
return CertificateUtils.toX509Certificates(new File(certsFile));
} else if (specifierCase == SpecifierCase.INLINE_BYTES) {
- return CertificateUtils.toX509Certificates(
- certificateValidationContext.getTrustedCa().getInlineBytes().newInput());
+ try (InputStream is =
+ certificateValidationContext.getTrustedCa().getInlineBytes().newInput()) {
+ return CertificateUtils.toX509Certificates(is);
+ }
} else {
throw new IllegalArgumentException("Not supported: " + specifierCase);
}
diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java
index 8a0dfeb..569d72b 100644
--- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java
+++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java
@@ -169,7 +169,7 @@
(TestCertificateProvider) handle1.certProvider;
assertThat(testCertificateProvider.startCalled).isEqualTo(1);
CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
- assertThat(distWatcher.downsstreamWatchers).hasSize(2);
+ assertThat(distWatcher.downstreamWatchers).hasSize(2);
PrivateKey testKey = mock(PrivateKey.class);
X509Certificate cert = mock(X509Certificate.class);
List<X509Certificate> testList = ImmutableList.of(cert);
@@ -185,7 +185,7 @@
reset(mockWatcher2);
handle1.close();
assertThat(testCertificateProvider.closeCalled).isEqualTo(0);
- assertThat(distWatcher.downsstreamWatchers).hasSize(1);
+ assertThat(distWatcher.downstreamWatchers).hasSize(1);
testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
verify(mockWatcher1, never())
.updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
@@ -221,7 +221,7 @@
CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider(
"cert-name1", "plugin1", "config", mockWatcher2, true);
- assertThat(distWatcher.downsstreamWatchers).hasSize(2);
+ assertThat(distWatcher.downstreamWatchers).hasSize(2);
// updates sent to the second watcher
verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
@@ -323,9 +323,9 @@
assertThat(testCertificateProvider2.certProviderProvider)
.isSameInstanceAs(certProviderProvider2);
CertificateProvider.DistributorWatcher distWatcher1 = testCertificateProvider1.getWatcher();
- assertThat(distWatcher1.downsstreamWatchers).hasSize(1);
+ assertThat(distWatcher1.downstreamWatchers).hasSize(1);
CertificateProvider.DistributorWatcher distWatcher2 = testCertificateProvider2.getWatcher();
- assertThat(distWatcher2.downsstreamWatchers).hasSize(1);
+ assertThat(distWatcher2.downstreamWatchers).hasSize(1);
PrivateKey testKey1 = mock(PrivateKey.class);
X509Certificate cert1 = mock(X509Certificate.class);
List<X509Certificate> testList1 = ImmutableList.of(cert1);
diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java
index d9d4da9..b28bd49 100644
--- a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java
+++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java
@@ -17,19 +17,25 @@
package io.grpc.xds.internal.certprovider;
import static com.google.common.truth.Truth.assertThat;
+import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isNull;
+import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import com.google.auth.oauth2.GoogleCredentials;
import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.ExponentialBackoffPolicy;
+import io.grpc.internal.TimeProvider;
import io.grpc.xds.internal.sts.StsCredentials;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
@@ -59,6 +65,14 @@
@Mock
MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
+
+ @Mock
+ private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory
+ scheduledExecutorServiceFactory;
+
+ @Mock
+ private TimeProvider timeProvider;
+
private MeshCaCertificateProviderProvider provider;
@Before
@@ -69,7 +83,9 @@
stsCredentialsFactory,
meshCaChannelFactory,
backoffPolicyProvider,
- meshCaCertificateProviderFactory);
+ meshCaCertificateProviderFactory,
+ scheduledExecutorServiceFactory,
+ timeProvider);
}
@Test
@@ -94,6 +110,10 @@
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
+ ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
+ when(scheduledExecutorServiceFactory.create(
+ eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT)))
+ .thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
@@ -114,7 +134,10 @@
eq(backoffPolicyProvider),
eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
- (GoogleCredentials) isNull());
+ (GoogleCredentials) isNull(),
+ eq(mockService),
+ eq(timeProvider),
+ eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
@Test
@@ -150,7 +173,9 @@
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildMinimalMap();
- map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
+ map.put(
+ "gkeClusterUrl",
+ "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
try {
provider.createCertificateProvider(map, distWatcher, true);
fail("exception expected");
@@ -164,6 +189,9 @@
CertificateProvider.DistributorWatcher distWatcher =
new CertificateProvider.DistributorWatcher();
Map<String, String> map = buildFullMap();
+ ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
+ when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL)))
+ .thenReturn(mockService);
provider.createCertificateProvider(map, distWatcher, true);
verify(stsCredentialsFactory, times(1))
.create(
@@ -184,7 +212,10 @@
eq(backoffPolicyProvider),
eq(4321L),
eq(9),
- (GoogleCredentials) isNull());
+ (GoogleCredentials) isNull(),
+ eq(mockService),
+ eq(timeProvider),
+ eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
}
private Map<String, String> buildFullMap() {
diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java
new file mode 100644
index 0000000..7117b8c
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java
@@ -0,0 +1,538 @@
+/*
+ * Copyright 2020 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds.internal.certprovider;
+
+import static com.google.common.truth.Truth.assertThat;
+import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
+import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
+import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
+import static org.mockito.AdditionalAnswers.delegatesTo;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.auth.http.AuthHttpConstants;
+import com.google.auth.oauth2.AccessToken;
+import com.google.auth.oauth2.GoogleCredentials;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.MoreExecutors;
+import google.security.meshca.v1.MeshCertificateServiceGrpc;
+import google.security.meshca.v1.Meshca;
+import io.grpc.Context;
+import io.grpc.ManagedChannel;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.SynchronizationContext;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.internal.BackoffPolicy;
+import io.grpc.internal.TimeProvider;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
+import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
+import java.io.IOException;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.ArrayDeque;
+import java.util.Date;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatchers;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.Spy;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/** Unit tests for {@link MeshCaCertificateProvider}. */
+@RunWith(JUnit4.class)
+public class MeshCaCertificateProviderTest {
+
+ private static final String TEST_STS_TOKEN = "test-stsToken";
+ private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L);
+ private static final Metadata.Key<String> KEY_FOR_AUTHORIZATION =
+ Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER);
+ private static final String ZONE = "us-west2-a";
+ private static final long START_DELAY = 200_000_000L; // 0.2 seconds
+ private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4};
+ private static final long RPC_TIMEOUT_MILLIS = 100L;
+ /**
+ * Expire time of cert SERVER_0_PEM_FILE.
+ */
+ private static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
+ /**
+ * Cert validity of 12 hours for the above cert.
+ */
+ private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS
+ .convert(12, TimeUnit.HOURS);
+ /**
+ * Compute current time based on cert expiry and cert validity.
+ */
+ private static final long CURRENT_TIME_NANOS =
+ TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS);
+ @Rule
+ public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
+
+ private static class ResponseToSend {
+ Throwable getThrowable() {
+ throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
+ }
+
+ List<String> getList() {
+ throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
+ }
+ }
+
+ private static class ResponseThrowable extends ResponseToSend {
+ final Throwable throwableToSend;
+
+ ResponseThrowable(Throwable throwable) {
+ throwableToSend = throwable;
+ }
+
+ @Override
+ Throwable getThrowable() {
+ return throwableToSend;
+ }
+ }
+
+ private static class ResponseList extends ResponseToSend {
+ final List<String> listToSend;
+
+ ResponseList(List<String> list) {
+ listToSend = list;
+ }
+
+ @Override
+ List<String> getList() {
+ return listToSend;
+ }
+ }
+
+ private final Queue<Meshca.MeshCertificateRequest> receivedRequests = new ArrayDeque<>();
+ private final Queue<String> receivedStsCreds = new ArrayDeque<>();
+ private final Queue<String> receivedZoneValues = new ArrayDeque<>();
+ private final Queue<ResponseToSend> responsesToSend = new ArrayDeque<>();
+ private final Queue<String> oauth2Tokens = new ArrayDeque<>();
+ private final AtomicBoolean callEnded = new AtomicBoolean(true);
+
+ @Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService;
+ @Mock private CertificateProvider.Watcher mockWatcher;
+ @Mock private BackoffPolicy.Provider backoffPolicyProvider;
+ @Mock private BackoffPolicy backoffPolicy;
+ @Spy private GoogleCredentials oauth2Creds;
+ @Mock private ScheduledExecutorService timeService;
+ @Mock private TimeProvider timeProvider;
+
+ private ManagedChannel channel;
+ private MeshCaCertificateProvider provider;
+
+ @Before
+ public void setUp() throws IOException {
+ MockitoAnnotations.initMocks(this);
+ when(backoffPolicyProvider.get()).thenReturn(backoffPolicy);
+ when(backoffPolicy.nextBackoffNanos())
+ .thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]);
+ doAnswer(
+ new Answer<AccessToken>() {
+ @Override
+ public AccessToken answer(InvocationOnMock invocation) throws Throwable {
+ return new AccessToken(
+ oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L));
+ }
+ })
+ .when(oauth2Creds)
+ .refreshAccessToken();
+ final String meshCaUri = InProcessServerBuilder.generateName();
+ MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl =
+ new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() {
+
+ @Override
+ public void createCertificate(
+ google.security.meshca.v1.Meshca.MeshCertificateRequest request,
+ io.grpc.stub.StreamObserver<google.security.meshca.v1.Meshca.MeshCertificateResponse>
+ responseObserver) {
+ assertThat(callEnded.get()).isTrue(); // ensure previous call was ended
+ callEnded.set(false);
+ Context.current()
+ .addListener(
+ new Context.CancellationListener() {
+ @Override
+ public void cancelled(Context context) {
+ callEnded.set(true);
+ }
+ },
+ MoreExecutors.directExecutor());
+ receivedRequests.offer(request);
+ ResponseToSend response = responsesToSend.poll();
+ if (response instanceof ResponseThrowable) {
+ responseObserver.onError(response.getThrowable());
+ } else if (response instanceof ResponseList) {
+ List<String> certChainInResponse = response.getList();
+ Meshca.MeshCertificateResponse responseToSend =
+ Meshca.MeshCertificateResponse.newBuilder()
+ .addAllCertChain(certChainInResponse)
+ .build();
+ responseObserver.onNext(responseToSend);
+ responseObserver.onCompleted();
+ } else {
+ callEnded.set(true);
+ }
+ }
+ };
+ mockedMeshCaService =
+ mock(
+ MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class,
+ delegatesTo(meshCaServiceImpl));
+ ServerInterceptor interceptor =
+ new ServerInterceptor() {
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
+ receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION));
+ receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO));
+ return next.startCall(call, headers);
+ }
+ };
+ cleanupRule.register(
+ InProcessServerBuilder.forName(meshCaUri)
+ .addService(mockedMeshCaService)
+ .intercept(interceptor)
+ .directExecutor()
+ .build()
+ .start());
+ channel =
+ cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build());
+ MeshCaCertificateProvider.MeshCaChannelFactory channelFactory =
+ new MeshCaCertificateProvider.MeshCaChannelFactory() {
+ @Override
+ ManagedChannel createChannel(String serverUri) {
+ assertThat(serverUri).isEqualTo(meshCaUri);
+ return channel;
+ }
+ };
+ CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher();
+ watcher.addWatcher(mockWatcher); //
+ provider =
+ new MeshCaCertificateProvider(
+ watcher,
+ true,
+ meshCaUri,
+ ZONE,
+ TimeUnit.HOURS.toSeconds(9L),
+ 2048,
+ "RSA",
+ "SHA256withRSA",
+ channelFactory,
+ backoffPolicyProvider,
+ RENEWAL_GRACE_PERIOD_SECONDS,
+ MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT,
+ oauth2Creds,
+ timeService,
+ timeProvider,
+ RPC_TIMEOUT_MILLIS);
+ }
+
+ @Test
+ public void startAndClose() {
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture)
+ .when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ provider.start();
+ SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle;
+ assertThat(savedScheduledHandle).isNotNull();
+ assertThat(savedScheduledHandle.isPending()).isTrue();
+ verify(timeService, times(1))
+ .schedule(
+ any(Runnable.class),
+ eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
+ eq(TimeUnit.SECONDS));
+ DistributorWatcher distWatcher = provider.getWatcher();
+ assertThat(distWatcher.downstreamWatchers).hasSize(1);
+ PrivateKey mockKey = mock(PrivateKey.class);
+ X509Certificate mockCert = mock(X509Certificate.class);
+ distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert));
+ distWatcher.updateTrustedRoots(ImmutableList.of(mockCert));
+ provider.close();
+ assertThat(provider.scheduledHandle).isNull();
+ assertThat(savedScheduledHandle.isPending()).isFalse();
+ assertThat(distWatcher.downstreamWatchers).isEmpty();
+ assertThat(distWatcher.getLastIdentityCert()).isNull();
+ }
+
+ @Test
+ public void startTwice_noException() {
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture)
+ .when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ provider.start();
+ SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle;
+ provider.start();
+ SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle;
+ assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1);
+ assertThat(savedScheduledHandle2.isPending()).isTrue();
+ }
+
+ @Test
+ public void getCertificate()
+ throws IOException, CertificateException, OperatorCreationException,
+ NoSuchAlgorithmException {
+ oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+ responsesToSend.offer(
+ new ResponseList(ImmutableList.of(
+ CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
+ CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
+ CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
+ when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture)
+ .when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ provider.refreshCertificate();
+ Meshca.MeshCertificateRequest receivedReq = receivedRequests.poll();
+ assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L));
+ // cannot decode CSR: just check the PEM format delimiters
+ String csr = receivedReq.getCsr();
+ assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----");
+ verifyReceivedMetadataValues(1);
+ verify(timeService, times(1))
+ .schedule(
+ any(Runnable.class),
+ eq(
+ TimeUnit.MILLISECONDS.toSeconds(
+ CERT0_VALIDITY_MILLIS
+ - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
+ eq(TimeUnit.SECONDS));
+ verifyMockWatcher();
+ }
+
+ @Test
+ public void getCertificate_withError()
+ throws IOException, OperatorCreationException, NoSuchAlgorithmException {
+ oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+ responsesToSend
+ .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ provider.refreshCertificate();
+ verify(mockWatcher, never())
+ .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
+ verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
+ verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
+ verify(timeService, times(1)).schedule(any(Runnable.class),
+ eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
+ eq(TimeUnit.SECONDS));
+ verifyReceivedMetadataValues(1);
+ }
+
+ @Test
+ public void getCertificate_withError_withExistingCert()
+ throws IOException, OperatorCreationException, NoSuchAlgorithmException {
+ PrivateKey mockKey = mock(PrivateKey.class);
+ X509Certificate mockCert = mock(X509Certificate.class);
+ // have current cert expire in 3 hours from current time
+ long threeHoursFromNowMillis = TimeUnit.NANOSECONDS
+ .toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3));
+ when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis));
+ provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
+ reset(mockWatcher);
+ oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+ responsesToSend
+ .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
+ when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ provider.refreshCertificate();
+ verify(mockWatcher, never())
+ .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
+ verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
+ verify(mockWatcher, never()).onError(any(Status.class));
+ verify(timeService, times(1)).schedule(any(Runnable.class),
+ eq(5400L),
+ eq(TimeUnit.SECONDS));
+ assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull();
+ verifyReceivedMetadataValues(1);
+ }
+
+ @Test
+ public void getCertificate_withError_withExistingExpiredCert()
+ throws IOException, OperatorCreationException, NoSuchAlgorithmException {
+ PrivateKey mockKey = mock(PrivateKey.class);
+ X509Certificate mockCert = mock(X509Certificate.class);
+ // have current cert expire in 3 seconds from current time
+ long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS
+ .toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3));
+ when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis));
+ provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
+ reset(mockWatcher);
+ oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+ responsesToSend
+ .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
+ when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ provider.refreshCertificate();
+ verify(mockWatcher, never())
+ .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
+ verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
+ verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
+ verify(timeService, times(1)).schedule(any(Runnable.class),
+ eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
+ eq(TimeUnit.SECONDS));
+ assertThat(provider.getWatcher().getLastIdentityCert()).isNull();
+ verifyReceivedMetadataValues(1);
+ }
+
+ @Test
+ public void getCertificate_retriesWithErrors()
+ throws IOException, CertificateException, OperatorCreationException,
+ NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
+ oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+ oauth2Tokens.offer(TEST_STS_TOKEN + "1");
+ oauth2Tokens.offer(TEST_STS_TOKEN + "2");
+ responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN)));
+ responsesToSend.offer(
+ new ResponseThrowable(
+ new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED))));
+ responsesToSend.offer(new ResponseList(ImmutableList.of(
+ CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
+ CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
+ CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
+ when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
+ doReturn(scheduledFutureSleep).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
+ provider.refreshCertificate();
+ assertThat(receivedRequests.size()).isEqualTo(3);
+ verify(timeService, times(1)).schedule(any(Runnable.class),
+ eq(TimeUnit.MILLISECONDS.toSeconds(
+ CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
+ eq(TimeUnit.SECONDS));
+ verifyRetriesWithBackoff(scheduledFutureSleep, 2);
+ verifyMockWatcher();
+ verifyReceivedMetadataValues(3);
+ }
+
+ @Test
+ public void getCertificate_retriesWithTimeouts()
+ throws IOException, CertificateException, OperatorCreationException,
+ NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
+ oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+ oauth2Tokens.offer(TEST_STS_TOKEN + "1");
+ oauth2Tokens.offer(TEST_STS_TOKEN + "2");
+ oauth2Tokens.offer(TEST_STS_TOKEN + "3");
+ responsesToSend.offer(new ResponseToSend());
+ responsesToSend.offer(new ResponseToSend());
+ responsesToSend.offer(new ResponseToSend());
+ responsesToSend.offer(new ResponseList(ImmutableList.of(
+ CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
+ CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
+ CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
+ when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+ ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+ doReturn(scheduledFuture).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+ ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
+ doReturn(scheduledFutureSleep).when(timeService)
+ .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
+ provider.refreshCertificate();
+ assertThat(receivedRequests.size()).isEqualTo(4);
+ verify(timeService, times(1)).schedule(any(Runnable.class),
+ eq(TimeUnit.MILLISECONDS.toSeconds(
+ CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
+ eq(TimeUnit.SECONDS));
+ verifyRetriesWithBackoff(scheduledFutureSleep, 3);
+ verifyMockWatcher();
+ verifyReceivedMetadataValues(4);
+ }
+
+ private void verifyRetriesWithBackoff(ScheduledFuture<?> scheduledFutureSleep, int numOfRetries)
+ throws InterruptedException, ExecutionException, TimeoutException {
+ for (int i = 0; i < numOfRetries; i++) {
+ long delayValue = DELAY_VALUES[i];
+ verify(timeService, times(1)).schedule(any(Runnable.class),
+ eq(delayValue),
+ eq(TimeUnit.NANOSECONDS));
+ verify(scheduledFutureSleep, times(1)).get(eq(delayValue), eq(TimeUnit.NANOSECONDS));
+ }
+ }
+
+ private void verifyMockWatcher() throws IOException, CertificateException {
+ ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
+ verify(mockWatcher, times(1))
+ .updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
+ List<X509Certificate> certChain = certChainCaptor.getValue();
+ assertThat(certChain).hasSize(3);
+ assertThat(certChain.get(0))
+ .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE));
+ assertThat(certChain.get(1))
+ .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE));
+ assertThat(certChain.get(2))
+ .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
+
+ ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
+ verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
+ List<X509Certificate> roots = rootsCaptor.getValue();
+ assertThat(roots).hasSize(1);
+ assertThat(roots.get(0))
+ .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
+ verify(mockWatcher, never()).onError(any(Status.class));
+ }
+
+ private void verifyReceivedMetadataValues(int count) {
+ assertThat(receivedStsCreds).hasSize(count);
+ assertThat(receivedZoneValues).hasSize(count);
+ for (int i = 0; i < count; i++) {
+ assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i);
+ assertThat(receivedZoneValues.poll()).isEqualTo("us-west2-a");
+ }
+ }
+}
diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java
index b6782b2..afa57f9 100644
--- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java
+++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java
@@ -16,7 +16,10 @@
package io.grpc.xds.internal.sds;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
import com.google.common.base.Strings;
+import com.google.common.io.CharStreams;
import com.google.protobuf.BoolValue;
import com.google.protobuf.Struct;
import com.google.protobuf.Value;
@@ -36,7 +39,14 @@
import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
import io.grpc.internal.testing.TestUtils;
import io.grpc.xds.EnvoyServerProtoData;
+import io.grpc.xds.internal.sds.trust.CertificateUtils;
+import java.io.ByteArrayInputStream;
import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.Reader;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
import java.util.Arrays;
import javax.annotation.Nullable;
@@ -432,4 +442,23 @@
return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
upstreamTlsContext);
}
+
+ /** Gets a cert from contents of a resource. */
+ public static X509Certificate getCertFromResourceName(String resourceName)
+ throws IOException, CertificateException {
+ try (ByteArrayInputStream bais =
+ new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) {
+ return CertificateUtils.toX509Certificate(bais);
+ }
+ }
+
+ /** Gets contents of a resource from TestUtils.class loader. */
+ public static String getResourceContents(String resourceName) throws IOException {
+ InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
+ String text = null;
+ try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
+ text = CharStreams.toString(reader);
+ }
+ return text;
+ }
}