xds: refactor XdsLrsClient and XdsLoadReportStore for integrating backend metric data in load report (#5728)

* make ClientLoadCounter as a separate class, added unit tests for it as it now counts quite many stats

* add MetricListener class that takes in a ClientLoadCounter and updates metric counts from received OrcaLoadReport

* refactor XdsClientLoadRecorder into XdsLoadReportStore for better integrity

* move interceptPickResult implementation to XdsLrsClient, no delegated call

* added unit test annotation

* created a StatsStore interface for better modularize LrsClient and LoadReportStore

* add more tests to ClientLoadCounter to increase coverage

* added tests for add/get/remove locality counter

* refactored tests for XdsLoadReportStore, with newly added abstract base class for ClientLoadCounter, real counter data is not involved, only stubbed snapshot is needed

* comparing doubles doing arithmetic is not recommended, but we are fine here as we are manually repeating the computation exactly

* added test case for two metric listeners with the same counter, metric values should be aggregated to the same counter

* fixed exception message and comment to only refer to interface

* removed unused variables

* cleaned up unused mock init

* removed unnecessary ClusterStats comparison helper method, as we are really comparing with the object manually created, order is deterministic

* trashed stuff for backend metrics, it should be in a separate PR

* added toString test

* remove Duration dependency in LoadReportStore

* use ThreadLocalRandom to generate positive double randoms directly

* rename XdsLoadReportStore to XdsLoadStatsStore

* rename XdsLrsClient to XdsLoadReportClient

* refactor ClientLoadSnapshot to be an exact snapshoht of ClientLoadCounter, use getters for ClientLoadSnapshot and avoid touching fields directly

* renamed XdsLoadStatsManager to XdsLoadReportClient and XdsLoadReportClient to XdsLoadReportClientImpl

* make fields final in ClientLoadSnapshot

* use a constant noop client stream tracer instead of creating new one for each noop client stream tracer factory

* rename loadReportStore for abstraction
diff --git a/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java b/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
new file mode 100644
index 0000000..d4db4a9
--- /dev/null
+++ b/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
@@ -0,0 +1,171 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.MoreObjects;
+import io.grpc.ClientStreamTracer;
+import io.grpc.ClientStreamTracer.StreamInfo;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.util.ForwardingClientStreamTracer;
+import io.grpc.xds.XdsLoadStatsStore.StatsCounter;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.concurrent.NotThreadSafe;
+import javax.annotation.concurrent.ThreadSafe;
+
+/**
+ * Client side aggregator for load stats.
+ *
+ * <p>All methods except {@link #snapshot()} in this class are thread-safe.
+ */
+@NotThreadSafe
+final class ClientLoadCounter extends XdsLoadStatsStore.StatsCounter {
+  private final AtomicLong callsInProgress = new AtomicLong();
+  private final AtomicLong callsFinished = new AtomicLong();
+  private final AtomicLong callsFailed = new AtomicLong();
+
+  ClientLoadCounter() {
+  }
+
+  /**
+   * Must only be used for testing.
+   */
+  @VisibleForTesting
+  ClientLoadCounter(long callsFinished, long callsInProgress, long callsFailed) {
+    this.callsFinished.set(callsFinished);
+    this.callsInProgress.set(callsInProgress);
+    this.callsFailed.set(callsFailed);
+  }
+
+  @Override
+  void incrementCallsInProgress() {
+    callsInProgress.getAndIncrement();
+  }
+
+  @Override
+  void decrementCallsInProgress() {
+    callsInProgress.getAndDecrement();
+  }
+
+  @Override
+  void incrementCallsFinished() {
+    callsFinished.getAndIncrement();
+  }
+
+  @Override
+  void incrementCallsFailed() {
+    callsFailed.getAndIncrement();
+  }
+
+  /**
+   * Generate snapshot for recorded query counts and metrics since previous snapshot.
+   *
+   * <p>This method is not thread-safe and must be called from {@link
+   * io.grpc.LoadBalancer.Helper#getSynchronizationContext()}.
+   */
+  @Override
+  public ClientLoadSnapshot snapshot() {
+    return new ClientLoadSnapshot(callsFinished.getAndSet(0),
+        callsInProgress.get(),
+        callsFailed.getAndSet(0));
+  }
+
+  /**
+   * A {@link ClientLoadSnapshot} represents a snapshot of {@link ClientLoadCounter} to be sent as
+   * part of {@link io.envoyproxy.envoy.api.v2.endpoint.ClusterStats} to the balancer.
+   */
+  static final class ClientLoadSnapshot {
+
+    @VisibleForTesting
+    static final ClientLoadSnapshot EMPTY_SNAPSHOT = new ClientLoadSnapshot(0, 0, 0);
+    private final long callsFinished;
+    private final long callsInProgress;
+    private final long callsFailed;
+
+    /**
+     * External usage must only be for testing.
+     */
+    @VisibleForTesting
+    ClientLoadSnapshot(long callsFinished, long callsInProgress, long callsFailed) {
+      this.callsFinished = callsFinished;
+      this.callsInProgress = callsInProgress;
+      this.callsFailed = callsFailed;
+    }
+
+    long getCallsFinished() {
+      return callsFinished;
+    }
+
+    long getCallsInProgress() {
+      return callsInProgress;
+    }
+
+    long getCallsFailed() {
+      return callsFailed;
+    }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this)
+          .add("callsFinished", callsFinished)
+          .add("callsInProgress", callsInProgress)
+          .add("callsFailed", callsFailed)
+          .toString();
+    }
+  }
+
+  /**
+   * An {@link XdsClientLoadRecorder} instance records and aggregates client-side load data into an
+   * {@link ClientLoadCounter} object.
+   */
+  @ThreadSafe
+  static final class XdsClientLoadRecorder extends ClientStreamTracer.Factory {
+
+    private final ClientStreamTracer.Factory delegate;
+    private final StatsCounter counter;
+
+    XdsClientLoadRecorder(StatsCounter counter, ClientStreamTracer.Factory delegate) {
+      this.counter = checkNotNull(counter, "counter");
+      this.delegate = checkNotNull(delegate, "delegate");
+    }
+
+    @Override
+    public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
+      counter.incrementCallsInProgress();
+      final ClientStreamTracer delegateTracer = delegate.newClientStreamTracer(info, headers);
+      return new ForwardingClientStreamTracer() {
+        @Override
+        protected ClientStreamTracer delegate() {
+          return delegateTracer;
+        }
+
+        @Override
+        public void streamClosed(Status status) {
+          counter.incrementCallsFinished();
+          counter.decrementCallsInProgress();
+          if (!status.isOk()) {
+            counter.incrementCallsFailed();
+          }
+          delegate().streamClosed(status);
+        }
+      };
+    }
+  }
+}
diff --git a/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java b/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java
deleted file mode 100644
index c2738b5..0000000
--- a/xds/src/main/java/io/grpc/xds/XdsClientLoadRecorder.java
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
- * Copyright 2019 The gRPC Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.grpc.xds;
-
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import com.google.common.annotations.VisibleForTesting;
-import io.grpc.ClientStreamTracer;
-import io.grpc.ClientStreamTracer.StreamInfo;
-import io.grpc.Metadata;
-import io.grpc.Status;
-import io.grpc.util.ForwardingClientStreamTracer;
-import java.util.concurrent.atomic.AtomicLong;
-import javax.annotation.concurrent.ThreadSafe;
-
-/**
- * An {@link XdsClientLoadRecorder} instance records and aggregates client-side load data into an
- * {@link ClientLoadCounter} object.
- */
-@ThreadSafe
-final class XdsClientLoadRecorder extends ClientStreamTracer.Factory {
-
-  private final ClientStreamTracer.Factory delegate;
-  private final ClientLoadCounter counter;
-
-  XdsClientLoadRecorder(ClientLoadCounter counter, ClientStreamTracer.Factory delegate) {
-    this.counter = checkNotNull(counter, "counter");
-    this.delegate = checkNotNull(delegate, "delegate");
-  }
-
-  @Override
-  public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
-    counter.callsInProgress.getAndIncrement();
-    final ClientStreamTracer delegateTracer = delegate.newClientStreamTracer(info, headers);
-    return new StreamTracer(delegateTracer);
-  }
-
-  /**
-   * A {@link ClientLoadSnapshot} represents a snapshot of {@link ClientLoadCounter} to be sent as
-   * part of {@link io.envoyproxy.envoy.api.v2.endpoint.ClusterStats} to the balancer.
-   */
-  static final class ClientLoadSnapshot {
-
-    final long callsSucceed;
-    final long callsInProgress;
-    final long callsFailed;
-
-    ClientLoadSnapshot(long callsSucceed, long callsInProgress, long callsFailed) {
-      this.callsSucceed = callsSucceed;
-      this.callsInProgress = callsInProgress;
-      this.callsFailed = callsFailed;
-    }
-  }
-
-  static final class ClientLoadCounter {
-
-    private final AtomicLong callsInProgress = new AtomicLong();
-    private final AtomicLong callsFinished = new AtomicLong();
-    private final AtomicLong callsFailed = new AtomicLong();
-    private boolean active = true;
-
-    ClientLoadCounter() {
-    }
-
-    @VisibleForTesting
-    ClientLoadCounter(long callsInProgress, long callsFinished, long callsFailed) {
-      this.callsInProgress.set(callsInProgress);
-      this.callsFinished.set(callsFinished);
-      this.callsFailed.set(callsFailed);
-    }
-
-    /**
-     * Generate a query count snapshot and reset counts for next snapshot.
-     */
-    ClientLoadSnapshot snapshot() {
-      long numFailed = callsFailed.getAndSet(0);
-      return new ClientLoadSnapshot(
-          callsFinished.getAndSet(0) - numFailed,
-          callsInProgress.get(),
-          numFailed);
-    }
-
-    boolean isActive() {
-      return active;
-    }
-
-    void setActive(boolean value) {
-      active = value;
-    }
-  }
-
-  private class StreamTracer extends ForwardingClientStreamTracer {
-
-    private final ClientStreamTracer delegate;
-
-    private StreamTracer(ClientStreamTracer delegate) {
-      this.delegate = checkNotNull(delegate, "delegate");
-    }
-
-    @Override
-    protected ClientStreamTracer delegate() {
-      return delegate;
-    }
-
-    @Override
-    public void streamClosed(Status status) {
-      counter.callsFinished.getAndIncrement();
-      counter.callsInProgress.getAndDecrement();
-      if (!status.isOk()) {
-        counter.callsFailed.getAndIncrement();
-      }
-      delegate().streamClosed(status);
-    }
-  }
-}
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadStatsManager.java b/xds/src/main/java/io/grpc/xds/XdsLoadReportClient.java
similarity index 93%
rename from xds/src/main/java/io/grpc/xds/XdsLoadStatsManager.java
rename to xds/src/main/java/io/grpc/xds/XdsLoadReportClient.java
index 6e2e533..7eec425 100644
--- a/xds/src/main/java/io/grpc/xds/XdsLoadStatsManager.java
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadReportClient.java
@@ -21,19 +21,19 @@
 import javax.annotation.concurrent.NotThreadSafe;
 
 /**
- * An {@link XdsLoadStatsManager} is in charge of recording client side load stats, collecting
+ * An {@link XdsLoadReportClient} is in charge of recording client side load stats, collecting
  * backend cost metrics and sending load reports to the remote balancer. It shares the same
  * channel with {@link XdsLoadBalancer} and its lifecycle is managed by {@link XdsLoadBalancer}.
  */
 @NotThreadSafe
