interop-testing: support dynamic configuration and accumulated stats for xDS test client (#7549)

diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
index 73af159..aba91dc 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
@@ -40,6 +40,11 @@
 import io.grpc.netty.NettyChannelBuilder;
 import io.grpc.netty.NettyServerBuilder;
 import io.grpc.stub.StreamObserver;
+import io.grpc.testing.integration.Messages.ClientConfigureRequest;
+import io.grpc.testing.integration.Messages.ClientConfigureRequest.RpcType;
+import io.grpc.testing.integration.Messages.ClientConfigureResponse;
+import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsRequest;
+import io.grpc.testing.integration.Messages.LoadBalancerAccumulatedStatsResponse;
 import io.grpc.testing.integration.Messages.LoadBalancerStatsRequest;
 import io.grpc.testing.integration.Messages.LoadBalancerStatsResponse;
 import io.grpc.testing.integration.Messages.SimpleRequest;
@@ -55,6 +60,7 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.logging.Level;
 import java.util.logging.Logger;
@@ -67,12 +73,14 @@
   private final Set<XdsStatsWatcher> watchers = new HashSet<>();
   private final Object lock = new Object();
   private final List<ManagedChannel> channels = new ArrayList<>();
+  private final AtomicInteger rpcsStarted = new AtomicInteger();
+  private final AtomicInteger rpcsFailed = new AtomicInteger();
+  private final AtomicInteger rpcsSucceeded = new AtomicInteger();
 
   private int numChannels = 1;
   private boolean printResponse = false;
   private int qps = 1;
-  private List<RpcType> rpcTypes = ImmutableList.of(RpcType.UNARY_CALL);
-  private EnumMap<RpcType, Metadata> metadata = new EnumMap<>(RpcType.class);
+  private volatile RpcConfig rpcConfig;
   private int rpcTimeoutSec = 20;
   private String server = "localhost:8080";
   private int statsPort = 8081;
@@ -80,15 +88,6 @@
   private long currentRequestId;
   private ListeningScheduledExecutorService exec;
 
-  private enum RpcType {
-    EMPTY_CALL,
-    UNARY_CALL;
-
-    public String toCamelCase() {
-      return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, toString());
-    }
-  }
-
   /**
    * The main application allowing this client to be launched from the command line.
    */
@@ -113,6 +112,8 @@
 
   private void parseArgs(String[] args) {
     boolean usage = false;
+    List<RpcType> rpcTypes = ImmutableList.of(RpcType.UNARY_CALL);
+    EnumMap<RpcType, Metadata> metadata = new EnumMap<>(RpcType.class);
     for (String arg : args) {
       if (!arg.startsWith("--")) {
         System.err.println("All arguments must start with '--': " + arg);
@@ -153,6 +154,7 @@
         break;
       }
     }
+    rpcConfig = new RpcConfig(rpcTypes, metadata);
 
     if (usage) {
       XdsTestClient c = new XdsTestClient();
@@ -167,8 +169,10 @@
               + c.qps
               + "\n  --rpc=STR              Types of RPCs to make, ',' separated string. RPCs can "
               + "be EmptyCall or UnaryCall. Default: UnaryCall"
+              + "\n[deprecated] Use XdsUpdateClientConfigureService"
               + "\n  --metadata=STR         The metadata to send with each RPC, in the format "
               + "EmptyCall:key1:value1,UnaryCall:key2:value2."
+              + "\n[deprecated] Use XdsUpdateClientConfigureService"
               + "\n  --rpc_timeout_sec=INT  Per RPC timeout seconds. Default: "
               + c.rpcTimeoutSec
               + "\n  --server=host:port     Address of server. Default: "
@@ -220,7 +224,11 @@
   }
 
   private void run() {
-    statsServer = NettyServerBuilder.forPort(statsPort).addService(new XdsStatsImpl()).build();
+    statsServer =
+        NettyServerBuilder.forPort(statsPort)
+            .addService(new XdsStatsImpl())
+            .addService(new ConfigureUpdateServiceImpl())
+            .build();
     try {
       statsServer.start();
       for (int i = 0; i < numChannels; i++) {
@@ -253,14 +261,20 @@
   private void runQps() throws InterruptedException, ExecutionException {
     final SettableFuture<Void> failure = SettableFuture.create();
     final class PeriodicRpc implements Runnable {
-      private final RpcType rpcType;
-
-      private PeriodicRpc(RpcType rpcType) {
-        this.rpcType = rpcType;
-      }
 
       @Override
       public void run() {
+        RpcConfig config = rpcConfig;
+        for (RpcType type : config.rpcTypes) {
+          Metadata headers = config.metadata.get(type);
+          if (headers == null)  {
+            headers = new Metadata();
+          }
+          makeRpc(type, headers);
+        }
+      }
+
+      private void makeRpc(final RpcType rpcType, final Metadata headersToSend) {
         final long requestId;
         final Set<XdsStatsWatcher> savedWatchers = new HashSet<>();
         synchronized (lock) {
@@ -269,12 +283,6 @@
           savedWatchers.addAll(watchers);
         }
 
-        final Metadata headersToSend;
-        if (metadata.containsKey(rpcType)) {
-          headersToSend = metadata.get(rpcType);
-        } else {
-          headersToSend = new Metadata();
-        }
         ManagedChannel channel = channels.get((int) (requestId % channels.size()));
         TestServiceGrpc.TestServiceStub stub = TestServiceGrpc.newStub(channel);
         final AtomicReference<ClientCall<?, ?>> clientCallRef = new AtomicReference<>();
@@ -314,17 +322,18 @@
               new StreamObserver<EmptyProtos.Empty>() {
                 @Override
                 public void onCompleted() {
-                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+                  handleRpcCompleted(requestId, rpcType, hostnameRef.get(), savedWatchers);
                 }
 
                 @Override
                 public void onError(Throwable t) {
-                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+                  handleRpcError(requestId, rpcType, hostnameRef.get(), savedWatchers);
                 }
 
                 @Override
                 public void onNext(EmptyProtos.Empty response) {}
               });
+          rpcsStarted.getAndIncrement();
         } else if (rpcType == RpcType.UNARY_CALL) {
           SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
           stub.unaryCall(
@@ -332,7 +341,7 @@
               new StreamObserver<SimpleResponse>() {
                 @Override
                 public void onCompleted() {
-                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+                  handleRpcCompleted(requestId, rpcType, hostnameRef.get(), savedWatchers);
                 }
 
                 @Override
@@ -340,7 +349,7 @@
                   if (printResponse) {
                     logger.log(Level.WARNING, "Rpc failed: {0}", t);
                   }
-                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+                  handleRpcError(requestId, rpcType, hostnameRef.get(), savedWatchers);
                 }
 
                 @Override
@@ -364,31 +373,39 @@
                   }
                 }
               });
+          rpcsStarted.getAndIncrement();
         }
       }
+
+      private void handleRpcCompleted(long requestId, RpcType rpcType, String hostname,
+          Set<XdsStatsWatcher> watchers) {
+        rpcsSucceeded.getAndIncrement();
+        notifyWatchers(watchers, rpcType, requestId, hostname);
+      }
+
+      private void handleRpcError(long requestId, RpcType rpcType, String hostname,
+          Set<XdsStatsWatcher> watchers) {
+        rpcsFailed.getAndIncrement();
+        notifyWatchers(watchers, rpcType, requestId, hostname);
+      }
     }
 
     long nanosPerQuery = TimeUnit.SECONDS.toNanos(1) / qps;
