xds: implement MeshCACertificateProvider (#7274)

diff --git a/build.gradle b/build.gradle
index 8eb98f2..f541539 100644
--- a/build.gradle
+++ b/build.gradle
@@ -177,6 +177,8 @@
             conscrypt: 'org.conscrypt:conscrypt-openjdk-uber:2.2.1',
             re2j: 'com.google.re2j:re2j:1.2',
 
+            bouncycastle: 'org.bouncycastle:bcpkix-jdk15on:1.61',
+
             // Test dependencies.
             junit: 'junit:junit:4.12',
             mockito: 'org.mockito:mockito-core:3.3.3',
diff --git a/xds/build.gradle b/xds/build.gradle
index 43902d8..2903366 100644
--- a/xds/build.gradle
+++ b/xds/build.gradle
@@ -26,7 +26,8 @@
             project(':grpc-auth'),
             project(path: ':grpc-alts', configuration: 'shadow'),
             libraries.gson,
-            libraries.re2j
+            libraries.re2j,
+            libraries.bouncycastle
     def nettyDependency = implementation project(':grpc-netty')
 
     implementation (libraries.opencensus_proto) {
diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java
index 9b96f99..b5d1497 100644
--- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java
+++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProvider.java
@@ -49,42 +49,42 @@
 
   @VisibleForTesting
   static final class DistributorWatcher implements Watcher {
-    private PrivateKey lastKey;
-    private List<X509Certificate> lastCertChain;
-    private List<X509Certificate> lastTrustedRoots;
+    private PrivateKey privateKey;
+    private List<X509Certificate> certChain;
+    private List<X509Certificate> trustedRoots;
 
     @VisibleForTesting
-    final Set<Watcher> downsstreamWatchers = new HashSet<>();
+    final Set<Watcher> downstreamWatchers = new HashSet<>();
 
     synchronized void addWatcher(Watcher watcher) {
-      downsstreamWatchers.add(watcher);
-      if (lastKey != null && lastCertChain != null) {
+      downstreamWatchers.add(watcher);
+      if (privateKey != null && certChain != null) {
         sendLastCertificateUpdate(watcher);
       }
-      if (lastTrustedRoots != null) {
+      if (trustedRoots != null) {
         sendLastTrustedRootsUpdate(watcher);
       }
     }
 
     synchronized void removeWatcher(Watcher watcher) {
-      downsstreamWatchers.remove(watcher);
+      downstreamWatchers.remove(watcher);
     }
 
     private void sendLastCertificateUpdate(Watcher watcher) {
-      watcher.updateCertificate(lastKey, lastCertChain);
+      watcher.updateCertificate(privateKey, certChain);
     }
 
     private void sendLastTrustedRootsUpdate(Watcher watcher) {
-      watcher.updateTrustedRoots(lastTrustedRoots);
+      watcher.updateTrustedRoots(trustedRoots);
     }
 
     @Override
     public synchronized void updateCertificate(PrivateKey key, List<X509Certificate> certChain) {
       checkNotNull(key, "key");
       checkNotNull(certChain, "certChain");
-      lastKey = key;
-      lastCertChain = certChain;
-      for (Watcher watcher : downsstreamWatchers) {
+      privateKey = key;
+      this.certChain = certChain;
+      for (Watcher watcher : downstreamWatchers) {
         sendLastCertificateUpdate(watcher);
       }
     }
@@ -92,18 +92,36 @@
     @Override
     public synchronized void updateTrustedRoots(List<X509Certificate> trustedRoots) {
       checkNotNull(trustedRoots, "trustedRoots");
-      lastTrustedRoots = trustedRoots;
-      for (Watcher watcher : downsstreamWatchers) {
+      this.trustedRoots = trustedRoots;
+      for (Watcher watcher : downstreamWatchers) {
         sendLastTrustedRootsUpdate(watcher);
       }
     }
 
     @Override
     public synchronized void onError(Status errorStatus) {
-      for (Watcher watcher : downsstreamWatchers) {
+      for (Watcher watcher : downstreamWatchers) {
         watcher.onError(errorStatus);
       }
     }
+
+    X509Certificate getLastIdentityCert() {
+      if (certChain != null && !certChain.isEmpty()) {
+        return certChain.get(0);
+      }
+      return null;
+    }
+
+    void close() {
+      downstreamWatchers.clear();
+      clearValues();
+    }
+
+    void clearValues() {
+      privateKey = null;
+      certChain = null;
+      trustedRoots = null;
+    }
   }
 
   /**
diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java
index 348f701..c038297 100644
--- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java
+++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java
@@ -17,35 +17,336 @@
 package io.grpc.xds.internal.certprovider;
 
 import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static io.grpc.Status.Code.ABORTED;
+import static io.grpc.Status.Code.CANCELLED;
+import static io.grpc.Status.Code.DEADLINE_EXCEEDED;
+import static io.grpc.Status.Code.INTERNAL;
+import static io.grpc.Status.Code.RESOURCE_EXHAUSTED;
+import static io.grpc.Status.Code.UNAVAILABLE;
+import static io.grpc.Status.Code.UNKNOWN;
+import static java.nio.charset.StandardCharsets.UTF_8;
 
 import com.google.auth.oauth2.GoogleCredentials;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.Duration;
+import google.security.meshca.v1.MeshCertificateServiceGrpc;
+import google.security.meshca.v1.Meshca;
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall;
+import io.grpc.InternalLogId;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.Status;
+import io.grpc.SynchronizationContext;
+import io.grpc.auth.MoreCallCredentials;
 import io.grpc.internal.BackoffPolicy;
+import io.grpc.internal.TimeProvider;
+import io.grpc.xds.internal.sds.trust.CertificateUtils;
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.StringWriter;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.NoSuchAlgorithmException;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.ArrayList;
+import java.util.EnumSet;
+import java.util.List;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 import java.util.logging.Level;
 import java.util.logging.Logger;
+import javax.security.auth.x500.X500Principal;
+import org.bouncycastle.openssl.jcajce.JcaPEMWriter;
+import org.bouncycastle.operator.ContentSigner;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder;
+import org.bouncycastle.pkcs.PKCS10CertificationRequest;
+import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder;
+import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder;
+import org.bouncycastle.util.io.pem.PemObject;
 
 /** Implementation of {@link CertificateProvider} for the Google Mesh CA. */
 final class MeshCaCertificateProvider extends CertificateProvider {
   private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName());
 
-  MeshCaCertificateProvider(DistributorWatcher watcher, boolean notifyCertUpdates,
-      String meshCaUrl, String zone, long validitySeconds,
-      int keySize, String alg, String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
-      BackoffPolicy.Provider backoffPolicyProvider, long renewalGracePeriodSeconds,
-      int maxRetryAttempts, GoogleCredentials oauth2Creds) {
+  MeshCaCertificateProvider(
+      DistributorWatcher watcher,
+      boolean notifyCertUpdates,
+      String meshCaUrl,
+      String zone,
+      long validitySeconds,
+      int keySize,
+      String alg,
+      String signatureAlg, MeshCaChannelFactory meshCaChannelFactory,
+      BackoffPolicy.Provider backoffPolicyProvider,
+      long renewalGracePeriodSeconds,
+      int maxRetryAttempts,
+      GoogleCredentials oauth2Creds,
+      ScheduledExecutorService scheduledExecutorService,
+      TimeProvider timeProvider,
+      long rpcTimeoutMillis) {
     super(watcher, notifyCertUpdates);
+    this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl");
+    checkArgument(
+        validitySeconds > INITIAL_DELAY_SECONDS,
+        "validitySeconds must be greater than " + INITIAL_DELAY_SECONDS);
+    this.validitySeconds = validitySeconds;
+    this.keySize = keySize;
+    this.alg = checkNotNull(alg, "alg");
+    this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg");
+    this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory");
+    this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
+    checkArgument(
+        renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds,
+        "renewalGracePeriodSeconds should be between 0 and " + validitySeconds);
+    this.renewalGracePeriodSeconds = renewalGracePeriodSeconds;
+    checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0");
+    this.maxRetryAttempts = maxRetryAttempts;
+    this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds");
+    this.scheduledExecutorService =
+        checkNotNull(scheduledExecutorService, "scheduledExecutorService");
+    this.timeProvider = checkNotNull(timeProvider, "timeProvider");
+    this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone"));
+    this.syncContext = createSynchronizationContext(meshCaUrl);
+    this.rpcTimeoutMillis = rpcTimeoutMillis;
+  }
+
+  private SynchronizationContext createSynchronizationContext(String details) {
+    final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details);
+    return new SynchronizationContext(
+        new Thread.UncaughtExceptionHandler() {
+          private boolean panicMode;
+
+          @Override
+          public void uncaughtException(Thread t, Throwable e) {
+            logger.log(
+                Level.SEVERE,
+                "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!",
+                e);
+            panic(e);
+          }
+
+          void panic(final Throwable t) {
+            if (panicMode) {
+              // Preserve the first panic information
+              return;
+            }
+            panicMode = true;
+            close();
+          }
+        });
   }
 
   @Override
   public void start() {
-    // TODO implement
+    scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS);
   }
 
   @Override
   public void close() {
-    // TODO implement
+    if (scheduledHandle != null) {
+      scheduledHandle.cancel();
+      scheduledHandle = null;
+    }
+    getWatcher().close();
+  }
+
+  private void scheduleNextRefreshCertificate(long delayInSeconds) {
+    if (scheduledHandle != null && scheduledHandle.isPending()) {
+      logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!");
+      scheduledHandle.cancel();
+    }
+    RefreshCertificateTask runnable = new RefreshCertificateTask();
+    scheduledHandle = syncContext.schedule(
+            runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService);
+  }
+
+  @VisibleForTesting
+  void refreshCertificate()
+      throws NoSuchAlgorithmException, IOException, OperatorCreationException {
+    long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry();
+    ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl);
+    try {
+      String uniqueReqIdForAllRetries = UUID.randomUUID().toString();
+      Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build();
+      KeyPair keyPair = generateKeyPair();
+      String csr = generateCsr(keyPair);
+      MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub =
+          createStubToMeshCa(channel);
+      List<X509Certificate> x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries,
+          duration, csr);
+      if (x509Chain != null) {
+        refreshDelaySeconds =
+            computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds;
+        getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain);
+        getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1)));
+      }
+    } finally {
+      shutdownChannel(channel);
+      scheduleNextRefreshCertificate(refreshDelaySeconds);
+    }
+  }
+
+  private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa(
+      ManagedChannel channel) {
+    return MeshCertificateServiceGrpc
+        .newBlockingStub(channel)
+        .withCallCredentials(MoreCallCredentials.from(oauth2Creds))
+        .withInterceptors(headerInterceptor);
+  }
+
+  private List<X509Certificate> makeRequestWithRetries(
+      MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub,
+      String reqId,
+      Duration duration,
+      String csr) {
+    Meshca.MeshCertificateRequest request =
+        Meshca.MeshCertificateRequest.newBuilder()
+            .setValidity(duration)
+            .setCsr(csr)
+            .setRequestId(reqId)
+            .build();
+
+    BackoffPolicy backoffPolicy = backoffPolicyProvider.get();
+    Throwable lastException = null;
+    for (int i = 0; i <= maxRetryAttempts; i++) {
+      try {
+        Meshca.MeshCertificateResponse response =
+            stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS)
+                .createCertificate(request);
+        return getX509CertificatesFromResponse(response);
+      } catch (Throwable t) {
+        if (!retriable(t)) {
+          generateErrorIfCurrentCertExpired(t);
+          return null;
+        }
+        lastException = t;
+        sleepForNanos(backoffPolicy.nextBackoffNanos());
+      }
+    }
+    generateErrorIfCurrentCertExpired(lastException);
+    return null;
+  }
+
+  private void sleepForNanos(long nanos) {
+    ScheduledFuture<?> future = scheduledExecutorService.schedule(new Runnable() {
+      @Override
+      public void run() {
+        // do nothing
+      }
+    }, nanos, TimeUnit.NANOSECONDS);
+    try {
+      future.get(nanos, TimeUnit.NANOSECONDS);
+    } catch (InterruptedException ie) {
+      logger.log(Level.SEVERE, "Inside sleep", ie);
+      Thread.currentThread().interrupt();
+    } catch (ExecutionException | TimeoutException ex) {
+      logger.log(Level.SEVERE, "Inside sleep", ex);
+    }
+  }
+
+  private static boolean retriable(Throwable t) {
+    return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode());
+  }
+
+  private void generateErrorIfCurrentCertExpired(Throwable t) {
+    X509Certificate currentCert = getWatcher().getLastIdentityCert();
+    if (currentCert != null) {
+      long delaySeconds = computeDelaySecondsToCertExpiry(currentCert);
+      if (delaySeconds > INITIAL_DELAY_SECONDS) {
+        return;
+      }
+      getWatcher().clearValues();
+    }
+    getWatcher().onError(Status.fromThrowable(t));
+  }
+
+  private KeyPair generateKeyPair() throws NoSuchAlgorithmException {
+    KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance(alg);
+    keyPairGenerator.initialize(keySize);
+    return keyPairGenerator.generateKeyPair();
+  }
+
+  private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException {
+    PKCS10CertificationRequestBuilder p10Builder =
+        new JcaPKCS10CertificationRequestBuilder(
+            new X500Principal("CN=EXAMPLE.COM"), pair.getPublic());
+    JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg);
+    ContentSigner signer = csBuilder.build(pair.getPrivate());
+    PKCS10CertificationRequest csr = p10Builder.build(signer);
+    PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded());
+    try (StringWriter str = new StringWriter()) {
+      try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) {
+        pemWriter.writeObject(pemObject);
+      }
+      return str.toString();
+    }
+  }
+
+  /** Compute refresh interval as half of interval to current cert expiry. */
+  private long computeRefreshSecondsFromCurrentCertExpiry() {
+    X509Certificate lastCert = getWatcher().getLastIdentityCert();
+    if (lastCert == null) {
+      return INITIAL_DELAY_SECONDS;
+    }
+    long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2;
+    return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS);
+  }
+
+  @SuppressWarnings("JdkObsolete")
+  private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) {
+    checkNotNull(lastCert, "lastCert");
+    return TimeUnit.NANOSECONDS.toSeconds(
+        TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider
+            .currentTimeNanos());
+  }
+
+  private static void shutdownChannel(ManagedChannel channel) {
+    channel.shutdown();
+    try {
+      channel.awaitTermination(10, TimeUnit.SECONDS);
+    } catch (InterruptedException ex) {
+      logger.log(Level.SEVERE, "awaiting channel Termination", ex);
+      channel.shutdownNow();
+      Thread.currentThread().interrupt();
+    }
+  }
+
+  private List<X509Certificate> getX509CertificatesFromResponse(
+      Meshca.MeshCertificateResponse response) throws CertificateException, IOException {
+    List<String> certChain = response.getCertChainList();
+    List<X509Certificate> x509Chain = new ArrayList<>(certChain.size());
+    for (String certString : certChain) {
+      try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) {
+        x509Chain.add(CertificateUtils.toX509Certificate(bais));
+      }
+    }
+    return x509Chain;
+  }
+
+  @VisibleForTesting
+  class RefreshCertificateTask implements Runnable {
+    @Override
+    public void run() {
+      try {
+        refreshCertificate();
+      } catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) {
+        logger.log(Level.SEVERE, "refreshing certificate", ex);
+      }
+    }
   }
 
   /** Factory for creating channels to MeshCA sever. */