-interface XdsLoadStatsManager {
+interface XdsLoadReportClient {
 
   /**
    * Establishes load reporting communication and negotiates with the remote balancer to report load
    * stats periodically.
    *
    * <p>This method should be the first method to be called in the lifecycle of {@link
-   * XdsLoadStatsManager} and should only be called once.
+   * XdsLoadReportClient} and should only be called once.
    *
    * <p>This method is not thread-safe and should be called from the same synchronized context
    * returned by {@link XdsLoadBalancer#helper#getSynchronizationContext}.
@@ -43,7 +43,7 @@
   /**
    * Terminates load reporting.
    *
-   * <p>No method in {@link XdsLoadStatsManager} should be called after calling this method.
+   * <p>No method in {@link XdsLoadReportClient} should be called after calling this method.
    *
    * <p>This method is not thread-safe and should be called from the same synchronized context
    * returned by {@link XdsLoadBalancer#helper#getSynchronizationContext}.
diff --git a/xds/src/main/java/io/grpc/xds/XdsLrsClient.java b/xds/src/main/java/io/grpc/xds/XdsLoadReportClientImpl.java
similarity index 83%
rename from xds/src/main/java/io/grpc/xds/XdsLrsClient.java
rename to xds/src/main/java/io/grpc/xds/XdsLoadReportClientImpl.java
index 65bbe29..7c1ca35 100644
--- a/xds/src/main/java/io/grpc/xds/XdsLrsClient.java
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadReportClientImpl.java
@@ -28,20 +28,26 @@
 import com.google.protobuf.util.Durations;
 import io.envoyproxy.envoy.api.v2.core.Locality;
 import io.envoyproxy.envoy.api.v2.core.Node;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
 import io.envoyproxy.envoy.service.load_stats.v2.LoadReportingServiceGrpc;
 import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest;
 import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse;
 import io.grpc.ChannelLogger;
 import io.grpc.ChannelLogger.ChannelLogLevel;
+import io.grpc.ClientStreamTracer;
+import io.grpc.ClientStreamTracer.StreamInfo;
 import io.grpc.LoadBalancer.Helper;
 import io.grpc.LoadBalancer.PickResult;
 import io.grpc.ManagedChannel;
+import io.grpc.Metadata;
 import io.grpc.Status;
 import io.grpc.SynchronizationContext;
 import io.grpc.SynchronizationContext.ScheduledHandle;
 import io.grpc.internal.BackoffPolicy;
 import io.grpc.internal.GrpcUtil;
 import io.grpc.stub.StreamObserver;
+import io.grpc.xds.ClientLoadCounter.XdsClientLoadRecorder;
+import io.grpc.xds.XdsLoadStatsStore.StatsCounter;
 import java.util.Collections;
 import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
@@ -55,11 +61,22 @@
  * returns.
  */
 @NotThreadSafe
-final class XdsLrsClient implements XdsLoadStatsManager {
+final class XdsLoadReportClientImpl implements XdsLoadReportClient {
 
   @VisibleForTesting
   static final String TRAFFICDIRECTOR_HOSTNAME_FIELD
       = "com.googleapis.trafficdirector.grpc_hostname";
+  private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER =
+      new ClientStreamTracer() {
+      };
+  private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
+      new ClientStreamTracer.Factory() {
+        @Override
+        public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
+          return NOOP_CLIENT_STREAM_TRACER;
+        }
+      };
+
   private final String serviceName;
   private final ManagedChannel channel;
   private final SynchronizationContext syncContext;
@@ -68,7 +85,7 @@
   private final Stopwatch retryStopwatch;
   private final ChannelLogger logger;
   private final BackoffPolicy.Provider backoffPolicyProvider;
-  private final XdsLoadReportStore loadReportStore;
+  private final StatsStore statsStore;
   private boolean started;
 
   @Nullable
@@ -79,19 +96,19 @@
   @Nullable
   private LrsStream lrsStream;
 
-  XdsLrsClient(ManagedChannel channel,
+  XdsLoadReportClientImpl(ManagedChannel channel,
       Helper helper,
       BackoffPolicy.Provider backoffPolicyProvider) {
     this(channel, helper, GrpcUtil.STOPWATCH_SUPPLIER, backoffPolicyProvider,
-        new XdsLoadReportStore(checkNotNull(helper, "helper").getAuthority()));
+        new XdsLoadStatsStore(checkNotNull(helper, "helper").getAuthority()));
   }
 
   @VisibleForTesting
-  XdsLrsClient(ManagedChannel channel,
+  XdsLoadReportClientImpl(ManagedChannel channel,
       Helper helper,
       Supplier<Stopwatch> stopwatchSupplier,
       BackoffPolicy.Provider backoffPolicyProvider,
-      XdsLoadReportStore loadReportStore) {
+      StatsStore statsStore) {
     this.channel = checkNotNull(channel, "channel");
     this.serviceName = checkNotNull(helper.getAuthority(), "serviceName");
     this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext");
@@ -100,7 +117,7 @@
     this.logger = checkNotNull(helper.getChannelLogger(), "logger");
     this.timerService = checkNotNull(helper.getScheduledExecutorService(), "timeService");
     this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider");
-    this.loadReportStore = checkNotNull(loadReportStore, "loadReportStore");
+    this.statsStore = checkNotNull(statsStore, "statsStore");
     started = false;
   }
 
@@ -126,26 +143,38 @@
   public void addLocality(Locality locality) {
     checkState(started, "load reporting must be started first");
     syncContext.throwIfNotInThisSynchronizationContext();
-    loadReportStore.addLocality(locality);
+    statsStore.addLocality(locality);
   }
 
   @Override
   public void removeLocality(final Locality locality) {
     checkState(started, "load reporting must be started first");
     syncContext.throwIfNotInThisSynchronizationContext();
-    loadReportStore.removeLocality(locality);
+    statsStore.removeLocality(locality);
   }
 
   @Override
   public void recordDroppedRequest(String category) {
     checkState(started, "load reporting must be started first");
-    loadReportStore.recordDroppedRequest(category);
+    statsStore.recordDroppedRequest(category);
   }
 
   @Override
   public PickResult interceptPickResult(PickResult pickResult, Locality locality) {
     checkState(started, "load reporting must be started first");
-    return loadReportStore.interceptPickResult(pickResult, locality);
+    if (!pickResult.getStatus().isOk()) {
+      return pickResult;
+    }
+    StatsCounter counter = statsStore.getLocalityCounter(locality);
+    if (counter == null) {
+      return pickResult;
+    }
+    ClientStreamTracer.Factory originFactory = pickResult.getStreamTracerFactory();
+    if (originFactory == null) {
+      originFactory = NOOP_CLIENT_STREAM_TRACER_FACTORY;
+    }
+    XdsClientLoadRecorder recorder = new XdsClientLoadRecorder(counter, originFactory);
+    return PickResult.withSubchannel(pickResult.getSubchannel(), recorder);
   }
 
   @VisibleForTesting
@@ -245,13 +274,18 @@
     private void sendLoadReport() {
       long interval = reportStopwatch.elapsed(TimeUnit.NANOSECONDS);
       reportStopwatch.reset().start();
+      ClusterStats report =
+          statsStore.generateLoadReport()
+              .toBuilder()
+              .setLoadReportInterval(Durations.fromNanos(interval))
+              .build();
       lrsRequestWriter.onNext(LoadStatsRequest.newBuilder()
           .setNode(Node.newBuilder()
               .setMetadata(Struct.newBuilder()
                   .putFields(
                       TRAFFICDIRECTOR_HOSTNAME_FIELD,
                       Value.newBuilder().setStringValue(serviceName).build())))
-          .addClusterStats(loadReportStore.generateLoadReport(Durations.fromNanos(interval)))
+          .addClusterStats(report)
           .build());
       scheduleNextLoadReport();
     }
@@ -348,4 +382,19 @@
       }
     }
   }
+
+  /**
+   * Interface for client side load stats store.
+   */
+  interface StatsStore {
+    ClusterStats generateLoadReport();
+
+    void addLocality(Locality locality);
+
+    void removeLocality(Locality locality);
+
+    StatsCounter getLocalityCounter(Locality locality);
+
+    void recordDroppedRequest(String category);
+  }
 }
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadReportStore.java b/xds/src/main/java/io/grpc/xds/XdsLoadReportStore.java
deleted file mode 100644
index 103b8fc..0000000
--- a/xds/src/main/java/io/grpc/xds/XdsLoadReportStore.java
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Copyright 2019 The gRPC Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.grpc.xds;
-
-import static com.google.common.base.Preconditions.checkNotNull;
-import static com.google.common.base.Preconditions.checkState;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.protobuf.Duration;
-import io.envoyproxy.envoy.api.v2.core.Locality;
-import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
-import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
-import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
-import io.grpc.ClientStreamTracer;
-import io.grpc.ClientStreamTracer.StreamInfo;
-import io.grpc.LoadBalancer.PickResult;
-import io.grpc.Metadata;
-import io.grpc.xds.XdsClientLoadRecorder.ClientLoadCounter;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
-import java.util.concurrent.atomic.AtomicLong;
-
-/**
- * An {@link XdsLoadReportStore} instance holds the client side load stats for a cluster.
- */
-final class XdsLoadReportStore {
-
-  private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER =
-      new ClientStreamTracer() {
-      };
-  private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
-      new ClientStreamTracer.Factory() {
-        @Override
-        public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
-          return NOOP_CLIENT_STREAM_TRACER;
-        }
-      };
-  private final String clusterName;
-  private final ConcurrentMap<Locality, ClientLoadCounter> localityLoadCounters;
-  // Cluster level dropped request counts for each category specified in the DropOverload policy.
-  private final ConcurrentMap<String, AtomicLong> dropCounters;
-
-  XdsLoadReportStore(String clusterName) {
-    this(clusterName, new ConcurrentHashMap<Locality, ClientLoadCounter>(),
-        new ConcurrentHashMap<String, AtomicLong>());
-  }
-
-  @VisibleForTesting
-  XdsLoadReportStore(String clusterName,
-      ConcurrentMap<Locality, ClientLoadCounter> localityLoadCounters,
-      ConcurrentMap<String, AtomicLong> dropCounters) {
-    this.clusterName = checkNotNull(clusterName, "clusterName");
-    this.localityLoadCounters = checkNotNull(localityLoadCounters, "localityLoadCounters");
-    this.dropCounters = checkNotNull(dropCounters, "dropCounters");
-  }
-
-  /**
-   * Generates a {@link ClusterStats} containing load stats in locality granularity.
-   * This method should be called in the same synchronized context that
-   * {@link XdsLoadBalancer#helper#getSynchronizationContext} returns.
-   */
-  ClusterStats generateLoadReport(Duration interval) {
-    ClusterStats.Builder statsBuilder = ClusterStats.newBuilder().setClusterName(clusterName)
-        .setLoadReportInterval(interval);
-    for (Map.Entry<Locality, XdsClientLoadRecorder.ClientLoadCounter> entry : localityLoadCounters
-        .entrySet()) {
-      XdsClientLoadRecorder.ClientLoadSnapshot snapshot = entry.getValue().snapshot();
-      statsBuilder
-          .addUpstreamLocalityStats(UpstreamLocalityStats.newBuilder()
-              .setLocality(entry.getKey())
-              .setTotalSuccessfulRequests(snapshot.callsSucceed)
-              .setTotalErrorRequests(snapshot.callsFailed)
-              .setTotalRequestsInProgress(snapshot.callsInProgress));
-      // Discard counters for localities that are no longer exposed by the remote balancer and
-      // no RPCs ongoing.
-      if (!entry.getValue().isActive() && snapshot.callsInProgress == 0) {
-        localityLoadCounters.remove(entry.getKey());
-      }
-    }
-    for (Map.Entry<String, AtomicLong> entry : dropCounters.entrySet()) {
-      statsBuilder.addDroppedRequests(DroppedRequests.newBuilder()
-          .setCategory(entry.getKey())
-          .setDroppedCount(entry.getValue().getAndSet(0)));
-    }
-    return statsBuilder.build();
-  }
-
-  /**
-   * Create a {@link ClientLoadCounter} for the provided locality or make it active if already in
-   * this {@link XdsLoadReportStore}. This method needs to be called at locality updates only for
-   * newly assigned localities in balancer discovery responses.
-   * This method should be called in the same synchronized context that
-   * {@link XdsLoadBalancer#helper#getSynchronizationContext} returns.
-   */
-  void addLocality(final Locality locality) {
-    ClientLoadCounter counter = localityLoadCounters.get(locality);
-    checkState(counter == null || !counter.isActive(),
-        "An active ClientLoadCounter for locality %s already exists", locality);
-    if (counter == null) {
-      localityLoadCounters.put(locality, new ClientLoadCounter());
-    } else {
-      counter.setActive(true);
-    }
-  }
-
-  /**
-   * Deactivate the {@link ClientLoadCounter} for the provided locality in by this
-   * {@link XdsLoadReportStore}. Inactive {@link ClientLoadCounter}s are for localities
-   * no longer exposed by the remote balancer. This method needs to be called at
-   * locality updates only for localities newly removed from balancer discovery responses.
-   * This method should be called in the same synchronized context that
-   * {@link XdsLoadBalancer#helper#getSynchronizationContext} returns.
-   */
-  void removeLocality(final Locality locality) {
-    ClientLoadCounter counter = localityLoadCounters.get(locality);
-    checkState(counter != null && counter.isActive(),
-        "No active ClientLoadCounter for locality %s exists", locality);
-    counter.setActive(false);
-  }
-
-  /**
-   * Intercepts a in-locality PickResult with load recording {@link ClientStreamTracer.Factory}.
-   */
-  PickResult interceptPickResult(PickResult pickResult, Locality locality) {
-    if (!pickResult.getStatus().isOk()) {
-      return pickResult;
-    }
-    XdsClientLoadRecorder.ClientLoadCounter counter = localityLoadCounters.get(locality);
-    if (counter == null) {
-      return pickResult;
-    }
-    ClientStreamTracer.Factory originFactory = pickResult.getStreamTracerFactory();
-    if (originFactory == null) {
-      originFactory = NOOP_CLIENT_STREAM_TRACER_FACTORY;
-    }
-    XdsClientLoadRecorder recorder = new XdsClientLoadRecorder(counter, originFactory);
-    return PickResult.withSubchannel(pickResult.getSubchannel(), recorder);
-  }
-
-  /**
-   * Record that a request has been dropped by drop overload policy with the provided category
-   * instructed by the remote balancer.
-   */
-  void recordDroppedRequest(String category) {
-    AtomicLong counter = dropCounters.get(category);
-    if (counter == null) {
-      counter = dropCounters.putIfAbsent(category, new AtomicLong());
-      if (counter == null) {
-        counter = dropCounters.get(category);
-      }
-    }
-    counter.getAndIncrement();
-  }
-}
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java b/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
new file mode 100644
index 0000000..318deda
--- /dev/null
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
@@ -0,0 +1,177 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Preconditions.checkState;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.envoyproxy.envoy.api.v2.core.Locality;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
+import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
+import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
+import io.grpc.xds.XdsLoadReportClientImpl.StatsStore;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.concurrent.NotThreadSafe;
+
+/**
+ * An {@link XdsLoadStatsStore} instance holds the client side load stats for a cluster.
+ */
+@NotThreadSafe
+final class XdsLoadStatsStore implements StatsStore {
+
+  private final String clusterName;
+  private final ConcurrentMap<Locality, StatsCounter> localityLoadCounters;
+  // Cluster level dropped request counts for each category specified in the DropOverload policy.
+  private final ConcurrentMap<String, AtomicLong> dropCounters;
+
+  XdsLoadStatsStore(String clusterName) {
+    this(clusterName, new ConcurrentHashMap<Locality, StatsCounter>(),
+        new ConcurrentHashMap<String, AtomicLong>());
+  }
+
+  @VisibleForTesting
+  XdsLoadStatsStore(String clusterName,
+      ConcurrentMap<Locality, StatsCounter> localityLoadCounters,
+      ConcurrentMap<String, AtomicLong> dropCounters) {
+    this.clusterName = checkNotNull(clusterName, "clusterName");
+    this.localityLoadCounters = checkNotNull(localityLoadCounters, "localityLoadCounters");
+    this.dropCounters = checkNotNull(dropCounters, "dropCounters");
+  }
+
+  /**
+   * Generates a {@link ClusterStats} containing client side load stats and backend metrics
+   * (if any) in locality granularity.
+   * This method should be called in the same synchronized context that
+   * {@link XdsLoadBalancer#helper#getSynchronizationContext} returns.
+   */
+  @Override
+  public ClusterStats generateLoadReport() {
+    ClusterStats.Builder statsBuilder = ClusterStats.newBuilder().setClusterName(clusterName);
+    for (Map.Entry<Locality, StatsCounter> entry : localityLoadCounters.entrySet()) {
+      ClientLoadSnapshot snapshot = entry.getValue().snapshot();
+      UpstreamLocalityStats.Builder localityStatsBuilder =
+          UpstreamLocalityStats.newBuilder().setLocality(entry.getKey());
+      localityStatsBuilder
+          .setTotalSuccessfulRequests(snapshot.getCallsFinished() - snapshot.getCallsFailed())
+          .setTotalErrorRequests(snapshot.getCallsFailed())
+          .setTotalRequestsInProgress(snapshot.getCallsInProgress());
+      statsBuilder.addUpstreamLocalityStats(localityStatsBuilder);
+      // Discard counters for localities that are no longer exposed by the remote balancer and
+      // no RPCs ongoing.
+      if (!entry.getValue().isActive() && snapshot.getCallsInProgress() == 0) {
+        localityLoadCounters.remove(entry.getKey());
+      }
+    }
+    for (Map.Entry<String, AtomicLong> entry : dropCounters.entrySet()) {
+      statsBuilder.addDroppedRequests(DroppedRequests.newBuilder()
+          .setCategory(entry.getKey())
+          .setDroppedCount(entry.getValue().getAndSet(0)));
+    }
+    return statsBuilder.build();
+  }
+
+  /**
+   * Create a {@link ClientLoadCounter} for the provided locality or make it active if already in
+   * this {@link XdsLoadStatsStore}. This method needs to be called at locality updates only for
+   * newly assigned localities in balancer discovery responses.
+   * This method should be called in the same synchronized context that
+   * {@link XdsLoadBalancer#helper#getSynchronizationContext} returns.
+   */
+  @Override
+  public void addLocality(final Locality locality) {
+    StatsCounter counter = localityLoadCounters.get(locality);
+    checkState(counter == null || !counter.isActive(),
+        "An active counter for locality %s already exists", locality);
+    if (counter == null) {
+      localityLoadCounters.put(locality, new ClientLoadCounter());
+    } else {
+      counter.setActive(true);
+    }
+  }
+
+  /**
+   * Deactivate the {@link StatsCounter} for the provided locality in by this
+   * {@link XdsLoadStatsStore}. Inactive {@link StatsCounter}s are for localities
+   * no longer exposed by the remote balancer. This method needs to be called at
+   * locality updates only for localities newly removed from balancer discovery responses.
+   * This method should be called in the same synchronized context that
+   * {@link XdsLoadBalancer#helper#getSynchronizationContext} returns.
+   */
+  @Override
+  public void removeLocality(final Locality locality) {
+    StatsCounter counter = localityLoadCounters.get(locality);
+    checkState(counter != null && counter.isActive(),
+        "No active counter for locality %s exists", locality);
+    counter.setActive(false);
+  }
+
+  /**
+   * Returns the {@link StatsCounter} instance that is responsible for aggregating load
+   * stats for the provided locality, or {@code null} if the locality is untracked.
+   */
+  @Override
+  public StatsCounter getLocalityCounter(final Locality locality) {
+    return localityLoadCounters.get(locality);
+  }
+
+  /**
+   * Record that a request has been dropped by drop overload policy with the provided category
+   * instructed by the remote balancer.
+   */
+  @Override
+  public void recordDroppedRequest(String category) {
+    AtomicLong counter = dropCounters.get(category);
+    if (counter == null) {
+      counter = dropCounters.putIfAbsent(category, new AtomicLong());
+      if (counter == null) {
+        counter = dropCounters.get(category);
+      }
+    }
+    counter.getAndIncrement();
+  }
+
+  /**
+   * Blueprint for counters that can can record number of calls in-progress, finished, failed.
+   */
+  abstract static class StatsCounter {
+
+    private boolean active = true;
+
+    abstract void incrementCallsInProgress();
+
+    abstract void decrementCallsInProgress();
+
+    abstract void incrementCallsFinished();
+
+    abstract void incrementCallsFailed();
+
+    abstract ClientLoadSnapshot snapshot();
+
+    boolean isActive() {
+      return active;
+    }
+
+    void setActive(boolean value) {
+      active = value;
+    }
+  }
+}
diff --git a/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java b/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
new file mode 100644
index 0000000..7bff11e
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
@@ -0,0 +1,129 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import io.grpc.ClientStreamTracer;
+import io.grpc.ClientStreamTracer.StreamInfo;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
+import io.grpc.xds.ClientLoadCounter.XdsClientLoadRecorder;
+import java.util.concurrent.ThreadLocalRandom;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link ClientLoadCounter}. */
+@RunWith(JUnit4.class)
+public class ClientLoadCounterTest {
+
+  private static final ClientStreamTracer.StreamInfo STREAM_INFO =
+      ClientStreamTracer.StreamInfo.newBuilder().build();
+  private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
+      new ClientStreamTracer.Factory() {
+        @Override
+        public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
+          return new ClientStreamTracer() {
+          };
+        }
+      };
+  private ClientLoadCounter counter;
+
+  @Before
+  public void setUp() {
+    counter = new ClientLoadCounter();
+    ClientLoadSnapshot emptySnapshot = counter.snapshot();
+    assertThat(emptySnapshot.getCallsInProgress()).isEqualTo(0);
+    assertThat(emptySnapshot.getCallsFinished()).isEqualTo(0);
+    assertThat(emptySnapshot.getCallsFailed()).isEqualTo(0);
+  }
+
+  @Test
+  public void snapshotContainsEverything() {
+    long numFinishedCalls = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE);
+    long numInProgressCalls = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE);
+    long numFailedCalls = ThreadLocalRandom.current().nextLong(numFinishedCalls);
+    counter = new ClientLoadCounter(numFinishedCalls, numInProgressCalls, numFailedCalls);
+    ClientLoadSnapshot snapshot = counter.snapshot();
+    assertThat(snapshot.getCallsFinished()).isEqualTo(numFinishedCalls);
+    assertThat(snapshot.getCallsInProgress()).isEqualTo(numInProgressCalls);
+    assertThat(snapshot.getCallsFailed()).isEqualTo(numFailedCalls);
+    String snapshotStr = snapshot.toString();
+    assertThat(snapshotStr).contains("callsFinished=" + numFinishedCalls);
+    assertThat(snapshotStr).contains("callsInProgress=" + numInProgressCalls);
+    assertThat(snapshotStr).contains("callsFailed=" + numFailedCalls);
+
+    // Snapshot only accounts for stats happening after previous snapshot.
+    snapshot = counter.snapshot();
+    assertThat(snapshot.getCallsFinished()).isEqualTo(0);
+    assertThat(snapshot.getCallsInProgress()).isEqualTo(numInProgressCalls);
+    assertThat(snapshot.getCallsFailed()).isEqualTo(0);
+
+    snapshotStr = snapshot.toString();
+    assertThat(snapshotStr).contains("callsFinished=0");
+    assertThat(snapshotStr).contains("callsInProgress=" + numInProgressCalls);
+    assertThat(snapshotStr).contains("callsFailed=0");
+  }
+
+  @Test
+  public void normalCountingOperations() {
+    ClientLoadSnapshot preSnapshot = counter.snapshot();
+    counter.incrementCallsInProgress();
+    ClientLoadSnapshot afterSnapshot = counter.snapshot();
+    assertThat(afterSnapshot.getCallsInProgress()).isEqualTo(preSnapshot.getCallsInProgress() + 1);
+    counter.decrementCallsInProgress();
+    afterSnapshot = counter.snapshot();
+    assertThat(afterSnapshot.getCallsInProgress()).isEqualTo(preSnapshot.getCallsInProgress());
+
+    counter.incrementCallsFinished();
+    afterSnapshot = counter.snapshot();
+    assertThat(afterSnapshot.getCallsFinished()).isEqualTo(1);
+
+    counter.incrementCallsFailed();
+    afterSnapshot = counter.snapshot();
+    assertThat(afterSnapshot.getCallsFailed()).isEqualTo(1);
+  }
+
+  @Test
+  public void xdsClientLoadRecorder_clientSideQueryCountsAggregation() {
+    XdsClientLoadRecorder recorder1 =
+        new XdsClientLoadRecorder(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
+    ClientStreamTracer tracer = recorder1.newClientStreamTracer(STREAM_INFO, new Metadata());
+    ClientLoadSnapshot snapshot = counter.snapshot();
+    assertThat(snapshot.getCallsFinished()).isEqualTo(0);
+    assertThat(snapshot.getCallsInProgress()).isEqualTo(1);
+    assertThat(snapshot.getCallsFailed()).isEqualTo(0);
+    tracer.streamClosed(Status.OK);
+    snapshot = counter.snapshot();
+    assertThat(snapshot.getCallsFinished()).isEqualTo(1);
+    assertThat(snapshot.getCallsInProgress()).isEqualTo(0);
+    assertThat(snapshot.getCallsFailed()).isEqualTo(0);
+
+    // Create a second XdsClientLoadRecorder with the same counter, stats are aggregated together.
+    XdsClientLoadRecorder recorder2 =
+        new XdsClientLoadRecorder(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
+    recorder1.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.ABORTED);
+    recorder2.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.CANCELLED);
+    snapshot = counter.snapshot();
+    assertThat(snapshot.getCallsFinished()).isEqualTo(2);
+    assertThat(snapshot.getCallsInProgress()).isEqualTo(0);
+    assertThat(snapshot.getCallsFailed()).isEqualTo(2);
+  }
+}
diff --git a/xds/src/test/java/io/grpc/xds/XdsLrsClientTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadReportClientImplTest.java
similarity index 78%
rename from xds/src/test/java/io/grpc/xds/XdsLrsClientTest.java
rename to xds/src/test/java/io/grpc/xds/XdsLoadReportClientImplTest.java
index 7610da4..4773a2c 100644
--- a/xds/src/test/java/io/grpc/xds/XdsLrsClientTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsLoadReportClientImplTest.java
@@ -21,6 +21,7 @@
 import static org.mockito.AdditionalAnswers.delegatesTo;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
