xds: redesign client load recording and backend metrics receiving interface (#5903)

* Redefined StatsStore interface.
- Removed interface method StatsStore#interceptPickResult and implementation do not take the resposibility intercepting PickResult with locality-level load recording.
- Introduce a wrapper class for SubchannelPicker to let users wrap SubchannelPicker by themselves, with client side load recording logic.
- Associate the corresponding locality counter with child helper when it is created, child helper will intercept the SubchannelPicker it creates.

* Renamed backend metrics listener class to be more abstract, hides the implementation detail of doing locality-level aggregation.

* Integrate client load recording and backend metrics recording with xDS load balancer.

- Created LoadRecordingSubchannelPicker class for applying XdsClientLoadRecorder that records client load to PickResult.
- Created MetricsObservingSubchannel class for applying OrcaReportingTracerFactory that takes listener to receive ORCA reports to PickResult.
- In xDS load balancer LocalityStore, the original picker is wrapped two layers inside the above wrappers.

* Renamed XdsClientLoadRecorder to ClientLoadRecorder. It should only be used for testing, xDS load balancer should use SubchannelPicker wrappers instead of this load recorder directly.

* Removed redudent layer of wrapping for SubchannelPicker in LocalityStore

* Added toString for SubchannelPicker wrapper classes.

* Rename ClientLoadRecorder to LoadRecordingStreamTracerFactory.

* Renamed StreamInstrumentedSubchannelPicker to TracerWrappingSubchannelPicker.

* Eliminate duplicated code in LocalityStoreTest, put them into a loop.
diff --git a/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java b/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
index e9c445a..c574a7a 100644
--- a/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
+++ b/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
@@ -23,6 +23,9 @@
 import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport;
 import io.grpc.ClientStreamTracer;
 import io.grpc.ClientStreamTracer.StreamInfo;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.LoadBalancer.PickSubchannelArgs;
+import io.grpc.LoadBalancer.SubchannelPicker;
 import io.grpc.Metadata;
 import io.grpc.Status;
 import io.grpc.util.ForwardingClientStreamTracer;
@@ -256,16 +259,18 @@
   }
 
   /**
-   * An {@link XdsClientLoadRecorder} instance records and aggregates client-side load data into an
-   * {@link ClientLoadCounter} object.
+   * An {@link LoadRecordingStreamTracerFactory} instance records and aggregates client-side load
+   * data into an {@link ClientLoadCounter} object.
    */
   @ThreadSafe
-  static final class XdsClientLoadRecorder extends ClientStreamTracer.Factory {
+  @VisibleForTesting
+  static final class LoadRecordingStreamTracerFactory extends ClientStreamTracer.Factory {
 
     private final ClientStreamTracer.Factory delegate;
     private final ClientLoadCounter counter;
 
-    XdsClientLoadRecorder(ClientLoadCounter counter, ClientStreamTracer.Factory delegate) {
+    LoadRecordingStreamTracerFactory(ClientLoadCounter counter,
+        ClientStreamTracer.Factory delegate) {
       this.counter = checkNotNull(counter, "counter");
       this.delegate = checkNotNull(delegate, "delegate");
     }
@@ -287,18 +292,29 @@
         }
       };
     }
+
+    @VisibleForTesting
+    ClientLoadCounter getCounter() {
+      return counter;
+    }
+
+    @VisibleForTesting
+    ClientStreamTracer.Factory delegate() {
+      return delegate;
+    }
   }
 
   /**
-   * Listener implementation to receive backend metrics with locality-level aggregation.
+   * Listener implementation to receive backend metrics and record metric values in the provided
+   * {@link ClientLoadCounter}.
    */
   @ThreadSafe