@@ -94,7 +395,10 @@
               BackoffPolicy.Provider backoffPolicyProvider,
               long renewalGracePeriodSeconds,
               int maxRetryAttempts,
-              GoogleCredentials oauth2Creds) {
+              GoogleCredentials oauth2Creds,
+              ScheduledExecutorService scheduledExecutorService,
+              TimeProvider timeProvider,
+              long rpcTimeoutMillis) {
             return new MeshCaCertificateProvider(
                 watcher,
                 notifyCertUpdates,
@@ -108,7 +412,10 @@
                 backoffPolicyProvider,
                 renewalGracePeriodSeconds,
                 maxRetryAttempts,
-                oauth2Creds);
+                oauth2Creds,
+                scheduledExecutorService,
+                timeProvider,
+                rpcTimeoutMillis);
           }
         };
 
@@ -129,6 +436,64 @@
         BackoffPolicy.Provider backoffPolicyProvider,
         long renewalGracePeriodSeconds,
         int maxRetryAttempts,
-        GoogleCredentials oauth2Creds);
+        GoogleCredentials oauth2Creds,
+        ScheduledExecutorService scheduledExecutorService,
+        TimeProvider timeProvider,
+        long rpcTimeoutMillis);
   }
+
+  private class ZoneInfoClientInterceptor implements ClientInterceptor {
+    private final String zone;
+
+    ZoneInfoClientInterceptor(String zone) {
+      this.zone = zone;
+    }
+
+    @Override
+    public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+        MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
+      return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
+          next.newCall(method, callOptions)) {
+
+        @Override
+        public void start(Listener<RespT> responseListener, Metadata headers) {
+          headers.put(KEY_FOR_ZONE_INFO, zone);
+          super.start(responseListener, headers);
+        }
+      };
+    }
+  }
+
+  @VisibleForTesting
+  static final Metadata.Key<String> KEY_FOR_ZONE_INFO =
+      Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER);
+  @VisibleForTesting
+  static final long INITIAL_DELAY_SECONDS = 4L;
+
+  private static final EnumSet<Status.Code> RETRIABLE_CODES =
+      EnumSet.of(
+          CANCELLED,
+          UNKNOWN,
+          DEADLINE_EXCEEDED,
+          RESOURCE_EXHAUSTED,
+          ABORTED,
+          INTERNAL,
+          UNAVAILABLE);
+
+  private final SynchronizationContext syncContext;
+  private final ScheduledExecutorService scheduledExecutorService;
+  private final int maxRetryAttempts;
+  private final ZoneInfoClientInterceptor headerInterceptor;
+  private final BackoffPolicy.Provider backoffPolicyProvider;
+  private final String meshCaUrl;
+  private final long validitySeconds;
+  private final long renewalGracePeriodSeconds;
+  private final int keySize;
+  private final String alg;
+  private final String signatureAlg;
+  private final GoogleCredentials oauth2Creds;
+  private final TimeProvider timeProvider;
+  private final MeshCaChannelFactory meshCaChannelFactory;
+  @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle;
+  private final long rpcTimeoutMillis;
 }
diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java
index a9c1b01..669b0f2 100644
--- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java
+++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java
@@ -21,10 +21,15 @@
 import static com.google.common.base.Preconditions.checkState;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
 import io.grpc.internal.BackoffPolicy;
 import io.grpc.internal.ExponentialBackoffPolicy;
+import io.grpc.internal.TimeProvider;
 import io.grpc.xds.internal.sts.StsCredentials;
 import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
@@ -46,16 +51,20 @@
   private static final String STS_URL_KEY = "stsUrl";
   private static final String GKE_SA_JWT_LOCATION_KEY = "gkeSaJwtLocation";
 
-  static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
-  static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
-  static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; // 9 hours
-  static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; // 1 hour
-  static final String KEY_ALGO_DEFAULT = "RSA";  // aka keyType
-  static final int KEY_SIZE_DEFAULT = 2048;
-  static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
-  static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
+  @VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com";
+  @VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L;
+  @VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L;
+  @VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L;
+  @VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA";  // aka keyType
+  @VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048;
+  @VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA";
+  @VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3;
+  @VisibleForTesting
   static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken";
 
+  @VisibleForTesting
+  static final long RPC_TIMEOUT_SECONDS = 10L;
+
   private static final Pattern CLUSTER_URL_PATTERN = Pattern
       .compile(".*/projects/(.*)/locations/(.*)/clusters/.*");
 