@@ -28,6 +29,7 @@
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 
 import com.google.common.collect.Iterables;
@@ -43,8 +45,12 @@
 import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsRequest;
 import io.envoyproxy.envoy.service.load_stats.v2.LoadStatsResponse;
 import io.grpc.ChannelLogger;
+import io.grpc.ClientStreamTracer;
 import io.grpc.LoadBalancer.Helper;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.LoadBalancer.Subchannel;
 import io.grpc.ManagedChannel;
+import io.grpc.Metadata;
 import io.grpc.Status;
 import io.grpc.SynchronizationContext;
 import io.grpc.inprocess.InProcessChannelBuilder;
@@ -53,15 +59,11 @@
 import io.grpc.internal.FakeClock;
 import io.grpc.stub.StreamObserver;
 import io.grpc.testing.GrpcCleanupRule;
-import io.grpc.xds.XdsClientLoadRecorder.ClientLoadCounter;
+import io.grpc.xds.XdsLoadReportClientImpl.StatsStore;
 import java.text.MessageFormat;
 import java.util.ArrayDeque;
-import java.util.HashSet;
-import java.util.Random;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicLong;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Rule;
@@ -76,25 +78,37 @@
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
-/** Unit tests for {@link XdsLrsClient}. */
+/**
+ * Unit tests for {@link XdsLoadReportClientImpl}.
+ */
 @RunWith(JUnit4.class)