-  static final class LocalityMetricsListener implements OrcaPerRequestReportListener,
-      OrcaOobReportListener {
+  static final class MetricsRecordingListener
+      implements OrcaPerRequestReportListener, OrcaOobReportListener {
 
     private final ClientLoadCounter counter;
 
-    LocalityMetricsListener(ClientLoadCounter counter) {
+    MetricsRecordingListener(ClientLoadCounter counter) {
       this.counter = checkNotNull(counter, "counter");
     }
 
@@ -310,5 +326,118 @@
         counter.recordMetric(entry.getKey(), entry.getValue());
       }
     }
+
+    @VisibleForTesting
+    ClientLoadCounter getCounter() {
+      return counter;
+    }
+  }
+
+  /**
+   * Base class for {@link SubchannelPicker} wrapper classes that intercept "RPC-capable"
+   * {@link PickResult}s with applying a custom {@link ClientStreamTracer.Factory} for stream
+   * instrumenting purposes.
+   */
+  @VisibleForTesting
+  abstract static class TracerWrappingSubchannelPicker extends SubchannelPicker {
+
+    private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER =
+        new ClientStreamTracer() {
+        };
+    private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
+        new ClientStreamTracer.Factory() {
+          @Override
+          public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
+            return NOOP_CLIENT_STREAM_TRACER;
+          }
+        };
+
+    protected abstract SubchannelPicker delegate();
+
+    protected abstract ClientStreamTracer.Factory wrapTracerFactory(
+        ClientStreamTracer.Factory originFactory);
+
+    @Override
+    public PickResult pickSubchannel(PickSubchannelArgs args) {
+      PickResult result = delegate().pickSubchannel(args);
+      if (!result.getStatus().isOk()) {
+        return result;
+      }
+      if (result.getSubchannel() == null) {
+        return result;
+      }
+      ClientStreamTracer.Factory originFactory = result.getStreamTracerFactory();
+      if (originFactory == null) {
+        originFactory = NOOP_CLIENT_STREAM_TRACER_FACTORY;
+      }
+      return PickResult.withSubchannel(result.getSubchannel(), wrapTracerFactory(originFactory));
+    }
+
+    @Override
+    public String toString() {
+      return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString();
+    }
+  }
+
+  /**
+   * A wrapper class that wraps a {@link SubchannelPicker} instance and associate it with a {@link
+   * ClientLoadCounter}. All "RPC-capable" {@link PickResult}s picked will be intercepted with
+   * client side load recording logic such that RPC activities occurring in the {@link PickResult}'s
+   * {@link io.grpc.LoadBalancer.Subchannel} will be recorded in the associated {@link
+   * ClientLoadCounter}.
+   */
+  @ThreadSafe
+  static final class LoadRecordingSubchannelPicker extends TracerWrappingSubchannelPicker {
+
+    private final ClientLoadCounter counter;
+    private final SubchannelPicker delegate;
+
+    LoadRecordingSubchannelPicker(ClientLoadCounter counter, SubchannelPicker delegate) {
+      this.counter = checkNotNull(counter, "counter");
+      this.delegate = checkNotNull(delegate, "delegate");
+    }
+
+    @Override
+    protected SubchannelPicker delegate() {
+      return delegate;
+    }
+
+    @Override
+    protected ClientStreamTracer.Factory wrapTracerFactory(
+        ClientStreamTracer.Factory originFactory) {
+      return new LoadRecordingStreamTracerFactory(counter, originFactory);
+    }
+  }
+
+  /**
+   * A wrapper class that wraps {@link SubchannelPicker} instance and associate it with an {@link
+   * OrcaPerRequestReportListener}. All "RPC-capable" {@link PickResult}s picked will be intercepted
+   * with the logic of registering the listener for observing backend metrics.
+   */
+  @ThreadSafe
+  static final class MetricsObservingSubchannelPicker extends TracerWrappingSubchannelPicker {
+
+    private final OrcaPerRequestReportListener listener;
+    private final SubchannelPicker delegate;
+    private final OrcaPerRequestUtil orcaPerRequestUtil;
+
+    MetricsObservingSubchannelPicker(OrcaPerRequestReportListener listener,
+        SubchannelPicker delegate,
+        OrcaPerRequestUtil orcaPerRequestUtil) {
+      this.listener = checkNotNull(listener, "listener");
+      this.delegate = checkNotNull(delegate, "delegate");
+      this.orcaPerRequestUtil = checkNotNull(orcaPerRequestUtil, "orcaPerRequestUtil");
+    }
+
+    @Override
+    protected SubchannelPicker delegate() {
+      return delegate;
+    }
+
+    @Override
+    protected ClientStreamTracer.Factory wrapTracerFactory(
+        ClientStreamTracer.Factory originFactory) {
+      return orcaPerRequestUtil.newOrcaClientStreamTracerFactory(originFactory, listener);
+    }
   }
 }
diff --git a/xds/src/main/java/io/grpc/xds/LocalityStore.java b/xds/src/main/java/io/grpc/xds/LocalityStore.java
index 878fb78..27da4a8 100644
--- a/xds/src/main/java/io/grpc/xds/LocalityStore.java
+++ b/xds/src/main/java/io/grpc/xds/LocalityStore.java
@@ -40,6 +40,9 @@
 import io.grpc.LoadBalancerRegistry;
 import io.grpc.Status;
 import io.grpc.util.ForwardingLoadBalancerHelper;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingSubchannelPicker;
+import io.grpc.xds.ClientLoadCounter.MetricsObservingSubchannelPicker;
+import io.grpc.xds.ClientLoadCounter.MetricsRecordingListener;
 import io.grpc.xds.InterLocalityPicker.WeightedChildPicker;
 import io.grpc.xds.XdsComms.DropOverload;
 import io.grpc.xds.XdsComms.LocalityInfo;
@@ -79,19 +82,24 @@
     private final LoadBalancerProvider loadBalancerProvider;
     private final ThreadSafeRandom random;
     private final StatsStore statsStore;