@@ -70,23 +79,32 @@
                 StsCredentials.Factory.getInstance(),
                 MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(),
                 new ExponentialBackoffPolicy.Provider(),
-                MeshCaCertificateProvider.Factory.getInstance()));
+                MeshCaCertificateProvider.Factory.getInstance(),
+                ScheduledExecutorServiceFactory.DEFAULT_INSTANCE,
+                TimeProvider.SYSTEM_TIME_PROVIDER));
   }
 
   final StsCredentials.Factory stsCredentialsFactory;
   final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory;
   final BackoffPolicy.Provider backoffPolicyProvider;
   final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
+  final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory;
+  final TimeProvider timeProvider;
 
   @VisibleForTesting
-  MeshCaCertificateProviderProvider(StsCredentials.Factory stsCredentialsFactory,
+  MeshCaCertificateProviderProvider(
+      StsCredentials.Factory stsCredentialsFactory,
       MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory,
       BackoffPolicy.Provider backoffPolicyProvider,
-      MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory) {
+      MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory,
+      ScheduledExecutorServiceFactory scheduledExecutorServiceFactory,
+      TimeProvider timeProvider) {
     this.stsCredentialsFactory = stsCredentialsFactory;
     this.meshCaChannelFactory = meshCaChannelFactory;
     this.backoffPolicyProvider = backoffPolicyProvider;
     this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory;
+    this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory;
+    this.timeProvider = timeProvider;
   }
 
   @Override
@@ -106,12 +124,23 @@
     StsCredentials stsCredentials = stsCredentialsFactory
         .create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation);
 
-    return meshCaCertificateProviderFactory.create(watcher, notifyCertUpdates, configObj.meshCaUrl,
+    return meshCaCertificateProviderFactory.create(
+        watcher,
+        notifyCertUpdates,
+        configObj.meshCaUrl,
         configObj.zone,
-        configObj.certValiditySeconds, configObj.keySize, configObj.keyAlgo,
+        configObj.certValiditySeconds,
+        configObj.keySize,
+        configObj.keyAlgo,
         configObj.signatureAlgo,
-        meshCaChannelFactory, backoffPolicyProvider,
-        configObj.renewalGracePeriodSeconds, configObj.maxRetryAttempts, stsCredentials);
+        meshCaChannelFactory,
+        backoffPolicyProvider,
+        configObj.renewalGracePeriodSeconds,
+        configObj.maxRetryAttempts,
+        stsCredentials,
+        scheduledExecutorServiceFactory.create(configObj.meshCaUrl),
+        timeProvider,
+        TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS));
   }
 
   private static Config validateAndTranslateConfig(Object config) {
@@ -177,6 +206,28 @@
     configObj.zone = matcher.group(2);
   }
 