-public class XdsLrsClientTest {
+public class XdsLoadReportClientImplTest {
 
   private static final String SERVICE_AUTHORITY = "api.google.com";
   private static final FakeClock.TaskFilter LOAD_REPORTING_TASK_FILTER =
       new FakeClock.TaskFilter() {
         @Override
         public boolean shouldAccept(Runnable command) {
-          return command.toString().contains(XdsLrsClient.LoadReportingTask.class.getSimpleName());
+          return command.toString()
+              .contains(XdsLoadReportClientImpl.LoadReportingTask.class.getSimpleName());
         }
       };
   private static final FakeClock.TaskFilter LRS_RPC_RETRY_TASK_FILTER =
       new FakeClock.TaskFilter() {
         @Override
         public boolean shouldAccept(Runnable command) {
-          return command.toString().contains(XdsLrsClient.LrsRpcRetryTask.class.getSimpleName());
+          return command.toString()
+              .contains(XdsLoadReportClientImpl.LrsRpcRetryTask.class.getSimpleName());
         }
       };
+  private static final Locality TEST_LOCALITY =
+      Locality.newBuilder()
+          .setRegion("test_region")
+          .setZone("test_zone")
+          .setSubZone("test_subzone")
+          .build();
+  private static final ClientStreamTracer.StreamInfo STREAM_INFO =
+      ClientStreamTracer.StreamInfo.newBuilder().build();
   @Rule
   public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule();
   private final SynchronizationContext syncContext = new SynchronizationContext(
@@ -131,15 +145,19 @@
       .setNode(Node.newBuilder()
           .setMetadata(Struct.newBuilder()
               .putFields(
-                  XdsLrsClient.TRAFFICDIRECTOR_HOSTNAME_FIELD,
+                  XdsLoadReportClientImpl.TRAFFICDIRECTOR_HOSTNAME_FIELD,
                   Value.newBuilder().setStringValue(SERVICE_AUTHORITY).build())))
       .build();
   @Mock
   private BackoffPolicy backoffPolicy1;
   private ManagedChannel channel;
-  private XdsLrsClient lrsClient;
+  private XdsLoadReportClientImpl lrsClient;
   @Mock
   private BackoffPolicy backoffPolicy2;
+  @Mock
+  private Subchannel mockSubchannel;
+  @Mock
+  private StatsStore statsStore;
 
   private static ClusterStats buildEmptyClusterStats(long loadReportIntervalNanos) {
     return ClusterStats.newBuilder()
@@ -191,9 +209,10 @@
         .thenReturn(TimeUnit.SECONDS.toNanos(1L), TimeUnit.SECONDS.toNanos(10L));
     when(backoffPolicy2.nextBackoffNanos())
         .thenReturn(TimeUnit.SECONDS.toNanos(1L), TimeUnit.SECONDS.toNanos(10L));
-    logs.clear();
-    lrsClient = new XdsLrsClient(channel, helper, fakeClock.getStopwatchSupplier(),
-        backoffPolicyProvider, new XdsLoadReportStore(SERVICE_AUTHORITY));
+    lrsClient =
+        new XdsLoadReportClientImpl(channel, helper, fakeClock.getStopwatchSupplier(),
+            backoffPolicyProvider,
+            statsStore);
     lrsClient.startLoadReporting();
   }
 
@@ -210,28 +229,59 @@
     assertEquals(1, fakeClock.forwardTime(1, TimeUnit.NANOSECONDS));
     // A second load report is scheduled upon the first is sent.
     assertEquals(1, fakeClock.numPendingTasks(LOAD_REPORTING_TASK_FILTER));
+    inOrder.verify(statsStore).generateLoadReport();
     ArgumentCaptor<LoadStatsRequest> reportCaptor = ArgumentCaptor.forClass(null);
     inOrder.verify(requestObserver).onNext(reportCaptor.capture());
     LoadStatsRequest report = reportCaptor.getValue();
     assertEquals(report.getNode(), Node.newBuilder()
         .setMetadata(Struct.newBuilder()
             .putFields(
-                XdsLrsClient.TRAFFICDIRECTOR_HOSTNAME_FIELD,
+                XdsLoadReportClientImpl.TRAFFICDIRECTOR_HOSTNAME_FIELD,
                 Value.newBuilder().setStringValue(SERVICE_AUTHORITY).build()))
         .build());
     assertEquals(1, report.getClusterStatsCount());
-    assertClusterStatsEqual(expectedStats, report.getClusterStats(0));
+    assertThat(report.getClusterStats(0)).isEqualTo(expectedStats);
   }
 
-  private void assertClusterStatsEqual(ClusterStats stats1, ClusterStats stats2) {
-    assertEquals(stats1.getClusterName(), stats2.getClusterName());
-    assertEquals(stats1.getLoadReportInterval(), stats2.getLoadReportInterval());
-    assertEquals(stats1.getUpstreamLocalityStatsCount(), stats2.getUpstreamLocalityStatsCount());
-    assertEquals(stats1.getDroppedRequestsCount(), stats2.getDroppedRequestsCount());
-    assertEquals(new HashSet<>(stats1.getUpstreamLocalityStatsList()),
-        new HashSet<>(stats2.getUpstreamLocalityStatsList()));
-    assertEquals(new HashSet<>(stats1.getDroppedRequestsList()),
-        new HashSet<>(stats2.getDroppedRequestsList()));
+  @Test
+  public void loadNotRecordedForUntrackedLocality() {
+    when(statsStore.getLocalityCounter(TEST_LOCALITY)).thenReturn(null);
+    PickResult pickResult = PickResult.withSubchannel(mockSubchannel);
+    // If the per-locality counter does not exist, nothing should happen.
+    PickResult interceptedPickResult = lrsClient.interceptPickResult(pickResult, TEST_LOCALITY);
+    verify(statsStore).getLocalityCounter(TEST_LOCALITY);
+    assertThat(interceptedPickResult.getStreamTracerFactory()).isNull();
+  }
+
+  @Test
+  public void invalidPickResultNotIntercepted() {
+    PickResult errorResult = PickResult.withError(Status.UNAVAILABLE.withDescription("Error"));
+    PickResult droppedResult = PickResult.withDrop(Status.UNAVAILABLE.withDescription("Dropped"));
+    // TODO (chengyuanzhang): for NoResult PickResult, do we still intercept?
+    PickResult interceptedErrorResult = lrsClient.interceptPickResult(errorResult, TEST_LOCALITY);
+    PickResult interceptedDroppedResult =
+        lrsClient.interceptPickResult(droppedResult, TEST_LOCALITY);
+    assertThat(interceptedErrorResult.getStreamTracerFactory()).isNull();
+    assertThat(interceptedDroppedResult.getStreamTracerFactory()).isNull();
+    verifyZeroInteractions(statsStore);
+  }
+
+  @Test
+  public void interceptPreservesOriginStreamTracer() {
+    ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
+    ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
+    when(mockFactory
+        .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
+        .thenReturn(mockTracer);
+    when(statsStore.getLocalityCounter(TEST_LOCALITY)).thenReturn(new ClientLoadCounter());
+    PickResult pickResult = PickResult.withSubchannel(mockSubchannel, mockFactory);
+    PickResult interceptedPickResult = lrsClient.interceptPickResult(pickResult, TEST_LOCALITY);
+    verify(statsStore).getLocalityCounter(TEST_LOCALITY);
+    Metadata metadata = new Metadata();
+    interceptedPickResult.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, metadata)
+        .streamClosed(Status.OK);
+    verify(mockFactory).newClientStreamTracer(same(STREAM_INFO), same(metadata));
+    verify(mockTracer).streamClosed(Status.OK);
   }
 
   @Test
@@ -252,7 +302,9 @@
     StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
     assertThat(lrsRequestObservers).hasSize(1);
     StreamObserver<LoadStatsRequest> requestObserver = lrsRequestObservers.poll();
-    InOrder inOrder = inOrder(requestObserver);
+    when(statsStore.generateLoadReport())
+        .thenReturn(ClusterStats.newBuilder().setClusterName(SERVICE_AUTHORITY).build());
+    InOrder inOrder = inOrder(requestObserver, statsStore);
     inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
     assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
     logs.poll();
@@ -269,7 +321,11 @@
     StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
     assertThat(lrsRequestObservers).hasSize(1);
     StreamObserver<LoadStatsRequest> requestObserver = lrsRequestObservers.poll();
-    InOrder inOrder = inOrder(requestObserver);
+
+    when(statsStore.generateLoadReport())
+        .thenReturn(ClusterStats.newBuilder().setClusterName(SERVICE_AUTHORITY).build());
+
+    InOrder inOrder = inOrder(requestObserver, statsStore);
     inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
     assertThat(logs).containsExactly("DEBUG: Initial LRS request sent: " + EXPECTED_INITIAL_REQ);
     logs.poll();
@@ -289,54 +345,27 @@
 
   @Test
   public void reportRecordedLoadData() {
-    lrsClient.stopLoadReporting();
     verify(mockLoadReportingService).streamLoadStats(lrsResponseObserverCaptor.capture());
-    lrsRequestObservers.clear();
-
-    ConcurrentMap<Locality, XdsClientLoadRecorder.ClientLoadCounter> localityCounters =
-        new ConcurrentHashMap<>();
-    ConcurrentMap<String, AtomicLong> dropCounters = new ConcurrentHashMap<>();
-    XdsLoadReportStore loadReportStore =
-        new XdsLoadReportStore(SERVICE_AUTHORITY, localityCounters, dropCounters);
-    lrsClient = new XdsLrsClient(channel, helper, fakeClock.getStopwatchSupplier(),
-        backoffPolicyProvider, loadReportStore);
-    lrsClient.startLoadReporting();
-
-    verify(mockLoadReportingService, times(2)).streamLoadStats(lrsResponseObserverCaptor.capture());
     StreamObserver<LoadStatsResponse> responseObserver = lrsResponseObserverCaptor.getValue();
     assertThat(lrsRequestObservers).hasSize(1);
     StreamObserver<LoadStatsRequest> requestObserver = lrsRequestObservers.poll();
-    InOrder inOrder = inOrder(requestObserver);
+    InOrder inOrder = inOrder(requestObserver, statsStore);
     inOrder.verify(requestObserver).onNext(EXPECTED_INITIAL_REQ);
 
-    Locality locality = Locality.newBuilder()
-        .setRegion("test_region")
-        .setZone("test_zone")
-        .setSubZone("test_subzone")
-        .build();
-    Random rand = new Random();
-    // Integer range is large enough for testing.
-    long callsInProgress1 = rand.nextInt(Integer.MAX_VALUE);
-    long callsFinished1 = rand.nextInt(Integer.MAX_VALUE);
-    long callsFailed1 = callsFinished1 - rand.nextInt((int) callsFinished1);
-    localityCounters.put(locality,
-        new ClientLoadCounter(callsInProgress1, callsFinished1, callsFailed1));
+    long callsInProgress = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE);
+    long callsFinished = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE);
+    long callsFailed = callsFinished - ThreadLocalRandom.current().nextLong(callsFinished);
+    long numLbDrops = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE);
+    long numThrottleDrops = ThreadLocalRandom.current().nextLong(Long.MAX_VALUE);
 
-    long numLbDrops = rand.nextLong();
-    long numThrottleDrops = rand.nextLong();
-    dropCounters.put("lb", new AtomicLong(numLbDrops));
-    dropCounters.put("throttle", new AtomicLong(numThrottleDrops));
-
-    responseObserver.onNext(buildLrsResponse(1362));
-
-    ClusterStats expectedStats = ClusterStats.newBuilder()
+    ClusterStats expectedStats1 = ClusterStats.newBuilder()
         .setClusterName(SERVICE_AUTHORITY)
         .setLoadReportInterval(Durations.fromNanos(1362))
         .addUpstreamLocalityStats(UpstreamLocalityStats.newBuilder()
-            .setLocality(locality)
-            .setTotalRequestsInProgress(callsInProgress1)
-            .setTotalSuccessfulRequests(callsFinished1 - callsFailed1)
-            .setTotalErrorRequests(callsFailed1))
+            .setLocality(TEST_LOCALITY)
+            .setTotalRequestsInProgress(callsInProgress)
+            .setTotalSuccessfulRequests(callsFinished - callsFailed)
+            .setTotalErrorRequests(callsFailed))
         .addDroppedRequests(DroppedRequests.newBuilder()
             .setCategory("lb")
             .setDroppedCount(numLbDrops))