+    ListenableScheduledFuture<?> future =
+        exec.scheduleAtFixedRate(new PeriodicRpc(), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
+    Futures.addCallback(
+        future,
+        new FutureCallback<Object>() {
 
-    for (RpcType rpcType : rpcTypes) {
-      ListenableScheduledFuture<?> future =
-          exec.scheduleAtFixedRate(
-              new PeriodicRpc(rpcType), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
+          @Override
+          public void onFailure(Throwable t) {
+            failure.setException(t);
+          }
 
-      Futures.addCallback(
-          future,
-          new FutureCallback<Object>() {
-
-            @Override
-            public void onFailure(Throwable t) {
-              failure.setException(t);
-            }
-
-            @Override
-            public void onSuccess(Object o) {}
-          },
-          MoreExecutors.directExecutor());
-    }
+          @Override
+          public void onSuccess(Object o) {}
+        },
+        MoreExecutors.directExecutor());
 
     failure.get();
   }
@@ -400,6 +417,22 @@
     }
   }
 
+  private final class ConfigureUpdateServiceImpl extends
+      XdsUpdateClientConfigureServiceGrpc.XdsUpdateClientConfigureServiceImplBase {
+    @Override
+    public void configure(ClientConfigureRequest request,
+        StreamObserver<ClientConfigureResponse> responseObserver) {
+      EnumMap<RpcType, Metadata> newMetadata = new EnumMap<>(RpcType.class);
+      for (ClientConfigureRequest.Metadata metadata : request.getMetadataList()) {
+        Metadata md = new Metadata();
+        md.put(Metadata.Key.of(metadata.getKey(), Metadata.ASCII_STRING_MARSHALLER),
+            metadata.getValue());
+        newMetadata.put(metadata.getType(), md);
+      }
+      rpcConfig = new RpcConfig(request.getTypesList(), newMetadata);
+    }
+  }
+
   private class XdsStatsImpl extends LoadBalancerStatsServiceGrpc.LoadBalancerStatsServiceImplBase {
     @Override
     public void getClientStats(
@@ -418,6 +451,29 @@
       responseObserver.onNext(response);
       responseObserver.onCompleted();
     }
+
+    @Override
+    public void getClientAccumulatedStats(LoadBalancerAccumulatedStatsRequest request,
+        StreamObserver<LoadBalancerAccumulatedStatsResponse> responseObserver) {
+      responseObserver.onNext(
+          LoadBalancerAccumulatedStatsResponse.newBuilder()
+              .setNumRpcsStarted(rpcsStarted.get())
+              .setNumRpcsSucceeded(rpcsSucceeded.get())
+              .setNumRpcsFailed(rpcsFailed.get())
+              .build());
+      responseObserver.onCompleted();
+    }
+  }
+
+  /** RPC configurations that can be dynamically updated. */
+  private static final class RpcConfig {
+    private final List<RpcType> rpcTypes;
+    private final EnumMap<RpcType, Metadata> metadata;
+
+    private RpcConfig(List<RpcType> rpcTypes, EnumMap<RpcType, Metadata> metadata) {
+      this.rpcTypes = rpcTypes;
+      this.metadata = metadata;
+    }
   }
 
   /** Records the remote peer distribution for a given range of RPCs. */
@@ -484,11 +540,15 @@
           LoadBalancerStatsResponse.RpcsByPeer.Builder rpcs =
               LoadBalancerStatsResponse.RpcsByPeer.newBuilder();
           rpcs.putAllRpcsByPeer(entry.getValue());
-          builder.putRpcsByMethod(entry.getKey().toCamelCase(), rpcs.build());
+          builder.putRpcsByMethod(getRpcTypeString(entry.getKey()), rpcs.build());
         }
         builder.setNumFailures(noRemotePeer + (int) latch.getCount());
       }
       return builder.build();
     }
+
+    private static String getRpcTypeString(RpcType rpcType) {
+      return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, rpcType.name());
+    }
   }
 }