+  abstract static class ScheduledExecutorServiceFactory {
+
+    private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE =
+        new ScheduledExecutorServiceFactory() {
+
+          @Override
+          ScheduledExecutorService create(String serverUri) {
+            return Executors.newSingleThreadScheduledExecutor(
+                new ThreadFactoryBuilder()
+                    .setNameFormat("meshca-" + serverUri + "-%d")
+                    .setDaemon(true)
+                    .build());
+          }
+        };
+
+    static ScheduledExecutorServiceFactory getInstance() {
+      return DEFAULT_INSTANCE;
+    }
+
+    abstract ScheduledExecutorService create(String serverUri);
+  }
+
   /** POJO class for storing various config values. */
   @VisibleForTesting
   static class Config {
diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java
index 834065a..a85ea3d 100644
--- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java
+++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/CertificateUtils.java
@@ -30,7 +30,7 @@
 /**
  * Contains certificate utility method(s).
  */
-final class CertificateUtils {
+public final class CertificateUtils {
 
   private static CertificateFactory factory;
 
@@ -46,19 +46,26 @@
    * @param file a {@link File} containing the cert data
    */
   static X509Certificate[] toX509Certificates(File file) throws CertificateException, IOException {
-    FileInputStream fis = new FileInputStream(file);
-    return toX509Certificates(new BufferedInputStream(fis));
+    try (FileInputStream fis = new FileInputStream(file);
+        BufferedInputStream bis = new BufferedInputStream(fis)) {
+      return toX509Certificates(bis);
+    }
   }
 
   static synchronized X509Certificate[] toX509Certificates(InputStream inputStream)
       throws CertificateException, IOException {
     initInstance();
-    try {
-      Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
-      return certs.toArray(new X509Certificate[0]);
-    } finally {
-      inputStream.close();
-    }
+    Collection<? extends Certificate> certs = factory.generateCertificates(inputStream);
+    return certs.toArray(new X509Certificate[0]);
+
+  }
+
+  /** See {@link CertificateFactory#generateCertificate(InputStream)}. */
+  public static synchronized X509Certificate toX509Certificate(InputStream inputStream)
+          throws CertificateException, IOException {
+    initInstance();
+    Certificate cert = factory.generateCertificate(inputStream);
+    return (X509Certificate) cert;
   }
 
   private CertificateUtils() {}
diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java
index b495570..6ff63c0 100644
--- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java
+++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsTrustManagerFactory.java
@@ -27,6 +27,7 @@
 import io.netty.handler.ssl.util.SimpleTrustManagerFactory;
 import java.io.File;
 import java.io.IOException;
+import java.io.InputStream;
 import java.security.KeyStore;
 import java.security.KeyStoreException;
 import java.security.NoSuchAlgorithmException;
@@ -69,8 +70,10 @@
           "trustedCa.file-name in certificateValidationContext cannot be empty");
       return CertificateUtils.toX509Certificates(new File(certsFile));
     } else if (specifierCase == SpecifierCase.INLINE_BYTES) {
-      return CertificateUtils.toX509Certificates(
-          certificateValidationContext.getTrustedCa().getInlineBytes().newInput());
+      try (InputStream is =
+          certificateValidationContext.getTrustedCa().getInlineBytes().newInput()) {
+        return CertificateUtils.toX509Certificates(is);
+      }
     } else {
       throw new IllegalArgumentException("Not supported: " + specifierCase);
     }
diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java
index 8a0dfeb..569d72b 100644
--- a/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java
+++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/CertificateProviderStoreTest.java
@@ -169,7 +169,7 @@
         (TestCertificateProvider) handle1.certProvider;
     assertThat(testCertificateProvider.startCalled).isEqualTo(1);
     CertificateProvider.DistributorWatcher distWatcher = testCertificateProvider.getWatcher();
-    assertThat(distWatcher.downsstreamWatchers).hasSize(2);
+    assertThat(distWatcher.downstreamWatchers).hasSize(2);
     PrivateKey testKey = mock(PrivateKey.class);
     X509Certificate cert = mock(X509Certificate.class);
     List<X509Certificate> testList = ImmutableList.of(cert);
@@ -185,7 +185,7 @@
     reset(mockWatcher2);
     handle1.close();
     assertThat(testCertificateProvider.closeCalled).isEqualTo(0);
-    assertThat(distWatcher.downsstreamWatchers).hasSize(1);
+    assertThat(distWatcher.downstreamWatchers).hasSize(1);
     testCertificateProvider.getWatcher().updateCertificate(testKey, testList);
     verify(mockWatcher1, never())
         .updateCertificate(any(PrivateKey.class), anyListOf(X509Certificate.class));
@@ -221,7 +221,7 @@
     CertificateProvider.Watcher mockWatcher2 = mock(CertificateProvider.Watcher.class);
     CertificateProviderStore.Handle unused = certificateProviderStore.createOrGetProvider(
             "cert-name1", "plugin1", "config", mockWatcher2, true);
-    assertThat(distWatcher.downsstreamWatchers).hasSize(2);
+    assertThat(distWatcher.downstreamWatchers).hasSize(2);
     // updates sent to the second watcher
     verify(mockWatcher2, times(1)).updateCertificate(eq(testKey), eq(testList));
     verify(mockWatcher2, times(1)).updateTrustedRoots(eq(testList));
@@ -323,9 +323,9 @@
     assertThat(testCertificateProvider2.certProviderProvider)
         .isSameInstanceAs(certProviderProvider2);
     CertificateProvider.DistributorWatcher distWatcher1 = testCertificateProvider1.getWatcher();
-    assertThat(distWatcher1.downsstreamWatchers).hasSize(1);
+    assertThat(distWatcher1.downstreamWatchers).hasSize(1);
     CertificateProvider.DistributorWatcher distWatcher2 = testCertificateProvider2.getWatcher();
-    assertThat(distWatcher2.downsstreamWatchers).hasSize(1);
+    assertThat(distWatcher2.downstreamWatchers).hasSize(1);
     PrivateKey testKey1 = mock(PrivateKey.class);
     X509Certificate cert1 = mock(X509Certificate.class);
     List<X509Certificate> testList1 = ImmutableList.of(cert1);
diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java
index d9d4da9..b28bd49 100644
--- a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java
+++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java
@@ -17,19 +17,25 @@
 package io.grpc.xds.internal.certprovider;
 
 import static com.google.common.truth.Truth.assertThat;
+import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.isNull;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import com.google.auth.oauth2.GoogleCredentials;
 import io.grpc.internal.BackoffPolicy;
 import io.grpc.internal.ExponentialBackoffPolicy;
+import io.grpc.internal.TimeProvider;
 import io.grpc.xds.internal.sts.StsCredentials;
 import java.io.IOException;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -59,6 +65,14 @@
 
   @Mock
   MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory;
+
+  @Mock
+  private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory
+      scheduledExecutorServiceFactory;
+
+  @Mock
+  private TimeProvider timeProvider;
+
   private MeshCaCertificateProviderProvider provider;
 
   @Before
@@ -69,7 +83,9 @@
             stsCredentialsFactory,
             meshCaChannelFactory,
             backoffPolicyProvider,
-            meshCaCertificateProviderFactory);
+            meshCaCertificateProviderFactory,
+            scheduledExecutorServiceFactory,
+            timeProvider);
   }
 
   @Test
@@ -94,6 +110,10 @@
     CertificateProvider.DistributorWatcher distWatcher =
         new CertificateProvider.DistributorWatcher();
     Map<String, String> map = buildMinimalMap();
+    ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
+    when(scheduledExecutorServiceFactory.create(
+            eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT)))
+        .thenReturn(mockService);
     provider.createCertificateProvider(map, distWatcher, true);
     verify(stsCredentialsFactory, times(1))
         .create(
@@ -114,7 +134,10 @@
             eq(backoffPolicyProvider),
             eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT),
             eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT),
-            (GoogleCredentials) isNull());
+            (GoogleCredentials) isNull(),
+            eq(mockService),
+            eq(timeProvider),
+            eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
   }
 
   @Test
@@ -150,7 +173,9 @@
     CertificateProvider.DistributorWatcher distWatcher =
             new CertificateProvider.DistributorWatcher();
     Map<String, String> map = buildMinimalMap();