+    private final OrcaPerRequestUtil orcaPerRequestUtil;
 
     private Map<XdsLocality, LocalityLbInfo> localityMap = new HashMap<>();
     private ImmutableList<DropOverload> dropOverloads = ImmutableList.of();
 
     LocalityStoreImpl(Helper helper, LoadBalancerRegistry lbRegistry) {
       this(helper, pickerFactoryImpl, lbRegistry, ThreadSafeRandom.ThreadSafeRandomImpl.instance,
-          new XdsLoadStatsStore());
+          new XdsLoadStatsStore(), OrcaPerRequestUtil.getInstance());
     }
 
     @VisibleForTesting
     LocalityStoreImpl(
-        Helper helper, PickerFactory pickerFactory, LoadBalancerRegistry lbRegistry,
-        ThreadSafeRandom random, StatsStore statsStore) {
+        Helper helper,
+        PickerFactory pickerFactory,
+        LoadBalancerRegistry lbRegistry,
+        ThreadSafeRandom random,
+        StatsStore statsStore,
+        OrcaPerRequestUtil orcaPerRequestUtil) {
       this.helper = checkNotNull(helper, "helper");
       this.pickerFactory = checkNotNull(pickerFactory, "pickerFactory");
       loadBalancerProvider = checkNotNull(
@@ -99,6 +107,7 @@
           "Unable to find '%s' LoadBalancer", ROUND_ROBIN);
       this.random = checkNotNull(random, "random");
       this.statsStore = checkNotNull(statsStore, "statsStore");
+      this.orcaPerRequestUtil = checkNotNull(orcaPerRequestUtil, "orcaPerRequestUtil");
     }
 
     @VisibleForTesting // Introduced for testing only.
