xds: redesign client load recording and backend metrics receiving interface (#5903)
* Redefined StatsStore interface.
- Removed interface method StatsStore#interceptPickResult and implementation do not take the resposibility intercepting PickResult with locality-level load recording.
- Introduce a wrapper class for SubchannelPicker to let users wrap SubchannelPicker by themselves, with client side load recording logic.
- Associate the corresponding locality counter with child helper when it is created, child helper will intercept the SubchannelPicker it creates.
* Renamed backend metrics listener class to be more abstract, hides the implementation detail of doing locality-level aggregation.
* Integrate client load recording and backend metrics recording with xDS load balancer.
- Created LoadRecordingSubchannelPicker class for applying XdsClientLoadRecorder that records client load to PickResult.
- Created MetricsObservingSubchannel class for applying OrcaReportingTracerFactory that takes listener to receive ORCA reports to PickResult.
- In xDS load balancer LocalityStore, the original picker is wrapped two layers inside the above wrappers.
* Renamed XdsClientLoadRecorder to ClientLoadRecorder. It should only be used for testing, xDS load balancer should use SubchannelPicker wrappers instead of this load recorder directly.
* Removed redudent layer of wrapping for SubchannelPicker in LocalityStore
* Added toString for SubchannelPicker wrapper classes.
* Rename ClientLoadRecorder to LoadRecordingStreamTracerFactory.
* Renamed StreamInstrumentedSubchannelPicker to TracerWrappingSubchannelPicker.
* Eliminate duplicated code in LocalityStoreTest, put them into a loop.
diff --git a/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java b/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
index e9c445a..c574a7a 100644
--- a/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
+++ b/xds/src/main/java/io/grpc/xds/ClientLoadCounter.java
@@ -23,6 +23,9 @@
import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport;
import io.grpc.ClientStreamTracer;
import io.grpc.ClientStreamTracer.StreamInfo;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.LoadBalancer.PickSubchannelArgs;
+import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.util.ForwardingClientStreamTracer;
@@ -256,16 +259,18 @@
}
/**
- * An {@link XdsClientLoadRecorder} instance records and aggregates client-side load data into an
- * {@link ClientLoadCounter} object.
+ * An {@link LoadRecordingStreamTracerFactory} instance records and aggregates client-side load
+ * data into an {@link ClientLoadCounter} object.
*/
@ThreadSafe
- static final class XdsClientLoadRecorder extends ClientStreamTracer.Factory {
+ @VisibleForTesting
+ static final class LoadRecordingStreamTracerFactory extends ClientStreamTracer.Factory {
private final ClientStreamTracer.Factory delegate;
private final ClientLoadCounter counter;
- XdsClientLoadRecorder(ClientLoadCounter counter, ClientStreamTracer.Factory delegate) {
+ LoadRecordingStreamTracerFactory(ClientLoadCounter counter,
+ ClientStreamTracer.Factory delegate) {
this.counter = checkNotNull(counter, "counter");
this.delegate = checkNotNull(delegate, "delegate");
}
@@ -287,18 +292,29 @@
}
};
}
+
+ @VisibleForTesting
+ ClientLoadCounter getCounter() {
+ return counter;
+ }
+
+ @VisibleForTesting
+ ClientStreamTracer.Factory delegate() {
+ return delegate;
+ }
}
/**
- * Listener implementation to receive backend metrics with locality-level aggregation.
+ * Listener implementation to receive backend metrics and record metric values in the provided
+ * {@link ClientLoadCounter}.
*/
@ThreadSafe
- static final class LocalityMetricsListener implements OrcaPerRequestReportListener,
- OrcaOobReportListener {
+ static final class MetricsRecordingListener
+ implements OrcaPerRequestReportListener, OrcaOobReportListener {
private final ClientLoadCounter counter;
- LocalityMetricsListener(ClientLoadCounter counter) {
+ MetricsRecordingListener(ClientLoadCounter counter) {
this.counter = checkNotNull(counter, "counter");
}
@@ -310,5 +326,118 @@
counter.recordMetric(entry.getKey(), entry.getValue());
}
}
+
+ @VisibleForTesting
+ ClientLoadCounter getCounter() {
+ return counter;
+ }
+ }
+
+ /**
+ * Base class for {@link SubchannelPicker} wrapper classes that intercept "RPC-capable"
+ * {@link PickResult}s with applying a custom {@link ClientStreamTracer.Factory} for stream
+ * instrumenting purposes.
+ */
+ @VisibleForTesting
+ abstract static class TracerWrappingSubchannelPicker extends SubchannelPicker {
+
+ private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER =
+ new ClientStreamTracer() {
+ };
+ private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
+ new ClientStreamTracer.Factory() {
+ @Override
+ public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
+ return NOOP_CLIENT_STREAM_TRACER;
+ }
+ };
+
+ protected abstract SubchannelPicker delegate();
+
+ protected abstract ClientStreamTracer.Factory wrapTracerFactory(
+ ClientStreamTracer.Factory originFactory);
+
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ PickResult result = delegate().pickSubchannel(args);
+ if (!result.getStatus().isOk()) {
+ return result;
+ }
+ if (result.getSubchannel() == null) {
+ return result;
+ }
+ ClientStreamTracer.Factory originFactory = result.getStreamTracerFactory();
+ if (originFactory == null) {
+ originFactory = NOOP_CLIENT_STREAM_TRACER_FACTORY;
+ }
+ return PickResult.withSubchannel(result.getSubchannel(), wrapTracerFactory(originFactory));
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this).add("delegate", delegate()).toString();
+ }
+ }
+
+ /**
+ * A wrapper class that wraps a {@link SubchannelPicker} instance and associate it with a {@link
+ * ClientLoadCounter}. All "RPC-capable" {@link PickResult}s picked will be intercepted with
+ * client side load recording logic such that RPC activities occurring in the {@link PickResult}'s
+ * {@link io.grpc.LoadBalancer.Subchannel} will be recorded in the associated {@link
+ * ClientLoadCounter}.
+ */
+ @ThreadSafe
+ static final class LoadRecordingSubchannelPicker extends TracerWrappingSubchannelPicker {
+
+ private final ClientLoadCounter counter;
+ private final SubchannelPicker delegate;
+
+ LoadRecordingSubchannelPicker(ClientLoadCounter counter, SubchannelPicker delegate) {
+ this.counter = checkNotNull(counter, "counter");
+ this.delegate = checkNotNull(delegate, "delegate");
+ }
+
+ @Override
+ protected SubchannelPicker delegate() {
+ return delegate;
+ }
+
+ @Override
+ protected ClientStreamTracer.Factory wrapTracerFactory(
+ ClientStreamTracer.Factory originFactory) {
+ return new LoadRecordingStreamTracerFactory(counter, originFactory);
+ }
+ }
+
+ /**
+ * A wrapper class that wraps {@link SubchannelPicker} instance and associate it with an {@link
+ * OrcaPerRequestReportListener}. All "RPC-capable" {@link PickResult}s picked will be intercepted
+ * with the logic of registering the listener for observing backend metrics.
+ */
+ @ThreadSafe
+ static final class MetricsObservingSubchannelPicker extends TracerWrappingSubchannelPicker {
+
+ private final OrcaPerRequestReportListener listener;
+ private final SubchannelPicker delegate;
+ private final OrcaPerRequestUtil orcaPerRequestUtil;
+
+ MetricsObservingSubchannelPicker(OrcaPerRequestReportListener listener,
+ SubchannelPicker delegate,
+ OrcaPerRequestUtil orcaPerRequestUtil) {
+ this.listener = checkNotNull(listener, "listener");
+ this.delegate = checkNotNull(delegate, "delegate");
+ this.orcaPerRequestUtil = checkNotNull(orcaPerRequestUtil, "orcaPerRequestUtil");
+ }
+
+ @Override
+ protected SubchannelPicker delegate() {
+ return delegate;
+ }
+
+ @Override
+ protected ClientStreamTracer.Factory wrapTracerFactory(
+ ClientStreamTracer.Factory originFactory) {
+ return orcaPerRequestUtil.newOrcaClientStreamTracerFactory(originFactory, listener);
+ }
}
}
diff --git a/xds/src/main/java/io/grpc/xds/LocalityStore.java b/xds/src/main/java/io/grpc/xds/LocalityStore.java
index 878fb78..27da4a8 100644
--- a/xds/src/main/java/io/grpc/xds/LocalityStore.java
+++ b/xds/src/main/java/io/grpc/xds/LocalityStore.java
@@ -40,6 +40,9 @@
import io.grpc.LoadBalancerRegistry;
import io.grpc.Status;
import io.grpc.util.ForwardingLoadBalancerHelper;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingSubchannelPicker;
+import io.grpc.xds.ClientLoadCounter.MetricsObservingSubchannelPicker;
+import io.grpc.xds.ClientLoadCounter.MetricsRecordingListener;
import io.grpc.xds.InterLocalityPicker.WeightedChildPicker;
import io.grpc.xds.XdsComms.DropOverload;
import io.grpc.xds.XdsComms.LocalityInfo;
@@ -79,19 +82,24 @@
private final LoadBalancerProvider loadBalancerProvider;
private final ThreadSafeRandom random;
private final StatsStore statsStore;
+ private final OrcaPerRequestUtil orcaPerRequestUtil;
private Map<XdsLocality, LocalityLbInfo> localityMap = new HashMap<>();
private ImmutableList<DropOverload> dropOverloads = ImmutableList.of();
LocalityStoreImpl(Helper helper, LoadBalancerRegistry lbRegistry) {
this(helper, pickerFactoryImpl, lbRegistry, ThreadSafeRandom.ThreadSafeRandomImpl.instance,
- new XdsLoadStatsStore());
+ new XdsLoadStatsStore(), OrcaPerRequestUtil.getInstance());
}
@VisibleForTesting
LocalityStoreImpl(
- Helper helper, PickerFactory pickerFactory, LoadBalancerRegistry lbRegistry,
- ThreadSafeRandom random, StatsStore statsStore) {
+ Helper helper,
+ PickerFactory pickerFactory,
+ LoadBalancerRegistry lbRegistry,
+ ThreadSafeRandom random,
+ StatsStore statsStore,
+ OrcaPerRequestUtil orcaPerRequestUtil) {
this.helper = checkNotNull(helper, "helper");
this.pickerFactory = checkNotNull(pickerFactory, "pickerFactory");
loadBalancerProvider = checkNotNull(
@@ -99,6 +107,7 @@
"Unable to find '%s' LoadBalancer", ROUND_ROBIN);
this.random = checkNotNull(random, "random");
this.statsStore = checkNotNull(statsStore, "statsStore");
+ this.orcaPerRequestUtil = checkNotNull(orcaPerRequestUtil, "orcaPerRequestUtil");
}
@VisibleForTesting // Introduced for testing only.
@@ -211,7 +220,7 @@
childHelper);
} else {
statsStore.addLocality(newLocality);
- childHelper = new ChildHelper(newLocality);
+ childHelper = new ChildHelper(newLocality, statsStore.getLocalityCounter(newLocality));
localityLbInfo =
new LocalityLbInfo(
localityInfoMap.get(newLocality).localityWeight,
@@ -360,12 +369,14 @@
class ChildHelper extends ForwardingLoadBalancerHelper {
private final XdsLocality locality;
+ private final ClientLoadCounter counter;
private SubchannelPicker currentChildPicker = XdsSubchannelPickers.BUFFER_PICKER;
private ConnectivityState currentChildState = null;
- ChildHelper(XdsLocality locality) {
+ ChildHelper(XdsLocality locality, ClientLoadCounter counter) {
this.locality = checkNotNull(locality, "locality");
+ this.counter = checkNotNull(counter, "counter");
}
@Override
@@ -375,33 +386,16 @@
// This is triggered by child balancer
@Override
- public void updateBalancingState(ConnectivityState newState, SubchannelPicker newPicker) {
+ public void updateBalancingState(ConnectivityState newState,
+ final SubchannelPicker newPicker) {
checkNotNull(newState, "newState");
checkNotNull(newPicker, "newPicker");
- class LoadRecordPicker extends SubchannelPicker {
- private final SubchannelPicker delegate;
-
- private LoadRecordPicker(SubchannelPicker delegate) {
- this.delegate = delegate;
- }
-
- @Override
- public PickResult pickSubchannel(PickSubchannelArgs args) {
- return statsStore.interceptPickResult(delegate.pickSubchannel(args), locality);
- }
-
- @Override
- public String toString() {
- return MoreObjects.toStringHelper(this)
- .add("delegate", delegate)
- .add("locality", locality)
- .toString();
- }
- }
-
currentChildState = newState;
- currentChildPicker = new LoadRecordPicker(newPicker);
+ currentChildPicker =
+ new LoadRecordingSubchannelPicker(counter,
+ new MetricsObservingSubchannelPicker(new MetricsRecordingListener(counter),
+ newPicker, orcaPerRequestUtil));
// delegate to parent helper
updateChildState(locality, newState, currentChildPicker);
diff --git a/xds/src/main/java/io/grpc/xds/StatsStore.java b/xds/src/main/java/io/grpc/xds/StatsStore.java
index bb81f25..4e6d949 100644
--- a/xds/src/main/java/io/grpc/xds/StatsStore.java
+++ b/xds/src/main/java/io/grpc/xds/StatsStore.java
@@ -17,7 +17,6 @@
package io.grpc.xds;
import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
-import io.grpc.LoadBalancer.PickResult;
import javax.annotation.Nullable;
/**
@@ -61,15 +60,6 @@
void removeLocality(XdsLocality locality);
/**
- * Applies client side load recording to {@link PickResult}s picked by the intra-locality picker
- * for the provided locality. If the provided locality is not tracked, the original
- * {@link PickResult} will be returned.
- *
- * <p>This method is thread-safe.
- */
- PickResult interceptPickResult(PickResult pickResult, XdsLocality locality);
-
- /**
* Returns the {@link ClientLoadCounter} that does locality level stats aggregation for the
* provided locality. If the provided locality is not tracked, {@code null} will be returned.
*
diff --git a/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java b/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
index 088939a..9d40a3e 100644
--- a/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
+++ b/xds/src/main/java/io/grpc/xds/XdsLoadStatsStore.java
@@ -24,13 +24,8 @@
import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
import io.envoyproxy.envoy.api.v2.endpoint.EndpointLoadMetricStats;
import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
-import io.grpc.ClientStreamTracer;
-import io.grpc.ClientStreamTracer.StreamInfo;
-import io.grpc.LoadBalancer.PickResult;
-import io.grpc.Metadata;
import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
import io.grpc.xds.ClientLoadCounter.MetricValue;
-import io.grpc.xds.ClientLoadCounter.XdsClientLoadRecorder;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@@ -43,17 +38,6 @@
@NotThreadSafe
final class XdsLoadStatsStore implements StatsStore {
- private static final ClientStreamTracer NOOP_CLIENT_STREAM_TRACER =
- new ClientStreamTracer() {
- };
- private static final ClientStreamTracer.Factory NOOP_CLIENT_STREAM_TRACER_FACTORY =
- new ClientStreamTracer.Factory() {
- @Override
- public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) {
- return NOOP_CLIENT_STREAM_TRACER;
- }
- };
-
private final ConcurrentMap<XdsLocality, ClientLoadCounter> localityLoadCounters;
// Cluster level dropped request counts for each category specified in the DropOverload policy.
private final ConcurrentMap<String, AtomicLong> dropCounters;
@@ -156,26 +140,4 @@
}
counter.getAndIncrement();
}
-
- @Override
- public PickResult interceptPickResult(PickResult pickResult, XdsLocality locality) {
- if (!pickResult.getStatus().isOk()) {
- return pickResult;
- }
- if (pickResult.getSubchannel() == null) {
- return pickResult;
- }
- ClientLoadCounter counter = localityLoadCounters.get(locality);
- if (counter == null) {
- // TODO (chengyuanzhang): this should not happen if this method is called in a correct
- // order with other methods in this class, but we might want to have some logs or warnings.
- return pickResult;
- }
- ClientStreamTracer.Factory originFactory = pickResult.getStreamTracerFactory();
- if (originFactory == null) {
- originFactory = NOOP_CLIENT_STREAM_TRACER_FACTORY;
- }
- XdsClientLoadRecorder recorder = new XdsClientLoadRecorder(counter, originFactory);
- return PickResult.withSubchannel(pickResult.getSubchannel(), recorder);
- }
}
diff --git a/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java b/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
index 69ac5ef..3204c6a 100644
--- a/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
+++ b/xds/src/test/java/io/grpc/xds/ClientLoadCounterTest.java
@@ -17,16 +17,30 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
import io.envoyproxy.udpa.data.orca.v1.OrcaLoadReport;
import io.grpc.ClientStreamTracer;
+import io.grpc.ClientStreamTracer.Factory;
import io.grpc.ClientStreamTracer.StreamInfo;
+import io.grpc.LoadBalancer.PickResult;
+import io.grpc.LoadBalancer.PickSubchannelArgs;
+import io.grpc.LoadBalancer.Subchannel;
+import io.grpc.LoadBalancer.SubchannelPicker;
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.xds.ClientLoadCounter.ClientLoadSnapshot;
-import io.grpc.xds.ClientLoadCounter.LocalityMetricsListener;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingStreamTracerFactory;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingSubchannelPicker;
import io.grpc.xds.ClientLoadCounter.MetricValue;
-import io.grpc.xds.ClientLoadCounter.XdsClientLoadRecorder;
+import io.grpc.xds.ClientLoadCounter.MetricsObservingSubchannelPicker;
+import io.grpc.xds.ClientLoadCounter.MetricsRecordingListener;
+import io.grpc.xds.ClientLoadCounter.TracerWrappingSubchannelPicker;
+import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import java.util.concurrent.ThreadLocalRandom;
import org.junit.Before;
import org.junit.Test;
@@ -120,28 +134,29 @@
}
@Test
- public void xdsClientLoadRecorder_clientSideQueryCountsAggregation() {
- XdsClientLoadRecorder recorder1 =
- new XdsClientLoadRecorder(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
- ClientStreamTracer tracer = recorder1.newClientStreamTracer(STREAM_INFO, new Metadata());
+ public void loadRecordingStreamTracerFactory_clientSideQueryCountsAggregation() {
+ LoadRecordingStreamTracerFactory factory1 =
+ new LoadRecordingStreamTracerFactory(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
+ ClientStreamTracer tracer = factory1.newClientStreamTracer(STREAM_INFO, new Metadata());
ClientLoadSnapshot snapshot = counter.snapshot();
assertQueryCounts(snapshot, 0, 1, 0, 1);
tracer.streamClosed(Status.OK);
snapshot = counter.snapshot();
assertQueryCounts(snapshot, 1, 0, 0, 0);
- // Create a second XdsClientLoadRecorder with the same counter, stats are aggregated together.
- XdsClientLoadRecorder recorder2 =
- new XdsClientLoadRecorder(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
- recorder1.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.ABORTED);
- recorder2.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.CANCELLED);
+ // Create a second LoadRecordingStreamTracerFactory with the same counter, stats are aggregated
+ // together.
+ LoadRecordingStreamTracerFactory factory2 =
+ new LoadRecordingStreamTracerFactory(counter, NOOP_CLIENT_STREAM_TRACER_FACTORY);
+ factory1.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.ABORTED);
+ factory2.newClientStreamTracer(STREAM_INFO, new Metadata()).streamClosed(Status.CANCELLED);
snapshot = counter.snapshot();
assertQueryCounts(snapshot, 0, 0, 2, 2);
}
@Test
- public void metricListener_backendMetricsAggregation() {
- LocalityMetricsListener listener1 = new LocalityMetricsListener(counter);
+ public void metricsRecordingListener_backendMetricsAggregation() {
+ MetricsRecordingListener listener1 = new MetricsRecordingListener(counter);
OrcaLoadReport report =
OrcaLoadReport.newBuilder()
.setCpuUtilization(0.5345)
@@ -174,7 +189,7 @@
snapshot = counter.snapshot();
assertThat(snapshot.getMetricValues()).isEmpty();
- LocalityMetricsListener listener2 = new LocalityMetricsListener(counter);
+ MetricsRecordingListener listener2 = new MetricsRecordingListener(counter);
report =
OrcaLoadReport.newBuilder()
.setCpuUtilization(0.3423)
@@ -199,6 +214,130 @@
assertThat(namedMetric.getTotalValue()).isEqualTo(3534.0 + 3534.0);
}
+ @Test
+ public void tracerWrappingSubchannelPicker_interceptPickResult_invalidPickResultNotIntercepted() {
+ final SubchannelPicker picker = mock(SubchannelPicker.class);
+ SubchannelPicker streamInstrSubchannelPicker = new TracerWrappingSubchannelPicker() {
+ @Override
+ protected SubchannelPicker delegate() {
+ return picker;
+ }
+
+ @Override
+ protected Factory wrapTracerFactory(Factory originFactory) {
+ // NO-OP
+ return originFactory;
+ }
+ };
+ PickResult errorResult = PickResult.withError(Status.UNAVAILABLE.withDescription("Error"));
+ PickResult droppedResult = PickResult.withDrop(Status.UNAVAILABLE.withDescription("Dropped"));
+ PickResult emptyResult = PickResult.withNoResult();
+ when(picker.pickSubchannel(any(PickSubchannelArgs.class)))
+ .thenReturn(errorResult, droppedResult, emptyResult);
+ PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+ PickResult interceptedErrorResult = streamInstrSubchannelPicker.pickSubchannel(args);
+ PickResult interceptedDroppedResult = streamInstrSubchannelPicker.pickSubchannel(args);
+ PickResult interceptedEmptyResult = streamInstrSubchannelPicker.pickSubchannel(args);
+ assertThat(interceptedErrorResult).isSameInstanceAs(errorResult);
+ assertThat(interceptedDroppedResult).isSameInstanceAs(droppedResult);
+ assertThat(interceptedEmptyResult).isSameInstanceAs(emptyResult);
+ }
+
+ @Test
+ public void loadRecordingSubchannelPicker_interceptPickResult_applyLoadRecorderToPickResult() {
+ ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
+ ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
+ when(mockFactory
+ .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
+ .thenReturn(mockTracer);
+
+ ClientLoadCounter localityCounter1 = new ClientLoadCounter();
+ ClientLoadCounter localityCounter2 = new ClientLoadCounter();
+ final PickResult pickResult1 = PickResult.withSubchannel(mock(Subchannel.class), mockFactory);
+ final PickResult pickResult2 = PickResult.withSubchannel(mock(Subchannel.class));
+ SubchannelPicker picker1 = new SubchannelPicker() {
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ return pickResult1;
+ }
+ };
+ SubchannelPicker picker2 = new SubchannelPicker() {
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ return pickResult2;
+ }
+ };
+ SubchannelPicker loadRecordingPicker1 =
+ new LoadRecordingSubchannelPicker(localityCounter1, picker1);
+ SubchannelPicker loadRecordingPicker2 =
+ new LoadRecordingSubchannelPicker(localityCounter2, picker2);
+ PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+ PickResult interceptedPickResult1 = loadRecordingPicker1.pickSubchannel(args);
+ PickResult interceptedPickResult2 = loadRecordingPicker2.pickSubchannel(args);
+
+ LoadRecordingStreamTracerFactory recorder1 =
+ (LoadRecordingStreamTracerFactory) interceptedPickResult1.getStreamTracerFactory();
+ LoadRecordingStreamTracerFactory recorder2 =
+ (LoadRecordingStreamTracerFactory) interceptedPickResult2.getStreamTracerFactory();
+ assertThat(recorder1.getCounter()).isSameInstanceAs(localityCounter1);
+ assertThat(recorder2.getCounter()).isSameInstanceAs(localityCounter2);
+
+ // Stream tracing is propagated to downstream tracers, which preserves PickResult's original
+ // tracing functionality.
+ Metadata metadata = new Metadata();
+ interceptedPickResult1.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, metadata)
+ .streamClosed(Status.OK);
+ verify(mockFactory).newClientStreamTracer(same(STREAM_INFO), same(metadata));
+ verify(mockTracer).streamClosed(Status.OK);
+ }
+
+ @Test
+ public void metricsObservingSubchannelPicker_interceptPickResult_applyOrcaListenerToPickResult() {
+ ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
+ ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
+ when(mockFactory
+ .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
+ .thenReturn(mockTracer);
+
+ final PickResult pickResult1 = PickResult.withSubchannel(mock(Subchannel.class), mockFactory);
+ final PickResult pickResult2 = PickResult.withSubchannel(mock(Subchannel.class));
+ SubchannelPicker picker1 = new SubchannelPicker() {
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ return pickResult1;
+ }
+ };
+ SubchannelPicker picker2 = new SubchannelPicker() {
+ @Override
+ public PickResult pickSubchannel(PickSubchannelArgs args) {
+ return pickResult2;
+ }
+ };
+ OrcaPerRequestUtil orcaPerRequestUtil = mock(OrcaPerRequestUtil.class);
+ ClientStreamTracer.Factory metricsRecorder1 = mock(ClientStreamTracer.Factory.class);
+ ClientStreamTracer.Factory metricsRecorder2 = mock(ClientStreamTracer.Factory.class);
+ when(orcaPerRequestUtil.newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class),
+ any(OrcaPerRequestReportListener.class))).thenReturn(metricsRecorder1, metricsRecorder2);
+ OrcaPerRequestReportListener listener1 = mock(OrcaPerRequestReportListener.class);
+ OrcaPerRequestReportListener listener2 = mock(OrcaPerRequestReportListener.class);
+ PickSubchannelArgs args = mock(PickSubchannelArgs.class);
+
+ SubchannelPicker metricsObservingPicker1 =
+ new MetricsObservingSubchannelPicker(listener1, picker1, orcaPerRequestUtil);
+ SubchannelPicker metricsObservingPicker2 =
+ new MetricsObservingSubchannelPicker(listener2, picker2, orcaPerRequestUtil);
+ PickResult interceptedPickResult1 = metricsObservingPicker1.pickSubchannel(args);
+ PickResult interceptedPickResult2 = metricsObservingPicker2.pickSubchannel(args);
+
+ verify(orcaPerRequestUtil)
+ .newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class), same(listener1));
+ verify(orcaPerRequestUtil)
+ .newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class), same(listener2));
+ assertThat(interceptedPickResult1.getStreamTracerFactory()).isSameInstanceAs(metricsRecorder1);
+ assertThat(interceptedPickResult2.getStreamTracerFactory()).isSameInstanceAs(metricsRecorder2);
+ }
+
private void assertQueryCounts(ClientLoadSnapshot snapshot,
long callsSucceeded,
long callsInProgress,
diff --git a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java
index dd4ae37..910dfa7 100644
--- a/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java
+++ b/xds/src/test/java/io/grpc/xds/LocalityStoreTest.java
@@ -20,24 +20,24 @@
import static io.grpc.ConnectivityState.CONNECTING;
import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.READY;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.AdditionalAnswers.delegatesTo;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.same;
-import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
import io.grpc.ChannelLogger;
+import io.grpc.ClientStreamTracer;
import io.grpc.ConnectivityState;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer;
@@ -51,9 +51,12 @@
import io.grpc.LoadBalancerProvider;
import io.grpc.LoadBalancerRegistry;
import io.grpc.SynchronizationContext;
+import io.grpc.xds.ClientLoadCounter.LoadRecordingStreamTracerFactory;
+import io.grpc.xds.ClientLoadCounter.MetricsRecordingListener;
import io.grpc.xds.InterLocalityPicker.WeightedChildPicker;
import io.grpc.xds.LocalityStore.LocalityStoreImpl;
import io.grpc.xds.LocalityStore.LocalityStoreImpl.PickerFactory;
+import io.grpc.xds.OrcaPerRequestUtil.OrcaPerRequestReportListener;
import io.grpc.xds.XdsComms.DropOverload;
import io.grpc.xds.XdsComms.LbEndpoint;
import io.grpc.xds.XdsComms.LocalityInfo;
@@ -62,6 +65,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import javax.annotation.Nullable;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -176,7 +180,9 @@
@Mock
private ThreadSafeRandom random;
@Mock
- private StatsStore statsStore;
+ private OrcaPerRequestUtil orcaPerRequestUtil;
+ private final FakeLoadStatsStore fakeLoadStatsStore = new FakeLoadStatsStore();
+ private StatsStore statsStore = mock(StatsStore.class, delegatesTo(fakeLoadStatsStore));
private LocalityStore localityStore;
@@ -185,10 +191,10 @@
doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger();
doReturn(mock(Subchannel.class)).when(helper).createSubchannel(any(CreateSubchannelArgs.class));
doReturn(syncContext).when(helper).getSynchronizationContext();
- doAnswer(returnsFirstArg())
- .when(statsStore).interceptPickResult(any(PickResult.class), any(XdsLocality.class));
lbRegistry.register(lbProvider);
- localityStore = new LocalityStoreImpl(helper, pickerFactory, lbRegistry, random, statsStore);
+ localityStore =
+ new LocalityStoreImpl(helper, pickerFactory, lbRegistry, random, statsStore,
+ orcaPerRequestUtil);
}
@Test
@@ -218,11 +224,10 @@
localityStore.updateLocalityStore(Collections.EMPTY_MAP);
verify(statsStore).removeLocality(locality4);
- verifyNoMoreInteractions(statsStore);
}
@Test
- public void updateLocalityStore_interceptPickResultUponPickReadySubchannel() {
+ public void updateLocalityStore_pickResultInterceptedForLoadRecordingWhenSubchannelReady() {
// Simulate receiving two localities.
LocalityInfo localityInfo1 =
new LocalityInfo(ImmutableList.of(lbEndpoint11, lbEndpoint12), 1);
@@ -235,15 +240,23 @@
assertThat(loadBalancers).hasSize(2);
assertThat(pickerFactory.totalReadyLocalities).isEqualTo(0);
- // Simulate picker updates for each of the two localities with dummy pickers.
- final PickResult result1 = PickResult.withNoResult();
- final PickResult result2 = PickResult.withNoResult();
+ ClientStreamTracer.Factory metricsTracingFactory1 = mock(ClientStreamTracer.Factory.class);
+ ClientStreamTracer.Factory metricsTracingFactory2 = mock(ClientStreamTracer.Factory.class);
+ when(orcaPerRequestUtil.newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class),
+ any(OrcaPerRequestReportListener.class)))
+ .thenReturn(metricsTracingFactory1, metricsTracingFactory2);
+
+ final PickResult result1 = PickResult.withSubchannel(mock(Subchannel.class));
+ final PickResult result2 =
+ PickResult.withSubchannel(mock(Subchannel.class), mock(ClientStreamTracer.Factory.class));
SubchannelPicker subchannelPicker1 = mock(SubchannelPicker.class);
SubchannelPicker subchannelPicker2 = mock(SubchannelPicker.class);
when(subchannelPicker1.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(result1);
when(subchannelPicker2.pickSubchannel(any(PickSubchannelArgs.class)))
.thenReturn(result2);
+
+ // Simulate picker updates for the two localities with dummy pickers.
childHelpers.get("sz1").updateBalancingState(READY, subchannelPicker1);
childHelpers.get("sz2").updateBalancingState(READY, subchannelPicker2);
@@ -251,12 +264,31 @@
ArgumentCaptor<SubchannelPicker> interLocalityPickerCaptor = ArgumentCaptor.forClass(null);
verify(helper, times(2)).updateBalancingState(eq(READY), interLocalityPickerCaptor.capture());
SubchannelPicker interLocalityPicker = interLocalityPickerCaptor.getValue();
+
+ // Verify each PickResult picked is intercepted with client stream tracer factory for
+ // recording load and backend metrics.
+ List<XdsLocality> localities = ImmutableList.of(locality1, locality2);
+ List<ClientStreamTracer.Factory> metricsTracingFactories =
+ ImmutableList.of(metricsTracingFactory1, metricsTracingFactory2);
for (int i = 0; i < pickerFactory.totalReadyLocalities; i++) {
pickerFactory.nextIndex = i;
- interLocalityPicker.pickSubchannel(pickSubchannelArgs);
+ PickResult pickResult = interLocalityPicker.pickSubchannel(pickSubchannelArgs);
+ ArgumentCaptor<OrcaPerRequestReportListener> listenerCaptor = ArgumentCaptor.forClass(null);
+ verify(orcaPerRequestUtil, times(i + 1))
+ .newOrcaClientStreamTracerFactory(any(ClientStreamTracer.Factory.class),
+ listenerCaptor.capture());
+ assertThat(listenerCaptor.getValue()).isInstanceOf(MetricsRecordingListener.class);
+ MetricsRecordingListener listener = (MetricsRecordingListener) listenerCaptor.getValue();
+ assertThat(listener.getCounter())
+ .isSameInstanceAs(fakeLoadStatsStore.localityCounters.get(localities.get(i)));
+ assertThat(pickResult.getStreamTracerFactory())
+ .isInstanceOf(LoadRecordingStreamTracerFactory.class);
+ LoadRecordingStreamTracerFactory loadRecordingFactory =
+ (LoadRecordingStreamTracerFactory) pickResult.getStreamTracerFactory();
+ assertThat(loadRecordingFactory.getCounter())
+ .isSameInstanceAs(fakeLoadStatsStore.localityCounters.get(localities.get(i)));
+ assertThat(loadRecordingFactory.delegate()).isSameInstanceAs(metricsTracingFactories.get(i));
}
- verify(statsStore).interceptPickResult(same(result1), eq(locality1));
- verify(statsStore).interceptPickResult(same(result2), eq(locality2));
}
@Test
@@ -510,4 +542,37 @@
verify(statsStore).removeLocality(locality1);
verify(statsStore).removeLocality(locality2);
}
+
+ private static final class FakeLoadStatsStore implements StatsStore {
+
+ Map<XdsLocality, ClientLoadCounter> localityCounters = new HashMap<>();
+
+ @Override
+ public ClusterStats generateLoadReport() {
+ throw new AssertionError("Should not be called");
+ }
+
+ @Override
+ public void addLocality(XdsLocality locality) {
+ assertThat(localityCounters).doesNotContainKey(locality);
+ localityCounters.put(locality, new ClientLoadCounter());
+ }
+
+ @Override
+ public void removeLocality(XdsLocality locality) {
+ assertThat(localityCounters).containsKey(locality);
+ localityCounters.remove(locality);
+ }
+
+ @Nullable
+ @Override
+ public ClientLoadCounter getLocalityCounter(XdsLocality locality) {
+ return localityCounters.get(locality);
+ }
+
+ @Override
+ public void recordDroppedRequest(String category) {
+ // NO-OP, verify by invocations.
+ }
+ }
}
diff --git a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java
index 13ca628..9a9429b 100644
--- a/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsLbStateTest.java
@@ -17,9 +17,6 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
@@ -32,7 +29,6 @@
import io.grpc.ChannelLogger;
import io.grpc.EquivalentAddressGroup;
import io.grpc.LoadBalancer.Helper;
-import io.grpc.LoadBalancer.PickResult;
import io.grpc.ManagedChannel;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessChannelBuilder;
@@ -63,8 +59,6 @@
@Mock
private AdsStreamCallback adsStreamCallback;
@Mock
- private StatsStore statsStore;
- @Mock
private LocalityStore localityStore;
private final FakeClock fakeClock = new FakeClock();
@@ -89,8 +83,6 @@
doReturn(fakeClock.getScheduledExecutorService()).when(helper).getScheduledExecutorService();
doReturn("fake_authority").when(helper).getAuthority();
doReturn(mock(ChannelLogger.class)).when(helper).getChannelLogger();
- doAnswer(returnsFirstArg())
- .when(statsStore).interceptPickResult(any(PickResult.class), any(XdsLocality.class));
String serverName = InProcessServerBuilder.generateName();
diff --git a/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java b/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
index 074b563..a18763c 100644
--- a/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
+++ b/xds/src/test/java/io/grpc/xds/XdsLoadStatsStoreTest.java
@@ -17,22 +17,12 @@
package io.grpc.xds;
import static com.google.common.truth.Truth.assertThat;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.same;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
import com.google.common.collect.ImmutableMap;
import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats;
import io.envoyproxy.envoy.api.v2.endpoint.ClusterStats.DroppedRequests;
import io.envoyproxy.envoy.api.v2.endpoint.EndpointLoadMetricStats;
import io.envoyproxy.envoy.api.v2.endpoint.UpstreamLocalityStats;
-import io.grpc.ClientStreamTracer;
-import io.grpc.LoadBalancer.PickResult;
-import io.grpc.LoadBalancer.Subchannel;
-import io.grpc.Metadata;
-import io.grpc.Status;
import io.grpc.xds.ClientLoadCounter.MetricValue;
import java.util.ArrayList;
import java.util.Arrays;
@@ -57,9 +47,6 @@
new XdsLocality("test_region1", "test_zone", "test_subzone");
private static final XdsLocality LOCALITY2 =
new XdsLocality("test_region2", "test_zone", "test_subzone");
- private static final ClientStreamTracer.StreamInfo STREAM_INFO =
- ClientStreamTracer.StreamInfo.newBuilder().build();
- private Subchannel mockSubchannel = mock(Subchannel.class);
private ConcurrentMap<XdsLocality, ClientLoadCounter> localityLoadCounters;
private ConcurrentMap<String, AtomicLong> dropCounters;
private XdsLoadStatsStore loadStore;
@@ -286,44 +273,4 @@
assertThat(dropCounters.get("lb").get()).isEqualTo(0);
assertThat(dropCounters.get("throttle").get()).isEqualTo(0);
}
-
- @Test
- public void loadNotRecordedForUntrackedLocality() {
- PickResult pickResult = PickResult.withSubchannel(mockSubchannel);
- // If the per-locality counter does not exist, nothing should happen.
- PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, LOCALITY1);
- assertThat(interceptedPickResult.getStreamTracerFactory()).isNull();
- }
-
- @Test
- public void invalidPickResultNotIntercepted() {
- localityLoadCounters.put(LOCALITY1, new ClientLoadCounter());
- PickResult errorResult = PickResult.withError(Status.UNAVAILABLE.withDescription("Error"));
- PickResult droppedResult = PickResult.withDrop(Status.UNAVAILABLE.withDescription("Dropped"));
- PickResult emptyResult = PickResult.withNoResult();
- PickResult interceptedErrorResult = loadStore.interceptPickResult(errorResult, LOCALITY1);
- PickResult interceptedDroppedResult =
- loadStore.interceptPickResult(droppedResult, LOCALITY1);
- PickResult interceptedEmptyResult = loadStore.interceptPickResult(emptyResult, LOCALITY1);
- assertThat(interceptedErrorResult).isSameInstanceAs(errorResult);
- assertThat(interceptedDroppedResult).isSameInstanceAs(droppedResult);
- assertThat(interceptedEmptyResult).isSameInstanceAs(emptyResult);
- }
-
- @Test
- public void interceptPreservesOriginStreamTracer() {
- ClientStreamTracer.Factory mockFactory = mock(ClientStreamTracer.Factory.class);
- ClientStreamTracer mockTracer = mock(ClientStreamTracer.class);
- when(mockFactory
- .newClientStreamTracer(any(ClientStreamTracer.StreamInfo.class), any(Metadata.class)))
- .thenReturn(mockTracer);
- localityLoadCounters.put(LOCALITY1, new ClientLoadCounter());
- PickResult pickResult = PickResult.withSubchannel(mockSubchannel, mockFactory);
- PickResult interceptedPickResult = loadStore.interceptPickResult(pickResult, LOCALITY1);
- Metadata metadata = new Metadata();
- interceptedPickResult.getStreamTracerFactory().newClientStreamTracer(STREAM_INFO, metadata)
- .streamClosed(Status.OK);
- verify(mockFactory).newClientStreamTracer(same(STREAM_INFO), same(metadata));
- verify(mockTracer).streamClosed(Status.OK);
- }
}