-    map.put("gkeClusterUrl", "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
+    map.put(
+        "gkeClusterUrl",
+        "https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3");
     try {
       provider.createCertificateProvider(map, distWatcher, true);
       fail("exception expected");
@@ -164,6 +189,9 @@
     CertificateProvider.DistributorWatcher distWatcher =
             new CertificateProvider.DistributorWatcher();
     Map<String, String> map = buildFullMap();
+    ScheduledExecutorService mockService = mock(ScheduledExecutorService.class);
+    when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL)))
+        .thenReturn(mockService);
     provider.createCertificateProvider(map, distWatcher, true);
     verify(stsCredentialsFactory, times(1))
             .create(
@@ -184,7 +212,10 @@
                     eq(backoffPolicyProvider),
                     eq(4321L),
                     eq(9),
-                    (GoogleCredentials) isNull());
+                    (GoogleCredentials) isNull(),
+                    eq(mockService),
+                    eq(timeProvider),
+                    eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)));
   }
 
   private Map<String, String> buildFullMap() {
diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java
new file mode 100644
index 0000000..7117b8c
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java
@@ -0,0 +1,538 @@
+/*
+ * Copyright 2020 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds.internal.certprovider;
+
+import static com.google.common.truth.Truth.assertThat;
+import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE;
+import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE;
+import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE;
+import static org.mockito.AdditionalAnswers.delegatesTo;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.doAnswer;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import com.google.auth.http.AuthHttpConstants;
+import com.google.auth.oauth2.AccessToken;
+import com.google.auth.oauth2.GoogleCredentials;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.MoreExecutors;
+import google.security.meshca.v1.MeshCertificateServiceGrpc;
+import google.security.meshca.v1.Meshca;
+import io.grpc.Context;
+import io.grpc.ManagedChannel;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+import io.grpc.SynchronizationContext;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.internal.BackoffPolicy;
+import io.grpc.internal.TimeProvider;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher;
+import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil;
+import java.io.IOException;
+import java.security.NoSuchAlgorithmException;
+import java.security.PrivateKey;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
+import java.util.ArrayDeque;
+import java.util.Date;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.bouncycastle.operator.OperatorCreationException;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatchers;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.Spy;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+/** Unit tests for {@link MeshCaCertificateProvider}. */
+@RunWith(JUnit4.class)
+public class MeshCaCertificateProviderTest {
+
+  private static final String TEST_STS_TOKEN = "test-stsToken";
+  private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L);
+  private static final Metadata.Key<String> KEY_FOR_AUTHORIZATION =
+      Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER);
+  private static final String ZONE = "us-west2-a";
+  private static final long START_DELAY = 200_000_000L;  // 0.2 seconds
+  private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4};
+  private static final long RPC_TIMEOUT_MILLIS = 100L;
+  /**
+   * Expire time of cert SERVER_0_PEM_FILE.
+   */
+  private static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L;
+  /**
+   * Cert validity of 12 hours for the above cert.
+   */
+  private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS
+      .convert(12, TimeUnit.HOURS);
+  /**
+   * Compute current time based on cert expiry and cert validity.
+   */
+  private static final long CURRENT_TIME_NANOS =
+      TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS);
+  @Rule
+  public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
+
+  private static class ResponseToSend {
+    Throwable getThrowable() {
+      throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
+    }
+
+    List<String> getList() {
+      throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName());
+    }
+  }
+
+  private static class ResponseThrowable extends ResponseToSend {
+    final Throwable throwableToSend;
+
+    ResponseThrowable(Throwable throwable) {
+      throwableToSend = throwable;
+    }
+
+    @Override
+    Throwable getThrowable() {
+      return throwableToSend;
+    }
+  }
+
+  private static class ResponseList extends ResponseToSend {
+    final List<String> listToSend;
+
+    ResponseList(List<String> list) {
+      listToSend = list;
+    }
+
+    @Override
+    List<String> getList() {
+      return listToSend;
+    }
+  }
+
+  private final Queue<Meshca.MeshCertificateRequest> receivedRequests = new ArrayDeque<>();
+  private final Queue<String> receivedStsCreds = new ArrayDeque<>();
+  private final Queue<String> receivedZoneValues = new ArrayDeque<>();
+  private final Queue<ResponseToSend> responsesToSend = new ArrayDeque<>();
+  private final Queue<String> oauth2Tokens = new ArrayDeque<>();
+  private final AtomicBoolean callEnded = new AtomicBoolean(true);
+
+  @Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService;
+  @Mock private CertificateProvider.Watcher mockWatcher;
+  @Mock private BackoffPolicy.Provider backoffPolicyProvider;
+  @Mock private BackoffPolicy backoffPolicy;
+  @Spy private GoogleCredentials oauth2Creds;
+  @Mock private ScheduledExecutorService timeService;
+  @Mock private TimeProvider timeProvider;
+
+  private ManagedChannel channel;
+  private MeshCaCertificateProvider provider;
+
+  @Before
+  public void setUp() throws IOException {
+    MockitoAnnotations.initMocks(this);
+    when(backoffPolicyProvider.get()).thenReturn(backoffPolicy);
+    when(backoffPolicy.nextBackoffNanos())
+        .thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]);
+    doAnswer(
+        new Answer<AccessToken>() {
+          @Override
+          public AccessToken answer(InvocationOnMock invocation) throws Throwable {
+            return new AccessToken(
+                oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L));
+          }
+        })
+        .when(oauth2Creds)
+        .refreshAccessToken();
+    final String meshCaUri = InProcessServerBuilder.generateName();
+    MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl =
+        new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() {
+
+          @Override
+          public void createCertificate(
+              google.security.meshca.v1.Meshca.MeshCertificateRequest request,
+              io.grpc.stub.StreamObserver<google.security.meshca.v1.Meshca.MeshCertificateResponse>
+                  responseObserver) {
+            assertThat(callEnded.get()).isTrue(); // ensure previous call was ended
+            callEnded.set(false);
+            Context.current()
+                .addListener(
+                    new Context.CancellationListener() {
+                      @Override
+                      public void cancelled(Context context) {
+                        callEnded.set(true);
+                      }
+                    },
+                    MoreExecutors.directExecutor());
+            receivedRequests.offer(request);
+            ResponseToSend response = responsesToSend.poll();
+            if (response instanceof ResponseThrowable) {
+              responseObserver.onError(response.getThrowable());
+            } else if (response instanceof ResponseList) {
+              List<String> certChainInResponse = response.getList();
+              Meshca.MeshCertificateResponse responseToSend =
+                  Meshca.MeshCertificateResponse.newBuilder()
+                      .addAllCertChain(certChainInResponse)
+                      .build();
+              responseObserver.onNext(responseToSend);
+              responseObserver.onCompleted();
+            } else {
+              callEnded.set(true);
+            }
+          }
+        };
+    mockedMeshCaService =
+        mock(
+            MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class,
+            delegatesTo(meshCaServiceImpl));
+    ServerInterceptor interceptor =
+        new ServerInterceptor() {
+          @Override
+          public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+              ServerCall<ReqT, RespT> call, Metadata headers, ServerCallHandler<ReqT, RespT> next) {
+            receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION));
+            receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO));
+            return next.startCall(call, headers);
+          }
+        };
+    cleanupRule.register(
+        InProcessServerBuilder.forName(meshCaUri)
+            .addService(mockedMeshCaService)
+            .intercept(interceptor)
+            .directExecutor()
+            .build()
+            .start());
+    channel =
+        cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build());
+    MeshCaCertificateProvider.MeshCaChannelFactory channelFactory =
+        new MeshCaCertificateProvider.MeshCaChannelFactory() {
+          @Override
+          ManagedChannel createChannel(String serverUri) {
+            assertThat(serverUri).isEqualTo(meshCaUri);
+            return channel;
+          }
+        };
+    CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher();
+    watcher.addWatcher(mockWatcher); //
+    provider =
+        new MeshCaCertificateProvider(
+            watcher,
+            true,
+            meshCaUri,
+            ZONE,
+            TimeUnit.HOURS.toSeconds(9L),
+            2048,
+            "RSA",
+            "SHA256withRSA",
+            channelFactory,
+            backoffPolicyProvider,
+            RENEWAL_GRACE_PERIOD_SECONDS,
+            MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT,
+            oauth2Creds,
+            timeService,
+            timeProvider,
+            RPC_TIMEOUT_MILLIS);
+  }
+
+  @Test
+  public void startAndClose() {
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture)
+        .when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    provider.start();
+    SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle;
+    assertThat(savedScheduledHandle).isNotNull();
+    assertThat(savedScheduledHandle.isPending()).isTrue();
+    verify(timeService, times(1))
+        .schedule(
+            any(Runnable.class),
+            eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
+            eq(TimeUnit.SECONDS));
+    DistributorWatcher distWatcher = provider.getWatcher();
+    assertThat(distWatcher.downstreamWatchers).hasSize(1);
+    PrivateKey mockKey = mock(PrivateKey.class);
+    X509Certificate mockCert = mock(X509Certificate.class);
+    distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert));
+    distWatcher.updateTrustedRoots(ImmutableList.of(mockCert));
+    provider.close();
+    assertThat(provider.scheduledHandle).isNull();
+    assertThat(savedScheduledHandle.isPending()).isFalse();
+    assertThat(distWatcher.downstreamWatchers).isEmpty();
+    assertThat(distWatcher.getLastIdentityCert()).isNull();
+  }
+
+  @Test
+  public void startTwice_noException() {
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture)
+        .when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    provider.start();
+    SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle;
+    provider.start();
+    SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle;
+    assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1);
+    assertThat(savedScheduledHandle2.isPending()).isTrue();
+  }
+
+  @Test
+  public void getCertificate()
+      throws IOException, CertificateException, OperatorCreationException,
+      NoSuchAlgorithmException {
+    oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+    responsesToSend.offer(
+        new ResponseList(ImmutableList.of(
+            CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
+            CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
+            CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
+    when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture)
+        .when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    provider.refreshCertificate();
+    Meshca.MeshCertificateRequest receivedReq = receivedRequests.poll();
+    assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L));
+    // cannot decode CSR: just check the PEM format delimiters
+    String csr = receivedReq.getCsr();
+    assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----");
+    verifyReceivedMetadataValues(1);
+    verify(timeService, times(1))
+        .schedule(
+            any(Runnable.class),
+            eq(
+                TimeUnit.MILLISECONDS.toSeconds(
+                    CERT0_VALIDITY_MILLIS
+                        - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
+            eq(TimeUnit.SECONDS));
+    verifyMockWatcher();
+  }
+
+  @Test
+  public void getCertificate_withError()
+      throws IOException, OperatorCreationException, NoSuchAlgorithmException {
+    oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+    responsesToSend
+        .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    provider.refreshCertificate();
+    verify(mockWatcher, never())
+        .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
+    verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
+    verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
+    verify(timeService, times(1)).schedule(any(Runnable.class),
+        eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
+        eq(TimeUnit.SECONDS));
+    verifyReceivedMetadataValues(1);
+  }
+
+  @Test
+  public void getCertificate_withError_withExistingCert()
+      throws IOException, OperatorCreationException, NoSuchAlgorithmException {
+    PrivateKey mockKey = mock(PrivateKey.class);
+    X509Certificate mockCert = mock(X509Certificate.class);
+    // have current cert expire in 3 hours from current time
+    long threeHoursFromNowMillis = TimeUnit.NANOSECONDS
+        .toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3));
+    when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis));
+    provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
+    reset(mockWatcher);
+    oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+    responsesToSend
+        .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
+    when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    provider.refreshCertificate();
+    verify(mockWatcher, never())
+        .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
+    verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
+    verify(mockWatcher, never()).onError(any(Status.class));
+    verify(timeService, times(1)).schedule(any(Runnable.class),
+        eq(5400L),
+        eq(TimeUnit.SECONDS));
+    assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull();
+    verifyReceivedMetadataValues(1);
+  }
+
+  @Test
+  public void getCertificate_withError_withExistingExpiredCert()
+      throws IOException, OperatorCreationException, NoSuchAlgorithmException {
+    PrivateKey mockKey = mock(PrivateKey.class);
+    X509Certificate mockCert = mock(X509Certificate.class);
+    // have current cert expire in 3 seconds from current time
+    long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS
+        .toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3));
+    when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis));
+    provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert));
+    reset(mockWatcher);
+    oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+    responsesToSend
+        .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION)));
+    when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    provider.refreshCertificate();
+    verify(mockWatcher, never())
+        .updateCertificate(any(PrivateKey.class), ArgumentMatchers.<X509Certificate>anyList());
+    verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.<X509Certificate>anyList());
+    verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION);
+    verify(timeService, times(1)).schedule(any(Runnable.class),
+        eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS),
+        eq(TimeUnit.SECONDS));
+    assertThat(provider.getWatcher().getLastIdentityCert()).isNull();
+    verifyReceivedMetadataValues(1);
+  }
+
+  @Test
+  public void getCertificate_retriesWithErrors()
+      throws IOException, CertificateException, OperatorCreationException,
+      NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
+    oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+    oauth2Tokens.offer(TEST_STS_TOKEN + "1");
+    oauth2Tokens.offer(TEST_STS_TOKEN + "2");
+    responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN)));
+    responsesToSend.offer(
+        new ResponseThrowable(
+            new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED))));
+    responsesToSend.offer(new ResponseList(ImmutableList.of(
+        CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
+        CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
+        CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
+    when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
+    doReturn(scheduledFutureSleep).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
+    provider.refreshCertificate();
+    assertThat(receivedRequests.size()).isEqualTo(3);
+    verify(timeService, times(1)).schedule(any(Runnable.class),
+        eq(TimeUnit.MILLISECONDS.toSeconds(
+            CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
+        eq(TimeUnit.SECONDS));
+    verifyRetriesWithBackoff(scheduledFutureSleep, 2);
+    verifyMockWatcher();
+    verifyReceivedMetadataValues(3);
+  }
+
+  @Test
+  public void getCertificate_retriesWithTimeouts()
+      throws IOException, CertificateException, OperatorCreationException,
+      NoSuchAlgorithmException, InterruptedException, ExecutionException, TimeoutException {
+    oauth2Tokens.offer(TEST_STS_TOKEN + "0");
+    oauth2Tokens.offer(TEST_STS_TOKEN + "1");
+    oauth2Tokens.offer(TEST_STS_TOKEN + "2");
+    oauth2Tokens.offer(TEST_STS_TOKEN + "3");
+    responsesToSend.offer(new ResponseToSend());
+    responsesToSend.offer(new ResponseToSend());
+    responsesToSend.offer(new ResponseToSend());
+    responsesToSend.offer(new ResponseList(ImmutableList.of(
+        CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE),
+        CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE),
+        CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE))));
+    when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS);
+    ScheduledFuture<?> scheduledFuture = mock(ScheduledFuture.class);
+    doReturn(scheduledFuture).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS));
+    ScheduledFuture<?> scheduledFutureSleep = mock(ScheduledFuture.class);
+    doReturn(scheduledFutureSleep).when(timeService)
+        .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS));
+    provider.refreshCertificate();
+    assertThat(receivedRequests.size()).isEqualTo(4);
+    verify(timeService, times(1)).schedule(any(Runnable.class),
+        eq(TimeUnit.MILLISECONDS.toSeconds(
+            CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))),
+        eq(TimeUnit.SECONDS));
+    verifyRetriesWithBackoff(scheduledFutureSleep, 3);
+    verifyMockWatcher();
+    verifyReceivedMetadataValues(4);
+  }
+
+  private void verifyRetriesWithBackoff(ScheduledFuture<?> scheduledFutureSleep, int numOfRetries)
+      throws InterruptedException, ExecutionException, TimeoutException {
+    for (int i = 0; i < numOfRetries; i++) {
+      long delayValue = DELAY_VALUES[i];
+      verify(timeService, times(1)).schedule(any(Runnable.class),
+          eq(delayValue),
+          eq(TimeUnit.NANOSECONDS));
+      verify(scheduledFutureSleep, times(1)).get(eq(delayValue), eq(TimeUnit.NANOSECONDS));
+    }
+  }
+
+  private void verifyMockWatcher() throws IOException, CertificateException {
+    ArgumentCaptor<List<X509Certificate>> certChainCaptor = ArgumentCaptor.forClass(null);
+    verify(mockWatcher, times(1))
+        .updateCertificate(any(PrivateKey.class), certChainCaptor.capture());
+    List<X509Certificate> certChain = certChainCaptor.getValue();
+    assertThat(certChain).hasSize(3);
+    assertThat(certChain.get(0))
+        .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE));
+    assertThat(certChain.get(1))
+        .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE));
+    assertThat(certChain.get(2))
+        .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
+
+    ArgumentCaptor<List<X509Certificate>> rootsCaptor = ArgumentCaptor.forClass(null);
+    verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture());
+    List<X509Certificate> roots = rootsCaptor.getValue();
+    assertThat(roots).hasSize(1);
+    assertThat(roots.get(0))
+        .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE));
+    verify(mockWatcher, never()).onError(any(Status.class));
+  }
+
+  private void verifyReceivedMetadataValues(int count) {
+    assertThat(receivedStsCreds).hasSize(count);
+    assertThat(receivedZoneValues).hasSize(count);
+    for (int i = 0; i < count; i++) {
+      assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i);
+      assertThat(receivedZoneValues.poll()).isEqualTo("us-west2-a");
+    }
+  }
+}
diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java
index b6782b2..afa57f9 100644
--- a/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java
+++ b/xds/src/test/java/io/grpc/xds/internal/sds/CommonTlsContextTestsUtil.java
@@ -16,7 +16,10 @@
 
 package io.grpc.xds.internal.sds;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