@@ -344,16 +373,12 @@
             .setCategory("throttle")
             .setDroppedCount(numThrottleDrops))
         .build();
-    assertNextReport(inOrder, requestObserver, expectedStats);
-
-    // No client load happens upon next load reporting, only number of in-progress
-    // calls are non-zero.
-    expectedStats = ClusterStats.newBuilder()
+    ClusterStats expectedStats2 = ClusterStats.newBuilder()
         .setClusterName(SERVICE_AUTHORITY)
         .setLoadReportInterval(Durations.fromNanos(1362))
         .addUpstreamLocalityStats(UpstreamLocalityStats.newBuilder()
-            .setLocality(locality)
-            .setTotalRequestsInProgress(callsInProgress1))
+            .setLocality(TEST_LOCALITY)
+            .setTotalRequestsInProgress(callsInProgress))
         .addDroppedRequests(DroppedRequests.newBuilder()
             .setCategory("lb")
             .setDroppedCount(0))
@@ -361,7 +386,13 @@
             .setCategory("throttle")
             .setDroppedCount(0))
         .build();
-    assertNextReport(inOrder, requestObserver, expectedStats);
+    when(statsStore.generateLoadReport())
+        .thenReturn(expectedStats1, expectedStats2);
+
+    responseObserver.onNext(buildLrsResponse(1362));
+    assertNextReport(inOrder, requestObserver, expectedStats1);
+
+    assertNextReport(inOrder, requestObserver, expectedStats2);
   }
 
   @Test
diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadReportStoreTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadReportStoreTest.java
deleted file mode 100644
index 8f896f3..0000000
--- a/xds/src/test/java/io/grpc/xds/XdsLoadReportStoreTest.java
+++ /dev/null
@@ -1,299 +0,0 @@
-/*
- * Copyright 2019 The gRPC Authors
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package io.grpc.xds;
-
-import static com.google.common.truth.Truth.assertThat;
-import static org.junit.Assert.assertEquals;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-import com.google.protobuf.Duration;
-import io.envoyproxy.envoy.api.v2.core.Locality;
-import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
-import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
-import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
-import io.grpc.ClientStreamTracer;
-import io.grpc.LoadBalancer.PickResult;
-import io.grpc.LoadBalancer.Subchannel;
-import io.grpc.Metadata;
-import io.grpc.Status;
-import io.grpc.xds.XdsClientLoadRecorder.ClientLoadCounter;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Random;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentMap;
-import java.util.concurrent.atomic.AtomicLong;
-import javax.annotation.Nullable;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.JUnit4;
-import org.mockito.ArgumentCaptor;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-
-/** Unit tests for {@link XdsLoadReportStore}. */
-@RunWith(JUnit4.class)
-public class XdsLoadReportStoreTest {
-  private static final String SERVICE_NAME = "api.google.com";
-  private static final ClientStreamTracer.StreamInfo STREAM_INFO =
-      ClientStreamTracer.StreamInfo.newBuilder().build();
-  private static final Locality TEST_LOCALITY = Locality.newBuilder()
-      .setRegion("test_region")
-      .setZone("test_zone")
-      .setSubZone("test_subzone")
-      .build();
-  @Mock Subchannel fakeSubchannel;
-  private ConcurrentMap<Locality, ClientLoadCounter> localityLoadCounters;
-  private ConcurrentMap<String, AtomicLong> dropCounters;
-  private XdsLoadReportStore loadStore;
-
-  @Before
-  public void setUp() {
-    MockitoAnnotations.initMocks(this);
-    localityLoadCounters = new ConcurrentHashMap<>();
-    dropCounters = new ConcurrentHashMap<>();
-    loadStore = new XdsLoadReportStore(SERVICE_NAME, localityLoadCounters, dropCounters);
-  }
-
-  @Test
-  public void loadNotRecordedForUntrackedLocality() {
-    PickResult pickResult = PickResult.withSubchannel(fakeSubchannel);
-    // XdsClientLoadStore does not record loads for untracked localities.
-    PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, TEST_LOCALITY);
-    assertThat(localityLoadCounters).hasSize(0);
-    assertThat(interceptedPickResult.getStreamTracerFactory()).isNull();
-  }
-
-  @Test
-  public void invalidPickResultNotIntercepted() {
-    PickResult errorResult = PickResult.withError(Status.UNAVAILABLE.withDescription("Error"));
-    PickResult emptyResult = PickResult.withNoResult();
-    PickResult droppedResult = PickResult.withDrop(Status.UNAVAILABLE.withDescription("Dropped"));
-    PickResult interceptedErrorResult = loadStore.interceptPickResult(errorResult, TEST_LOCALITY);
-    PickResult interceptedEmptyResult = loadStore.interceptPickResult(emptyResult, TEST_LOCALITY);
-    PickResult interceptedDroppedResult = loadStore
-        .interceptPickResult(droppedResult, TEST_LOCALITY);
-    assertThat(localityLoadCounters).hasSize(0);
-    assertThat(interceptedErrorResult.getStreamTracerFactory()).isNull();
-    assertThat(interceptedEmptyResult.getStreamTracerFactory()).isNull();
-    assertThat(interceptedDroppedResult.getStreamTracerFactory()).isNull();
-  }
-
-  @Test
-  public void interceptPreservesOriginStreamTracer() {
-    loadStore.addLocality(TEST_LOCALITY);
-    ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
-    ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
-    when(mockFactory
-        .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
-        .thenReturn(mockTracer);
-    PickResult pickResult = PickResult.withSubchannel(fakeSubchannel, mockFactory);
-    PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, TEST_LOCALITY);
-    Metadata metadata = new Metadata();
-    interceptedPickResult.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, metadata)
-        .streamClosed(Status.OK);
-    ArgumentCaptor<ClientStreamTracer.StreamInfo> streamInfoArgumentCaptor = ArgumentCaptor
-        .forClass(null);
-    ArgumentCaptor<Metadata> metadataArgumentCaptor = ArgumentCaptor.forClass(null);
-    verify(mockFactory).newClientStreamTracer(streamInfoArgumentCaptor.capture(),
-        metadataArgumentCaptor.capture());
-    assertThat(streamInfoArgumentCaptor.getValue()).isSameInstanceAs(STREAM_INFO);
-    assertThat(metadataArgumentCaptor.getValue()).isSameInstanceAs(metadata);
-    verify(mockTracer).streamClosed(Status.OK);
-  }
-
-  @Test
-  public void loadStatsRecording() {
-    Locality locality1 =
-        Locality.newBuilder()
-            .setRegion("test_region1")
-            .setZone("test_zone")
-            .setSubZone("test_subzone")
-            .build();
-    loadStore.addLocality(locality1);
-    PickResult pickResult1 = PickResult.withSubchannel(fakeSubchannel);
-    PickResult interceptedPickResult1 = loadStore.interceptPickResult(pickResult1, locality1);
-    assertThat(interceptedPickResult1.getSubchannel()).isSameInstanceAs(fakeSubchannel);
-    assertThat(localityLoadCounters).containsKey(locality1);
-    ClientStreamTracer tracer =
-        interceptedPickResult1
-            .getStreamTracerFactory()
-            .newClientStreamTracer(STREAM_INFO, new Metadata());
-    Duration interval = Duration.newBuilder().setNanos(342).build();
-    ClusterStats expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(locality1, 0, 0, 1)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    // Make another load report should not reset count for calls in progress.
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    tracer.streamClosed(Status.OK);
-    expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(locality1, 1, 0, 0)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    // Make another load report should reset finished calls count for calls finished.
-    expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(locality1, 0, 0, 0)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    // PickResult within the same locality should aggregate to the same counter.
-    PickResult pickResult2 = PickResult.withSubchannel(fakeSubchannel);
-    PickResult interceptedPickResult2 = loadStore.interceptPickResult(pickResult2, locality1);
-    assertThat(localityLoadCounters).hasSize(1);
-    interceptedPickResult1
-        .getStreamTracerFactory()
-        .newClientStreamTracer(STREAM_INFO, new Metadata())
-        .streamClosed(Status.ABORTED);
-    interceptedPickResult2
-        .getStreamTracerFactory()
-        .newClientStreamTracer(STREAM_INFO, new Metadata())
-        .streamClosed(Status.CANCELLED);
-    expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(locality1, 0, 2, 0)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(locality1, 0, 0, 0)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    Locality locality2 =
-        Locality.newBuilder()
-            .setRegion("test_region2")
-            .setZone("test_zone")
-            .setSubZone("test_subzone")
-            .build();
-    loadStore.addLocality(locality2);
-    PickResult pickResult3 = PickResult.withSubchannel(fakeSubchannel);
-    PickResult interceptedPickResult3 = loadStore.interceptPickResult(pickResult3, locality2);
-    assertThat(localityLoadCounters).containsKey(locality2);
-    assertThat(localityLoadCounters).hasSize(2);
-    interceptedPickResult3
-        .getStreamTracerFactory()
-        .newClientStreamTracer(STREAM_INFO, new Metadata());
-    List<UpstreamLocalityStats> upstreamLocalityStatsList =
-        Arrays.asList(buildUpstreamLocalityStats(locality1, 0, 0, 0),
-            buildUpstreamLocalityStats(locality2, 0, 0, 1));
-    expectedLoadReport = buildClusterStats(interval, upstreamLocalityStatsList, null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-  }
-
-  @Test
-  public void loadRecordingForRemovedLocality() {
-    loadStore.addLocality(TEST_LOCALITY);
-    assertThat(localityLoadCounters).containsKey(TEST_LOCALITY);
-    PickResult pickResult = PickResult.withSubchannel(fakeSubchannel);
-    PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, TEST_LOCALITY);
-    ClientStreamTracer tracer = interceptedPickResult
-        .getStreamTracerFactory()
-        .newClientStreamTracer(STREAM_INFO, new Metadata());
-
-    Duration interval = Duration.newBuilder().setNanos(342).build();
-    ClusterStats expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(TEST_LOCALITY, 0, 0, 1)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-    // Remote balancer instructs to remove the locality while client has in-progress calls
-    // to backends in the locality, the XdsClientLoadStore continues tracking its load stats.
-    loadStore.removeLocality(TEST_LOCALITY);
-    assertThat(localityLoadCounters).containsKey(TEST_LOCALITY);
-    expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(TEST_LOCALITY, 0, 0, 1)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-
-    tracer.streamClosed(Status.OK);
-    expectedLoadReport = buildClusterStats(interval,
-        Collections.singletonList(buildUpstreamLocalityStats(TEST_LOCALITY, 1, 0, 0)), null);
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-    assertThat(localityLoadCounters).doesNotContainKey(TEST_LOCALITY);
-  }
-
-  @Test
-  public void recordingDroppedRequests() {
-    Random rand = new Random();
-    int numLbDrop = rand.nextInt(1000);
-    int numThrottleDrop = rand.nextInt(1000);
-    for (int i = 0; i < numLbDrop; i++) {
-      loadStore.recordDroppedRequest("lb");
-    }
-    for (int i = 0; i < numThrottleDrop; i++) {
-      loadStore.recordDroppedRequest("throttle");
-    }
-    Duration interval = Duration.newBuilder().setNanos(342).build();
-    ClusterStats expectedLoadReport = buildClusterStats(interval,
-        null,
-        Arrays.asList(
-            DroppedRequests.newBuilder()
-                .setCategory("lb")
-                .setDroppedCount(numLbDrop)
-                .build(),
-            DroppedRequests.newBuilder()
-                .setCategory("throttle")
-                .setDroppedCount(numThrottleDrop)
-                .build()
-        ));
-    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport(interval));
-    assertEquals(0, dropCounters.get("lb").get());
-    assertEquals(0, dropCounters.get("throttle").get());
-  }
-
-  private UpstreamLocalityStats buildUpstreamLocalityStats(Locality locality, long callsSucceed,
-      long callsFailed, long callsInProgress) {
-    return UpstreamLocalityStats.newBuilder()
-        .setLocality(locality)
-        .setTotalSuccessfulRequests(callsSucceed)
-        .setTotalErrorRequests(callsFailed)
-        .setTotalRequestsInProgress(callsInProgress)
-        .build();
-  }
-
-  private ClusterStats buildClusterStats(Duration interval,
-      @Nullable List<UpstreamLocalityStats> upstreamLocalityStatsList,
-      @Nullable List<DroppedRequests> droppedRequestsList) {
-    ClusterStats.Builder clusterStatsBuilder = ClusterStats.newBuilder()
-        .setClusterName(SERVICE_NAME)
-        .setLoadReportInterval(interval);
-    if (upstreamLocalityStatsList != null) {
-      clusterStatsBuilder.addAllUpstreamLocalityStats(upstreamLocalityStatsList);
-    }
-    if (droppedRequestsList != null) {
-      long dropCount = 0;
-      for (DroppedRequests drop : droppedRequestsList) {
-        dropCount += drop.getDroppedCount();
-        clusterStatsBuilder.addDroppedRequests(drop);
-      }
-      clusterStatsBuilder.setTotalDroppedRequests(dropCount);
-    }
-    return clusterStatsBuilder.build();
-  }
-
-  private void assertClusterStatsEqual(ClusterStats stats1, ClusterStats stats2) {
-    assertEquals(stats1.getClusterName(), stats2.getClusterName());
-    assertEquals(stats1.getLoadReportInterval(), stats2.getLoadReportInterval());
-    assertEquals(stats1.getUpstreamLocalityStatsCount(), stats2.getUpstreamLocalityStatsCount());
-    assertEquals(stats1.getDroppedRequestsCount(), stats2.getDroppedRequestsCount());
-    assertEquals(new HashSet<>(stats1.getUpstreamLocalityStatsList()),
-        new HashSet<>(stats2.getUpstreamLocalityStatsList()));
-    assertEquals(new HashSet<>(stats1.getDroppedRequestsList()),
-        new HashSet<>(stats2.getDroppedRequestsList()));
-  }
-}
diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
new file mode 100644
index 0000000..c2e1443
--- /dev/null
+++ b/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
@@ -0,0 +1,275 @@
+/*
+ * Copyright 2019 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.xds;
+
+import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import io.envoyproxy.envoy.api.v2.core.Locality;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
+import io.envoyproxy.envoy.api.v2.endpoint.EndpointLoadMetricStats;
+import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
+import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
+import io.grpc.xds.XdsLoadStatsStore.StatsCounter;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicLong;
+import javax.annotation.Nullable;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit tests for {@link XdsLoadStatsStore}. */
+@RunWith(JUnit4.class)
+public class XdsLoadStatsStoreTest {
+  private static final String SERVICE_NAME = "api.google.com";
+  private static final Locality LOCALITY1 =
+      Locality.newBuilder()
+          .setRegion("test_region1")
+          .setZone("test_zone")
+          .setSubZone("test_subzone")
+          .build();
+  private static final Locality LOCALITY2 =
+      Locality.newBuilder()
+          .setRegion("test_region2")
+          .setZone("test_zone")
+          .setSubZone("test_subzone")
+          .build();
+  private ConcurrentMap<Locality, StatsCounter> localityLoadCounters;
+  private ConcurrentMap<String, AtomicLong> dropCounters;
+  private XdsLoadStatsStore loadStore;
+
+  @Before
+  public void setUp() {
+    localityLoadCounters = new ConcurrentHashMap<>();
+    dropCounters = new ConcurrentHashMap<>();
+    loadStore = new XdsLoadStatsStore(SERVICE_NAME, localityLoadCounters, dropCounters);
+  }
+
+  private static UpstreamLocalityStats buildUpstreamLocalityStats(Locality locality,
+      long callsSucceed,
+      long callsInProgress,
+      long callsFailed,
+      @Nullable List<EndpointLoadMetricStats> metrics) {
+    UpstreamLocalityStats.Builder builder =
+        UpstreamLocalityStats.newBuilder()
+            .setLocality(locality)
+            .setTotalSuccessfulRequests(callsSucceed)
+            .setTotalErrorRequests(callsFailed)
+            .setTotalRequestsInProgress(callsInProgress);
+    if (metrics != null) {
+      builder.addAllLoadMetricStats(metrics);
+    }
+    return builder.build();
+  }
+
+  private static DroppedRequests buildDroppedRequests(String category, long counts) {
+    return DroppedRequests.newBuilder()
+        .setCategory(category)
+        .setDroppedCount(counts)
+        .build();
+  }
+
+  private static ClusterStats buildClusterStats(
+      @Nullable List<UpstreamLocalityStats> upstreamLocalityStatsList,
+      @Nullable List<DroppedRequests> droppedRequestsList) {
+    ClusterStats.Builder clusterStatsBuilder = ClusterStats.newBuilder()
+        .setClusterName(SERVICE_NAME);
+    if (upstreamLocalityStatsList != null) {
+      clusterStatsBuilder.addAllUpstreamLocalityStats(upstreamLocalityStatsList);
+    }
+    if (droppedRequestsList != null) {
+      long dropCount = 0;
+      for (DroppedRequests drop : droppedRequestsList) {
+        dropCount += drop.getDroppedCount();
+        clusterStatsBuilder.addDroppedRequests(drop);
+      }
+      clusterStatsBuilder.setTotalDroppedRequests(dropCount);
+    }
+    return clusterStatsBuilder.build();
+  }
+
+  private static void assertClusterStatsEqual(ClusterStats expected, ClusterStats actual) {
+    assertThat(actual.getClusterName()).isEqualTo(expected.getClusterName());
+    assertThat(actual.getLoadReportInterval()).isEqualTo(expected.getLoadReportInterval());
+    assertThat(actual.getDroppedRequestsCount()).isEqualTo(expected.getDroppedRequestsCount());
+    assertThat(new HashSet<>(actual.getDroppedRequestsList()))
+        .isEqualTo(new HashSet<>(expected.getDroppedRequestsList()));
+    assertUpstreamLocalityStatsListsEqual(actual.getUpstreamLocalityStatsList(),
+        expected.getUpstreamLocalityStatsList());
+  }
+
+  private static void assertUpstreamLocalityStatsListsEqual(List<UpstreamLocalityStats> expected,
+      List<UpstreamLocalityStats> actual) {
+    assertThat(actual.size()).isEqualTo(expected.size());
+    Map<Locality, UpstreamLocalityStats> expectedLocalityStats = new HashMap<>();
+    for (UpstreamLocalityStats stats : expected) {
+      expectedLocalityStats.put(stats.getLocality(), stats);
+    }
+    for (UpstreamLocalityStats stats : actual) {
+      UpstreamLocalityStats expectedStats = expectedLocalityStats.get(stats.getLocality());
+      assertThat(expectedStats).isNotNull();
+      assertUpstreamLocalityStatsEqual(stats, expectedStats);
+    }
+  }
+
+  private static void assertUpstreamLocalityStatsEqual(UpstreamLocalityStats expected,
+      UpstreamLocalityStats actual) {
+    assertThat(actual.getLocality()).isEqualTo(expected.getLocality());
+    assertThat(actual.getTotalSuccessfulRequests())
+        .isEqualTo(expected.getTotalSuccessfulRequests());
+    assertThat(actual.getTotalRequestsInProgress())
+        .isEqualTo(expected.getTotalRequestsInProgress());
+    assertThat(actual.getTotalErrorRequests()).isEqualTo(expected.getTotalErrorRequests());
+    assertThat(new HashSet<>(actual.getLoadMetricStatsList()))
+        .isEqualTo(new HashSet<>(expected.getLoadMetricStatsList()));
+  }
+
+  @Test
+  public void addAndGetAndRemoveLocality() {
+    loadStore.addLocality(LOCALITY1);
+    assertThat(localityLoadCounters).containsKey(LOCALITY1);
+
+    // Adding the same locality counter again causes an exception.
+    try {
+      loadStore.addLocality(LOCALITY1);
+      Assert.fail();
+    } catch (IllegalStateException expected) {
+      assertThat(expected).hasMessageThat()
+          .contains("An active counter for locality " + LOCALITY1 + " already exists");
+    }
+
+    assertThat(loadStore.getLocalityCounter(LOCALITY1))
+        .isSameInstanceAs(localityLoadCounters.get(LOCALITY1));
+    assertThat(loadStore.getLocalityCounter(LOCALITY2)).isNull();
+
+    // Removing an non-existing locality counter causes an exception.
+    try {
+      loadStore.removeLocality(LOCALITY2);
+      Assert.fail();
+    } catch (IllegalStateException expected) {
+      assertThat(expected).hasMessageThat()
+          .contains("No active counter for locality " + LOCALITY2 + " exists");
+    }
+
+    // Removing the locality counter only mark it as inactive, but not throw it away.
+    loadStore.removeLocality(LOCALITY1);
+    assertThat(localityLoadCounters.get(LOCALITY1).isActive()).isFalse();
+
+    // Removing an inactive locality counter causes an exception.
+    try {
+      loadStore.removeLocality(LOCALITY1);
+      Assert.fail();
+    } catch (IllegalStateException expected) {
+      assertThat(expected).hasMessageThat()
+          .contains("No active counter for locality " + LOCALITY1 + " exists");
+    }
+
+    // Adding it back simply mark it as active again.
+    loadStore.addLocality(LOCALITY1);
+    assertThat(localityLoadCounters.get(LOCALITY1).isActive()).isTrue();
+  }
+
+  @Test
+  public void removeInactiveCountersAfterGeneratingLoadReport() {
+    StatsCounter counter1 = mock(StatsCounter.class);
+    when(counter1.isActive()).thenReturn(true);
+    when(counter1.snapshot()).thenReturn(ClientLoadSnapshot.EMPTY_SNAPSHOT);
+    StatsCounter counter2 = mock(StatsCounter.class);
+    when(counter2.isActive()).thenReturn(false);
+    when(counter2.snapshot()).thenReturn(ClientLoadSnapshot.EMPTY_SNAPSHOT);
+    localityLoadCounters.put(LOCALITY1, counter1);
+    localityLoadCounters.put(LOCALITY2, counter2);
+    loadStore.generateLoadReport();
+    assertThat(localityLoadCounters).containsKey(LOCALITY1);
+    assertThat(localityLoadCounters).doesNotContainKey(LOCALITY2);
+  }
+
+  @Test
+  public void loadReportMatchesSnapshots() {
+    StatsCounter counter1 = mock(StatsCounter.class);
+    when(counter1.isActive()).thenReturn(true);
+    when(counter1.snapshot())
+        .thenReturn(new ClientLoadSnapshot(4315, 3421, 23),
+            new ClientLoadSnapshot(0, 543, 0));
+    StatsCounter counter2 = mock(StatsCounter.class);
+    when(counter2.snapshot()).thenReturn(new ClientLoadSnapshot(41234, 432, 431),
+        new ClientLoadSnapshot(0, 432, 0));
+    when(counter2.isActive()).thenReturn(true);
+    localityLoadCounters.put(LOCALITY1, counter1);
+    localityLoadCounters.put(LOCALITY2, counter2);
+
+    ClusterStats expectedReport =
+        buildClusterStats(
+            Arrays.asList(
+                buildUpstreamLocalityStats(LOCALITY1, 4315 - 23, 3421, 23, null),
+                buildUpstreamLocalityStats(LOCALITY2, 41234 - 431, 432, 431, null)
+            ),
+            null);
+
+    assertClusterStatsEqual(expectedReport, loadStore.generateLoadReport());
+    verify(counter1).snapshot();
+    verify(counter2).snapshot();
+
+    expectedReport =
+        buildClusterStats(
+            Arrays.asList(
+                buildUpstreamLocalityStats(LOCALITY1, 0, 543, 0,
+                    null),
+                buildUpstreamLocalityStats(LOCALITY2, 0, 432, 0,
+                    null)
+            ),
+            null);
+    assertClusterStatsEqual(expectedReport, loadStore.generateLoadReport());
+    verify(counter1, times(2)).snapshot();
+    verify(counter2, times(2)).snapshot();
+  }
+
+  @Test
+  public void recordingDroppedRequests() {
+    Random rand = new Random();
+    int numLbDrop = rand.nextInt(1000);
+    int numThrottleDrop = rand.nextInt(1000);
+    for (int i = 0; i < numLbDrop; i++) {
+      loadStore.recordDroppedRequest("lb");
+    }
+    for (int i = 0; i < numThrottleDrop; i++) {
+      loadStore.recordDroppedRequest("throttle");
+    }
+    assertThat(dropCounters.get("lb").get()).isEqualTo(numLbDrop);
+    assertThat(dropCounters.get("throttle").get()).isEqualTo(numThrottleDrop);
+    ClusterStats expectedLoadReport =
+        buildClusterStats(null,
+            Arrays.asList(buildDroppedRequests("lb", numLbDrop),
+                buildDroppedRequests("throttle", numThrottleDrop)));
+    assertClusterStatsEqual(expectedLoadReport, loadStore.generateLoadReport());
+    assertThat(dropCounters.get("lb").get()).isEqualTo(0);
+    assertThat(dropCounters.get("throttle").get()).isEqualTo(0);
+  }
+}