@@ -211,7 +220,7 @@
               childHelper);
         } else {
           statsStore.addLocality(newLocality);
-          childHelper = new ChildHelper(newLocality);
+          childHelper = new ChildHelper(newLocality, statsStore.getLocalityCounter(newLocality));
           localityLbInfo =
               new LocalityLbInfo(
                   localityInfoMap.get(newLocality).localityWeight,
@@ -360,12 +369,14 @@
     class ChildHelper extends ForwardingLoadBalancerHelper {
 
       private final XdsLocality locality;
+      private final ClientLoadCounter counter;
 
       private SubchannelPicker currentChildPicker = XdsSubchannelPickers.BUFFER_PICKER;
       private ConnectivityState currentChildState = null;
 
-      ChildHelper(XdsLocality locality) {
+      ChildHelper(XdsLocality locality, ClientLoadCounter counter) {
         this.locality = checkNotNull(locality, "locality");
+        this.counter = checkNotNull(counter, "counter");
       }
 
       @Override
@@ -375,33 +386,16 @@
 
       // This is triggered by child balancer
       @Override
-      public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
+      public void updateBalancingState(ConnectivityState newState,
+          final SubchannelPicker newPicker) {
         checkNotNull(newState, "newState");
         checkNotNull(newPicker, "newPicker");
 
-        class LoadRecordPicker extends SubchannelPicker {
-          private final SubchannelPicker delegate;
-
-          private LoadRecordPicker(SubchannelPicker delegate) {
-            this.delegate = delegate;
-          }
-
-          @Override
-          public PickResult pickSubchannel(PickSubchannelArgs args) {
-            return statsStore.interceptPickResult(delegate.pickSubchannel(args), locality);
-          }
-
-          @Override
-          public String toString() {
-            return MoreObjects.toStringHelper(this)
-                .add("delegate", delegate)
-                .add("locality", locality)
-                .toString();
-          }
-        }
-
         currentChildState = newState;
-        currentChildPicker = new LoadRecordPicker(newPicker);
+        currentChildPicker =
+            new LoadRecordingSubchannelPicker(counter,
+                new MetricsObservingSubchannelPicker(new MetricsRecordingListener(counter),
+                    newPicker, orcaPerRequestUtil));
 
         // delegate to parent helper
         updateChildState(locality, newState, currentChildPicker);
diff --git a/xds/src/main/java/io/grpc/xds/StatsStore.java b/xds/src/main/java/io/grpc/xds/StatsStore.java
index bb81f25..4e6d949 100644
--- a/xds/src/main/java/io/grpc/xds/StatsStore.java
+++ b/xds/src/main/java/io/grpc/xds/StatsStore.java
@@ -17,7 +17,6 @@
 package io.grpc.xds;
 
 import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
-import io.grpc.LoadBalancer.PickResult;
 import javax.annotation.Nullable;
 
 /**
@@ -61,15 +60,6 @@
   void removeLocality(XdsLocality locality);
 
   /**
-   * Applies client side load recording to {@link PickResult}s picked by the intra-locality picker
-   * for the provided locality. If the provided locality is not tracked, the original
-   * {@link PickResult} will be returned.
-   *
-   * <p>This method is thread-safe.
-   */
-  PickResult interceptPickResult(PickResult pickResult, XdsLocality locality);
-
-  /**
    * Returns the {@link ClientLoadCounter} that does locality level stats aggregation for the
    * provided locality. If the provided locality is not tracked, {@code null} will be returned.
    *
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java b/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
index 088939a..9d40a3e 100644
--- a/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
@@ -24,13 +24,8 @@
 import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
 import io.envoyproxy.envoy.api.v2.endpoint.EndpointLoadMetricStats;
 import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
-import io.grpc.ClientStreamTracer;
-import io.grpc.ClientStreamTracer.StreamInfo;
-import io.grpc.LoadBalancer.PickResult;
-import io.grpc.Metadata;
 import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
 import io.grpc.xds.ClientLoadCounter.MetricValue;
-import io.grpc.xds.ClientLoadCounter.XdsClientLoadRecorder;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -43,17 +38,6 @@
 @NotThreadSafe
 final class XdsLoadStatsStore implements StatsStore {
 
-  private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER =
-      new ClientStreamTracer() {
-      };
-  private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
-      new ClientStreamTracer.Factory() {
-        @Override
-        public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
-          return NOOP_CLIENT_STREAM_TRACER;
-        }
-      };
-
   private final ConcurrentMap<XdsLocality, ClientLoadCounter> localityLoadCounters;
   // Cluster level dropped request counts for each category specified in the DropOverload policy.
   private final ConcurrentMap<String, AtomicLong> dropCounters;
@@ -156,26 +140,4 @@
     }
     counter.getAndIncrement();
   }
-
-  @Override
-  public PickResult interceptPickResult(PickResult pickResult, XdsLocality locality) {
-    if (!pickResult.getStatus().isOk()) {
-      return pickResult;
-    }
-    if (pickResult.getSubchannel() == null) {
-      return pickResult;
-    }
-    ClientLoadCounter counter = localityLoadCounters.get(locality);
-    if (counter == null) {
-      // TODO (chengyuanzhang): this should not happen if this method is called in a correct
-      //  order with other methods in this class, but we might want to have some logs or warnings.
-      return pickResult;
-    }
-    ClientStreamTracer.Factory originFactory = pickResult.getStreamTracerFactory();
-    if (originFactory == null) {
-      originFactory = NOOP_CLIENT_STREAM_TRACER_FACTORY;
-    }
-    XdsClientLoadRecorder recorder = new XdsClientLoadRecorder(counter, originFactory);
-    return PickResult.withSubchannel(pickResult.getSubchannel(), recorder);
-  }
 }
diff --git a/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java b/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
index 69ac5ef..3204c6a 100644
--- a/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
+++ b/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
@@ -17,16 +17,30 @@
 package io.grpc.xds;
 
 import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport;
 import io.grpc.ClientStreamTracer;
+import io.grpc.ClientStreamTracer.Factory;
 import io.grpc.ClientStreamTracer.StreamInfo;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.LoadBalancer.PickSubchannelArgs;
+import io.grpc.LoadBalancer.Subchannel;
+import io.grpc.LoadBalancer.SubchannelPicker;
 import io.grpc.Metadata;
 import io.grpc.Status;
 import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
-import io.grpc.xds.ClientLoadCounter.LocalityMetricsListener;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingStreamTracerFactory;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingSubchannelPicker;
 import io.grpc.xds.ClientLoadCounter.MetricValue;
-import io.grpc.xds.ClientLoadCounter.XdsClientLoadRecorder;
+import io.grpc.xds.ClientLoadCounter.MetricsObservingSubchannelPicker;
+import io.grpc.xds.ClientLoadCounter.MetricsRecordingListener;
+import io.grpc.xds.ClientLoadCounter.TracerWrappingSubchannelPicker;
+import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener;
 import java.util.concurrent.ThreadLocalRandom;
 import org.junit.Before;
 import org.junit.Test;
@@ -120,28 +134,29 @@
   }
 
   @Test
-  public void xdsClientLoadRecorder_clientSideQueryCountsAggregation() {
-    XdsClientLoadRecorder recorder1 =
-        new XdsClientLoadRecorder(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
-    ClientStreamTracer tracer = recorder1.newClientStreamTracer(STREAM_INFO, new Metadata());
+  public void loadRecordingStreamTracerFactory_clientSideQueryCountsAggregation() {
+    LoadRecordingStreamTracerFactory factory1 =
+        new LoadRecordingStreamTracerFactory(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
+    ClientStreamTracer tracer = factory1.newClientStreamTracer(STREAM_INFO, new Metadata());
     ClientLoadSnapshot snapshot = counter.snapshot();
     assertQueryCounts(snapshot, 0, 1, 0, 1);
     tracer.streamClosed(Status.OK);
     snapshot = counter.snapshot();
     assertQueryCounts(snapshot, 1, 0, 0, 0);
 
-    // Create a second XdsClientLoadRecorder with the same counter, stats are aggregated together.
-    XdsClientLoadRecorder recorder2 =
-        new XdsClientLoadRecorder(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
-    recorder1.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.ABORTED);
-    recorder2.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.CANCELLED);
+    // Create a second LoadRecordingStreamTracerFactory with the same counter, stats are aggregated
+    // together.
+    LoadRecordingStreamTracerFactory factory2 =
+        new LoadRecordingStreamTracerFactory(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
+    factory1.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.ABORTED);
+    factory2.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.CANCELLED);
     snapshot = counter.snapshot();
     assertQueryCounts(snapshot, 0, 0, 2, 2);
   }
 
   @Test
-  public void metricListener_backendMetricsAggregation() {
-    LocalityMetricsListener listener1 = new LocalityMetricsListener(counter);
+  public void metricsRecordingListener_backendMetricsAggregation() {
+    MetricsRecordingListener listener1 = new MetricsRecordingListener(counter);
     OrcaLoadReport report =
         OrcaLoadReport.newBuilder()
             .setCpuUtilization(0.5345)
@@ -174,7 +189,7 @@
     snapshot = counter.snapshot();
     assertThat(snapshot.getMetricValues()).isEmpty();
 
-    LocalityMetricsListener listener2 = new LocalityMetricsListener(counter);
+    MetricsRecordingListener listener2 = new MetricsRecordingListener(counter);
     report =
         OrcaLoadReport.newBuilder()
             .setCpuUtilization(0.3423)
@@ -199,6 +214,130 @@
     assertThat(namedMetric.getTotalValue()).isEqualTo(3534.0 + 3534.0);
   }
 
+  @Test
+  public void tracerWrappingSubchannelPicker_interceptPickResult_invalidPickResultNotIntercepted() {
+    final SubchannelPicker picker = mock(SubchannelPicker.class);
+    SubchannelPicker streamInstrSubchannelPicker = new TracerWrappingSubchannelPicker() {
+      @Override
+      protected SubchannelPicker delegate() {
+        return picker;
+      }
+
+      @Override
+      protected Factory wrapTracerFactory(Factory originFactory) {
+        // NO-OP
+        return originFactory;
+      }
+    };
+    PickResult errorResult = PickResult.withError(Status.UNAVAILABLE.withDescription("Error"));
+    PickResult droppedResult = PickResult.withDrop(Status.UNAVAILABLE.withDescription("Dropped"));
+    PickResult emptyResult = PickResult.withNoResult();
+    when(picker.pickSubchannel(any(PickSubchannelArgs.class)))
+        .thenReturn(errorResult, droppedResult, emptyResult);
+    PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+    PickResult interceptedErrorResult = streamInstrSubchannelPicker.pickSubchannel(args);
+    PickResult interceptedDroppedResult = streamInstrSubchannelPicker.pickSubchannel(args);
+    PickResult interceptedEmptyResult = streamInstrSubchannelPicker.pickSubchannel(args);
+    assertThat(interceptedErrorResult).isSameInstanceAs(errorResult);
+    assertThat(interceptedDroppedResult).isSameInstanceAs(droppedResult);
+    assertThat(interceptedEmptyResult).isSameInstanceAs(emptyResult);
+  }
+
+  @Test
+  public void loadRecordingSubchannelPicker_interceptPickResult_applyLoadRecorderToPickResult() {
+    ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
+    ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
+    when(mockFactory
+        .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
+        .thenReturn(mockTracer);
+
+    ClientLoadCounter localityCounter1 = new ClientLoadCounter();
+    ClientLoadCounter localityCounter2 = new ClientLoadCounter();
+    final PickResult pickResult1 = PickResult.withSubchannel(mock(Subchannel.class), mockFactory);
+    final PickResult pickResult2 = PickResult.withSubchannel(mock(Subchannel.class));
+    SubchannelPicker picker1 = new SubchannelPicker() {
+      @Override
+      public PickResult pickSubchannel(PickSubchannelArgs args) {
+        return pickResult1;
+      }
+    };
+    SubchannelPicker picker2 = new SubchannelPicker() {
+      @Override
+      public PickResult pickSubchannel(PickSubchannelArgs args) {
+        return pickResult2;
+      }
+    };
+    SubchannelPicker loadRecordingPicker1 =
+        new LoadRecordingSubchannelPicker(localityCounter1, picker1);
+    SubchannelPicker loadRecordingPicker2 =
+        new LoadRecordingSubchannelPicker(localityCounter2, picker2);
+    PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+    PickResult interceptedPickResult1 = loadRecordingPicker1.pickSubchannel(args);
+    PickResult interceptedPickResult2 = loadRecordingPicker2.pickSubchannel(args);
+
+    LoadRecordingStreamTracerFactory recorder1 =
+        (LoadRecordingStreamTracerFactory) interceptedPickResult1.getStreamTracerFactory();
+    LoadRecordingStreamTracerFactory recorder2 =
+        (LoadRecordingStreamTracerFactory) interceptedPickResult2.getStreamTracerFactory();
+    assertThat(recorder1.getCounter()).isSameInstanceAs(localityCounter1);
+    assertThat(recorder2.getCounter()).isSameInstanceAs(localityCounter2);
+
+    // Stream tracing is propagated to downstream tracers, which preserves PickResult's original
+    // tracing functionality.
+    Metadata metadata = new Metadata();
+    interceptedPickResult1.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, metadata)
+        .streamClosed(Status.OK);
+    verify(mockFactory).newClientStreamTracer(same(STREAM_INFO), same(metadata));
+    verify(mockTracer).streamClosed(Status.OK);
+  }
+
+  @Test
+  public void metricsObservingSubchannelPicker_interceptPickResult_applyOrcaListenerToPickResult() {
+    ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
+    ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
+    when(mockFactory
+        .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
+        .thenReturn(mockTracer);
+
+    final PickResult pickResult1 = PickResult.withSubchannel(mock(Subchannel.class), mockFactory);
+    final PickResult pickResult2 = PickResult.withSubchannel(mock(Subchannel.class));
+    SubchannelPicker picker1 = new SubchannelPicker() {
+      @Override
+      public PickResult pickSubchannel(PickSubchannelArgs args) {
+        return pickResult1;
+      }
+    };
+    SubchannelPicker picker2 = new SubchannelPicker() {
+      @Override
+      public PickResult pickSubchannel(PickSubchannelArgs args) {
+        return pickResult2;
+      }
+    };
+    OrcaPerRequestUtil orcaPerRequestUtil = mock(OrcaPerRequestUtil.class);
+    ClientStreamTracer.Factory metricsRecorder1 = mock(ClientStreamTracer.Factory.class);
+    ClientStreamTracer.Factory metricsRecorder2 = mock(ClientStreamTracer.Factory.class);
+    when(orcaPerRequestUtil.newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class),
+        any(OrcaPerRequestReportListener.class))).thenReturn(metricsRecorder1, metricsRecorder2);
+    OrcaPerRequestReportListener listener1 = mock(OrcaPerRequestReportListener.class);
+    OrcaPerRequestReportListener listener2 = mock(OrcaPerRequestReportListener.class);
+    PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+    SubchannelPicker metricsObservingPicker1 =
+        new MetricsObservingSubchannelPicker(listener1, picker1, orcaPerRequestUtil);
+    SubchannelPicker metricsObservingPicker2 =
+        new MetricsObservingSubchannelPicker(listener2, picker2, orcaPerRequestUtil);
+    PickResult interceptedPickResult1 = metricsObservingPicker1.pickSubchannel(args);
+    PickResult interceptedPickResult2 = metricsObservingPicker2.pickSubchannel(args);
+
+    verify(orcaPerRequestUtil)
+        .newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class), same(listener1));
+    verify(orcaPerRequestUtil)
+        .newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class), same(listener2));
+    assertThat(interceptedPickResult1.getStreamTracerFactory()).isSameInstanceAs(metricsRecorder1);
+    assertThat(interceptedPickResult2.getStreamTracerFactory()).isSameInstanceAs(metricsRecorder2);
+  }
+
   private void assertQueryCounts(ClientLoadSnapshot snapshot,
       long callsSucceeded,
       long callsInProgress,
diff --git a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java
index dd4ae37..910dfa7 100644
--- a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java
+++ b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java
@@ -20,24 +20,24 @@
 import static io.grpc.ConnectivityState.CONNECTING;
 import static io.grpc.ConnectivityState.IDLE;
 import static io.grpc.ConnectivityState.READY;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.delegatesTo;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.same;
-import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.mockito.Mockito.when;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
 import io.grpc.ChannelLogger;
+import io.grpc.ClientStreamTracer;
 import io.grpc.ConnectivityState;
 import io.grpc.EquivalentAddressGroup;
 import io.grpc.LoadBalancer;
@@ -51,9 +51,12 @@
 import io.grpc.LoadBalancerProvider;
 import io.grpc.LoadBalancerRegistry;
 import io.grpc.SynchronizationContext;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingStreamTracerFactory;
+import io.grpc.xds.ClientLoadCounter.MetricsRecordingListener;
 import io.grpc.xds.InterLocalityPicker.WeightedChildPicker;
 import io.grpc.xds.LocalityStore.LocalityStoreImpl;
 import io.grpc.xds.LocalityStore.LocalityStoreImpl.PickerFactory;
+import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener;
 import io.grpc.xds.XdsComms.DropOverload;
 import io.grpc.xds.XdsComms.LbEndpoint;
 import io.grpc.xds.XdsComms.LocalityInfo;
@@ -62,6 +65,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import javax.annotation.Nullable;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -176,7 +180,9 @@
   @Mock
   private ThreadSafeRandom random;
   @Mock
-  private StatsStore statsStore;
+  private OrcaPerRequestUtil orcaPerRequestUtil;
+  private final FakeLoadStatsStore fakeLoadStatsStore = new FakeLoadStatsStore();
+  private StatsStore statsStore = mock(StatsStore.class, delegatesTo(fakeLoadStatsStore));
 
   private LocalityStore localityStore;
 
@@ -185,10 +191,10 @@
     doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger();
     doReturn(mock(Subchannel.class)).when(helper).createSubchannel(any(CreateSubchannelArgs.class));
     doReturn(syncContext).when(helper).getSynchronizationContext();
-    doAnswer(returnsFirstArg())
-        .when(statsStore).interceptPickResult(any(PickResult.class), any(XdsLocality.class));
     lbRegistry.register(lbProvider);
-    localityStore = new LocalityStoreImpl(helper, pickerFactory, lbRegistry, random, statsStore);
+    localityStore =
+        new LocalityStoreImpl(helper, pickerFactory, lbRegistry, random, statsStore,
+            orcaPerRequestUtil);
   }
 
   @Test
@@ -218,11 +224,10 @@
 
     localityStore.updateLocalityStore(Collections.EMPTY_MAP);
     verify(statsStore).removeLocality(locality4);
-    verifyNoMoreInteractions(statsStore);
   }
 
   @Test
-  public void updateLocalityStore_interceptPickResultUponPickReadySubchannel() {
+  public void updateLocalityStore_pickResultInterceptedForLoadRecordingWhenSubchannelReady() {
     // Simulate receiving two localities.
     LocalityInfo localityInfo1 =
         new LocalityInfo(ImmutableList.of(lbEndpoint11, lbEndpoint12), 1);
@@ -235,15 +240,23 @@
     assertThat(loadBalancers).hasSize(2);
     assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
 
-    // Simulate picker updates for each of the two localities with dummy pickers.
-    final PickResult result1 = PickResult.withNoResult();
-    final PickResult result2 = PickResult.withNoResult();
+    ClientStreamTracer.Factory metricsTracingFactory1 = mock(ClientStreamTracer.Factory.class);
+    ClientStreamTracer.Factory metricsTracingFactory2 = mock(ClientStreamTracer.Factory.class);
+    when(orcaPerRequestUtil.newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class),
+        any(OrcaPerRequestReportListener.class)))
+        .thenReturn(metricsTracingFactory1, metricsTracingFactory2);
+
+    final PickResult result1 = PickResult.withSubchannel(mock(Subchannel.class));
+    final PickResult result2 =
+        PickResult.withSubchannel(mock(Subchannel.class), mock(ClientStreamTracer.Factory.class));
     SubchannelPicker subchannelPicker1 = mock(SubchannelPicker.class);
     SubchannelPicker subchannelPicker2 = mock(SubchannelPicker.class);
     when(subchannelPicker1.pickSubchannel(any(PickSubchannelArgs.class)))
         .thenReturn(result1);
     when(subchannelPicker2.pickSubchannel(any(PickSubchannelArgs.class)))
         .thenReturn(result2);
+
+    // Simulate picker updates for the two localities with dummy pickers.
     childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker1);
     childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2);
 
@@ -251,12 +264,31 @@
     ArgumentCaptor<SubchannelPicker> interLocalityPickerCaptor = ArgumentCaptor.forClass(null);
     verify(helper, times(2)).updateBalancingState(eq(READY), interLocalityPickerCaptor.capture());
     SubchannelPicker interLocalityPicker = interLocalityPickerCaptor.getValue();
+
+    // Verify each PickResult picked is intercepted with client stream tracer factory for
+    // recording load and backend metrics.
+    List<XdsLocality> localities = ImmutableList.of(locality1, locality2);
+    List<ClientStreamTracer.Factory> metricsTracingFactories =
+        ImmutableList.of(metricsTracingFactory1, metricsTracingFactory2);
     for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) {
       pickerFactory.nextIndex = i;
-      interLocalityPicker.pickSubchannel(pickSubchannelArgs);
+      PickResult pickResult = interLocalityPicker.pickSubchannel(pickSubchannelArgs);
+      ArgumentCaptor<OrcaPerRequestReportListener> listenerCaptor = ArgumentCaptor.forClass(null);
+      verify(orcaPerRequestUtil, times(i + 1))
+          .newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class),
+              listenerCaptor.capture());
+      assertThat(listenerCaptor.getValue()).isInstanceOf(MetricsRecordingListener.class);
+      MetricsRecordingListener listener = (MetricsRecordingListener) listenerCaptor.getValue();
+      assertThat(listener.getCounter())
+          .isSameInstanceAs(fakeLoadStatsStore.localityCounters.get(localities.get(i)));
+      assertThat(pickResult.getStreamTracerFactory())
+          .isInstanceOf(LoadRecordingStreamTracerFactory.class);
+      LoadRecordingStreamTracerFactory loadRecordingFactory =
+          (LoadRecordingStreamTracerFactory) pickResult.getStreamTracerFactory();
+      assertThat(loadRecordingFactory.getCounter())
+          .isSameInstanceAs(fakeLoadStatsStore.localityCounters.get(localities.get(i)));
+      assertThat(loadRecordingFactory.delegate()).isSameInstanceAs(metricsTracingFactories.get(i));
     }
-    verify(statsStore).interceptPickResult(same(result1), eq(locality1));
-    verify(statsStore).interceptPickResult(same(result2), eq(locality2));
   }
 
   @Test
@@ -510,4 +542,37 @@
     verify(statsStore).removeLocality(locality1);
     verify(statsStore).removeLocality(locality2);
   }
+
+  private static final class FakeLoadStatsStore implements StatsStore {
+
+    Map<XdsLocality, ClientLoadCounter> localityCounters = new HashMap<>();
+
+    @Override
+    public ClusterStats generateLoadReport() {
+      throw new AssertionError("Should not be called");
+    }
+
+    @Override
+    public void addLocality(XdsLocality locality) {
+      assertThat(localityCounters).doesNotContainKey(locality);
+      localityCounters.put(locality, new ClientLoadCounter());
+    }
+
+    @Override
+    public void removeLocality(XdsLocality locality) {
+      assertThat(localityCounters).containsKey(locality);
+      localityCounters.remove(locality);
+    }
+
+    @Nullable
+    @Override
+    public ClientLoadCounter getLocalityCounter(XdsLocality locality) {
+      return localityCounters.get(locality);
+    }
+
+    @Override
+    public void recordDroppedRequest(String category) {
+      // NO-OP, verify by invocations.
+    }
+  }
 }
diff --git a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java
index 13ca628..9a9429b 100644
--- a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java
@@ -17,9 +17,6 @@
 package io.grpc.xds;
 
 import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -32,7 +29,6 @@
 import io.grpc.ChannelLogger;
 import io.grpc.EquivalentAddressGroup;
 import io.grpc.LoadBalancer.Helper;
-import io.grpc.LoadBalancer.PickResult;
 import io.grpc.ManagedChannel;
 import io.grpc.SynchronizationContext;
 import io.grpc.inprocess.InProcessChannelBuilder;
@@ -63,8 +59,6 @@
   @Mock
   private AdsStreamCallback adsStreamCallback;
   @Mock
-  private StatsStore statsStore;
-  @Mock
   private LocalityStore localityStore;
 
   private final FakeClock fakeClock = new FakeClock();
@@ -89,8 +83,6 @@
     doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService();
     doReturn("fake_authority").when(helper).getAuthority();
     doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger();
-    doAnswer(returnsFirstArg())
-        .when(statsStore).interceptPickResult(any(PickResult.class), any(XdsLocality.class));
 
     String serverName = InProcessServerBuilder.generateName();
 
diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
index 074b563..a18763c 100644
--- a/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
@@ -17,22 +17,12 @@
 package io.grpc.xds;
 
 import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.same;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
 
 import com.google.common.collect.ImmutableMap;
 import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
 import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
 import io.envoyproxy.envoy.api.v2.endpoint.EndpointLoadMetricStats;
 import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
-import io.grpc.ClientStreamTracer;
-import io.grpc.LoadBalancer.PickResult;
-import io.grpc.LoadBalancer.Subchannel;
-import io.grpc.Metadata;
-import io.grpc.Status;
 import io.grpc.xds.ClientLoadCounter.MetricValue;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -57,9 +47,6 @@
       new XdsLocality("test_region1", "test_zone", "test_subzone");
   private static final XdsLocality LOCALITY2 =
       new XdsLocality("test_region2", "test_zone", "test_subzone");
-  private static final ClientStreamTracer.StreamInfo STREAM_INFO =
-      ClientStreamTracer.StreamInfo.newBuilder().build();
-  private Subchannel mockSubchannel = mock(Subchannel.class);
   private ConcurrentMap<XdsLocality, ClientLoadCounter> localityLoadCounters;
   private ConcurrentMap<String, AtomicLong> dropCounters;
   private XdsLoadStatsStore loadStore;
@@ -286,44 +273,4 @@
     assertThat(dropCounters.get("lb").get()).isEqualTo(0);
     assertThat(dropCounters.get("throttle").get()).isEqualTo(0);
   }
-
-  @Test
-  public void loadNotRecordedForUntrackedLocality() {
-    PickResult pickResult = PickResult.withSubchannel(mockSubchannel);
-    // If the per-locality counter does not exist, nothing should happen.
-    PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, LOCALITY1);
-    assertThat(interceptedPickResult.getStreamTracerFactory()).isNull();
-  }
-
-  @Test
-  public void invalidPickResultNotIntercepted() {
-    localityLoadCounters.put(LOCALITY1, new ClientLoadCounter());
-    PickResult errorResult = PickResult.withError(Status.UNAVAILABLE.withDescription("Error"));
-    PickResult droppedResult = PickResult.withDrop(Status.UNAVAILABLE.withDescription("Dropped"));
-    PickResult emptyResult = PickResult.withNoResult();
-    PickResult interceptedErrorResult = loadStore.interceptPickResult(errorResult, LOCALITY1);
-    PickResult interceptedDroppedResult =
-        loadStore.interceptPickResult(droppedResult, LOCALITY1);
-    PickResult interceptedEmptyResult = loadStore.interceptPickResult(emptyResult, LOCALITY1);
-    assertThat(interceptedErrorResult).isSameInstanceAs(errorResult);
-    assertThat(interceptedDroppedResult).isSameInstanceAs(droppedResult);
-    assertThat(interceptedEmptyResult).isSameInstanceAs(emptyResult);
-  }
-
-  @Test
-  public void interceptPreservesOriginStreamTracer() {
-    ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
-    ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
-    when(mockFactory
-        .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
-        .thenReturn(mockTracer);
-    localityLoadCounters.put(LOCALITY1, new ClientLoadCounter());
-    PickResult pickResult = PickResult.withSubchannel(mockSubchannel, mockFactory);
-    PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, LOCALITY1);
-    Metadata metadata = new Metadata();
-    interceptedPickResult.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, metadata)
-        .streamClosed(Status.OK);
-    verify(mockFactory).newClientStreamTracer(same(STREAM_INFO), same(metadata));
-    verify(mockTracer).streamClosed(Status.OK);
-  }
 }