+
 import com.google.common.base.Strings;
+import com.google.common.io.CharStreams;
 import com.google.protobuf.BoolValue;
 import com.google.protobuf.Struct;
 import com.google.protobuf.Value;
@@ -36,7 +39,14 @@
 import io.envoyproxy.envoy.type.matcher.v3.StringMatcher;
 import io.grpc.internal.testing.TestUtils;
 import io.grpc.xds.EnvoyServerProtoData;
+import io.grpc.xds.internal.sds.trust.CertificateUtils;
+import java.io.ByteArrayInputStream;
 import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.Reader;
+import java.security.cert.CertificateException;
+import java.security.cert.X509Certificate;
 import java.util.Arrays;
 import javax.annotation.Nullable;
 
@@ -432,4 +442,23 @@
     return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext(
         upstreamTlsContext);
   }
+
+  /** Gets a cert from contents of a resource. */
+  public static X509Certificate getCertFromResourceName(String resourceName)
+      throws IOException, CertificateException {
+    try (ByteArrayInputStream bais =
+        new ByteArrayInputStream(getResourceContents(resourceName).getBytes(UTF_8))) {
+      return CertificateUtils.toX509Certificate(bais);
+    }
+  }
+
+  /** Gets contents of a resource from TestUtils.class loader. */
+  public static String getResourceContents(String resourceName) throws IOException {
+    InputStream inputStream = TestUtils.class.getResourceAsStream("/certs/" + resourceName);
+    String text = null;
+    try (Reader reader = new InputStreamReader(inputStream, UTF_8)) {
+      text = CharStreams.toString(reader);
+    }
+    return text;
+  }
 }