interop-testing: add flags to xds test client

diff --git a/buildscripts/kokoro/xds.sh b/buildscripts/kokoro/xds.sh
index 9d66245..f911f25 100755
--- a/buildscripts/kokoro/xds.sh
+++ b/buildscripts/kokoro/xds.sh
@@ -20,9 +20,15 @@
 git clone -b "${branch}" --single-branch --depth=1 https://github.com/grpc/grpc.git
 
 grpc/tools/run_tests/helper_scripts/prep_xds.sh
+
+# Test cases "path_matching" and "header_matching" are not included in "all",
+# because not all interop clients in all languages support these new tests.
+#
+# TODO(ericgribkoff): remove "path_matching" and "header_matching" from
+# --test_case after they are added into "all".
 JAVA_OPTS=-Djava.util.logging.config.file=grpc-java/buildscripts/xds_logging.properties \
   python3 grpc/tools/run_tests/run_xds_tests.py \
-    --test_case=all \
+    --test_case="all,path_matching,header_matching" \
     --project_id=grpc-testing \
     --source_image=projects/grpc-testing/global/images/xds-test-server \
     --path_to_server_binary=/java_server/grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-server \
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
index 0bd2fbe..73af159 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
@@ -16,6 +16,9 @@
 
 package io.grpc.testing.integration;
 
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
 import com.google.common.primitives.Ints;
 import com.google.common.util.concurrent.FutureCallback;
 import com.google.common.util.concurrent.Futures;
@@ -24,12 +27,16 @@
 import com.google.common.util.concurrent.MoreExecutors;
 import com.google.common.util.concurrent.SettableFuture;
 import io.grpc.CallOptions;
+import io.grpc.Channel;
 import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
+import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
 import io.grpc.Grpc;
 import io.grpc.ManagedChannel;
 import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
 import io.grpc.Server;
-import io.grpc.Status;
 import io.grpc.netty.NettyChannelBuilder;
 import io.grpc.netty.NettyServerBuilder;
 import io.grpc.stub.StreamObserver;
@@ -38,6 +45,7 @@
 import io.grpc.testing.integration.Messages.SimpleRequest;
 import io.grpc.testing.integration.Messages.SimpleResponse;
 import java.util.ArrayList;
+import java.util.EnumMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -47,6 +55,7 @@
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
 import java.util.logging.Level;
 import java.util.logging.Logger;
 import javax.annotation.Nullable;
@@ -62,13 +71,24 @@
   private int numChannels = 1;
   private boolean printResponse = false;
   private int qps = 1;
-  private int rpcTimeoutSec = 2;
+  private List<RpcType> rpcTypes = ImmutableList.of(RpcType.UNARY_CALL);
+  private EnumMap<RpcType, Metadata> metadata = new EnumMap<>(RpcType.class);
+  private int rpcTimeoutSec = 20;
   private String server = "localhost:8080";
   private int statsPort = 8081;
   private Server statsServer;
   private long currentRequestId;
   private ListeningScheduledExecutorService exec;
 
+  private enum RpcType {
+    EMPTY_CALL,
+    UNARY_CALL;
+
+    public String toCamelCase() {
+      return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, toString());
+    }
+  }
+
   /**
    * The main application allowing this client to be launched from the command line.
    */
@@ -111,12 +131,16 @@
         break;
       }
       String value = parts[1];
-      if ("num_channels".equals(key)) {
+      if ("metadata".equals(key)) {
+        metadata = parseMetadata(value);
+      } else if ("num_channels".equals(key)) {
         numChannels = Integer.valueOf(value);
       } else if ("print_response".equals(key)) {
         printResponse = Boolean.valueOf(value);
       } else if ("qps".equals(key)) {
         qps = Integer.valueOf(value);
+      } else if ("rpc".equals(key)) {
+        rpcTypes = parseRpcs(value);
       } else if ("rpc_timeout_sec".equals(key)) {
         rpcTimeoutSec = Integer.valueOf(value);
       } else if ("server".equals(key)) {
@@ -139,8 +163,12 @@
               + c.numChannels
               + "\n  --print_response=BOOL  Write RPC response to stdout. Default: "
               + c.printResponse
-              + "\n  --qps=INT              Qps per channel. Default: "
+              + "\n  --qps=INT              Qps per channel, for each type of RPC. Default: "
               + c.qps
+              + "\n  --rpc=STR              Types of RPCs to make, ',' separated string. RPCs can "
+              + "be EmptyCall or UnaryCall. Default: UnaryCall"
+              + "\n  --metadata=STR         The metadata to send with each RPC, in the format "
+              + "EmptyCall:key1:value1,UnaryCall:key2:value2."
               + "\n  --rpc_timeout_sec=INT  Per RPC timeout seconds. Default: "
               + c.rpcTimeoutSec
               + "\n  --server=host:port     Address of server. Default: "
@@ -152,6 +180,45 @@
     }
   }
 
+  private static List<RpcType> parseRpcs(String rpcArg) {
+    List<RpcType> rpcs = new ArrayList<>();
+    for (String rpc : Splitter.on(',').split(rpcArg)) {
+      rpcs.add(parseRpc(rpc));
+    }
+    return rpcs;
+  }
+
+  private static EnumMap<RpcType, Metadata> parseMetadata(String metadataArg) {
+    EnumMap<RpcType, Metadata> rpcMetadata = new EnumMap<>(RpcType.class);
+    for (String metadata : Splitter.on(',').omitEmptyStrings().split(metadataArg)) {
+      List<String> parts = Splitter.on(':').splitToList(metadata);
+      if (parts.size() != 3) {
+        throw new IllegalArgumentException("Invalid metadata: '" + metadata + "'");
+      }
+      RpcType rpc = parseRpc(parts.get(0));
+      String key = parts.get(1);
+      String value = parts.get(2);
+      Metadata md = new Metadata();
+      md.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value);
+      if (rpcMetadata.containsKey(rpc)) {
+        rpcMetadata.get(rpc).merge(md);
+      } else {
+        rpcMetadata.put(rpc, md);
+      }
+    }
+    return rpcMetadata;
+  }
+
+  private static RpcType parseRpc(String rpc) {
+    if ("EmptyCall".equals(rpc)) {
+      return RpcType.EMPTY_CALL;
+    } else if ("UnaryCall".equals(rpc)) {
+      return RpcType.UNARY_CALL;
+    } else {
+      throw new IllegalArgumentException("Unknown RPC: '" + rpc + "'");
+    }
+  }
+
   private void run() {
     statsServer = NettyServerBuilder.forPort(statsPort).addService(new XdsStatsImpl()).build();
     try {
@@ -186,6 +253,11 @@
   private void runQps() throws InterruptedException, ExecutionException {
     final SettableFuture<Void> failure = SettableFuture.create();
     final class PeriodicRpc implements Runnable {
+      private final RpcType rpcType;
+
+      private PeriodicRpc(RpcType rpcType) {
+        this.rpcType = rpcType;
+      }
 
       @Override
       public void run() {
@@ -197,69 +269,137 @@
           savedWatchers.addAll(watchers);
         }
 
-        SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
+        final Metadata headersToSend;
+        if (metadata.containsKey(rpcType)) {
+          headersToSend = metadata.get(rpcType);
+        } else {
+          headersToSend = new Metadata();
+        }
         ManagedChannel channel = channels.get((int) (requestId % channels.size()));
-        final ClientCall<SimpleRequest, SimpleResponse> call =
-            channel.newCall(
-                TestServiceGrpc.getUnaryCallMethod(),
-                CallOptions.DEFAULT.withDeadlineAfter(rpcTimeoutSec, TimeUnit.SECONDS));
-        call.start(
-            new ClientCall.Listener<SimpleResponse>() {
-              private String hostname;
+        TestServiceGrpc.TestServiceStub stub = TestServiceGrpc.newStub(channel);
+        final AtomicReference<ClientCall<?, ?>> clientCallRef = new AtomicReference<>();
+        final AtomicReference<String> hostnameRef = new AtomicReference<>();
+        stub =
+            stub.withDeadlineAfter(rpcTimeoutSec, TimeUnit.SECONDS)
+                .withInterceptors(
+                    new ClientInterceptor() {
+                      @Override
+                      public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+                          MethodDescriptor<ReqT, RespT> method,
+                          CallOptions callOptions,
+                          Channel next) {
+                        ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
+                        clientCallRef.set(call);
+                        return new SimpleForwardingClientCall<ReqT, RespT>(call) {
+                          @Override
+                          public void start(Listener<RespT> responseListener, Metadata headers) {
+                            headers.merge(headersToSend);
+                            super.start(
+                                new SimpleForwardingClientCallListener<RespT>(responseListener) {
+                                  @Override
+                                  public void onHeaders(Metadata headers) {
+                                    hostnameRef.set(headers.get(XdsTestServer.HOSTNAME_KEY));
+                                    super.onHeaders(headers);
+                                  }
+                                },
+                                headers);
+                          }
+                        };
+                      }
+                    });
 
-              @Override
-              public void onMessage(SimpleResponse response) {
-                hostname = response.getHostname();
-                // TODO(ericgribkoff) Currently some test environments cannot access the stats RPC
-                // service and rely on parsing stdout.
-                if (printResponse) {
-                  System.out.println(
-                      "Greeting: Hello world, this is "
-                          + hostname
-                          + ", from "
-                          + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
+        if (rpcType == RpcType.EMPTY_CALL) {
+          stub.emptyCall(
+              EmptyProtos.Empty.getDefaultInstance(),
+              new StreamObserver<EmptyProtos.Empty>() {
+                @Override
+                public void onCompleted() {
+                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
                 }
-              }
 
-              @Override
-              public void onClose(Status status, Metadata trailers) {
-                if (printResponse && !status.isOk()) {
-                  logger.log(Level.WARNING, "Greeting RPC failed with status {0}", status);
+                @Override
+                public void onError(Throwable t) {
+                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
                 }
-                for (XdsStatsWatcher watcher : savedWatchers) {
-                  watcher.rpcCompleted(requestId, hostname);
-                }
-              }
-            },
-            new Metadata());
 
-        call.sendMessage(request);
-        call.request(1);
-        call.halfClose();
+                @Override
+                public void onNext(EmptyProtos.Empty response) {}
+              });
+        } else if (rpcType == RpcType.UNARY_CALL) {
+          SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
+          stub.unaryCall(
+              request,
+              new StreamObserver<SimpleResponse>() {
+                @Override
+                public void onCompleted() {
+                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+                }
+
+                @Override
+                public void onError(Throwable t) {
+                  if (printResponse) {
+                    logger.log(Level.WARNING, "Rpc failed: {0}", t);
+                  }
+                  notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+                }
+
+                @Override
+                public void onNext(SimpleResponse response) {
+                  // TODO(ericgribkoff) Currently some test environments cannot access the stats RPC
+                  // service and rely on parsing stdout.
+                  if (printResponse) {
+                    System.out.println(
+                        "Greeting: Hello world, this is "
+                            + response.getHostname()
+                            + ", from "
+                            + clientCallRef
+                                .get()
+                                .getAttributes()
+                                .get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
+                  }
+                  // Use the hostname from the response if not present in the metadata.
+                  // TODO(ericgribkoff) Delete when server is deployed that sets metadata value.
+                  if (hostnameRef.get() == null) {
+                    hostnameRef.set(response.getHostname());
+                  }
+                }
+              });
+        }
       }
     }
 
     long nanosPerQuery = TimeUnit.SECONDS.toNanos(1) / qps;
-    ListenableScheduledFuture<?> future =
-        exec.scheduleAtFixedRate(new PeriodicRpc(), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
 
-    Futures.addCallback(
-        future,
-        new FutureCallback<Object>() {
+    for (RpcType rpcType : rpcTypes) {
+      ListenableScheduledFuture<?> future =
+          exec.scheduleAtFixedRate(
+              new PeriodicRpc(rpcType), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
 
-          @Override
-          public void onFailure(Throwable t) {
-            failure.setException(t);
-          }
+      Futures.addCallback(
+          future,
+          new FutureCallback<Object>() {
 
-          @Override
-          public void onSuccess(Object o) {}
-        },
-        MoreExecutors.directExecutor());
+            @Override
+            public void onFailure(Throwable t) {
+              failure.setException(t);
+            }
+
+            @Override
+            public void onSuccess(Object o) {}
+          },
+          MoreExecutors.directExecutor());
+    }
 
     failure.get();
   }
 
+  private void notifyWatchers(
+      Set<XdsStatsWatcher> watchers, RpcType rpcType, long requestId, String hostname) {
+    for (XdsStatsWatcher watcher : watchers) {
+      watcher.rpcCompleted(rpcType, requestId, hostname);
+    }
+  }
+
   private class XdsStatsImpl extends LoadBalancerStatsServiceGrpc.LoadBalancerStatsServiceImplBase {
     @Override
     public void getClientStats(
@@ -286,6 +426,8 @@
     private final long startId;
     private final long endId;
     private final Map<String, Integer> rpcsByPeer = new HashMap<>();
+    private final EnumMap<RpcType, Map<String, Integer>> rpcsByTypeAndPeer =
+        new EnumMap<>(RpcType.class);
     private final Object lock = new Object();
     private int noRemotePeer;
 
@@ -295,7 +437,7 @@
       this.endId = endId;
     }
 
-    void rpcCompleted(long requestId, @Nullable String hostname) {
+    void rpcCompleted(RpcType rpcType, long requestId, @Nullable String hostname) {
       synchronized (lock) {
         if (startId <= requestId && requestId < endId) {
           if (hostname != null) {
@@ -304,6 +446,19 @@
             } else {
               rpcsByPeer.put(hostname, 1);
             }
+            if (rpcsByTypeAndPeer.containsKey(rpcType)) {
+              if (rpcsByTypeAndPeer.get(rpcType).containsKey(hostname)) {
+                rpcsByTypeAndPeer
+                    .get(rpcType)
+                    .put(hostname, rpcsByTypeAndPeer.get(rpcType).get(hostname) + 1);
+              } else {
+                rpcsByTypeAndPeer.get(rpcType).put(hostname, 1);
+              }
+            } else {
+              Map<String, Integer> rpcMap = new HashMap<>();
+              rpcMap.put(hostname, 1);
+              rpcsByTypeAndPeer.put(rpcType, rpcMap);
+            }
           } else {
             noRemotePeer += 1;
           }
@@ -325,6 +480,12 @@
       LoadBalancerStatsResponse.Builder builder = LoadBalancerStatsResponse.newBuilder();
       synchronized (lock) {
         builder.putAllRpcsByPeer(rpcsByPeer);
+        for (Map.Entry<RpcType, Map<String, Integer>> entry : rpcsByTypeAndPeer.entrySet()) {
+          LoadBalancerStatsResponse.RpcsByPeer.Builder rpcs =
+              LoadBalancerStatsResponse.RpcsByPeer.newBuilder();
+          rpcs.putAllRpcsByPeer(entry.getValue());
+          builder.putRpcsByMethod(entry.getKey().toCamelCase(), rpcs.build());
+        }
         builder.setNumFailures(noRemotePeer + (int) latch.getCount());
       }
       return builder.build();
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java
index a7fcbf7..b9c8e3a 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java
@@ -16,7 +16,13 @@
 
 package io.grpc.testing.integration;
 
+import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
+import io.grpc.Metadata;
 import io.grpc.Server;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.ServerInterceptors;
 import io.grpc.health.v1.HealthCheckResponse.ServingStatus;
 import io.grpc.netty.NettyServerBuilder;
 import io.grpc.protobuf.services.ProtoReflectionService;
@@ -32,12 +38,16 @@
 
 /** Interop test server that implements the xDS testing service. */
 public final class XdsTestServer {
+  static final Metadata.Key<String> HOSTNAME_KEY =
+      Metadata.Key.of("hostname", Metadata.ASCII_STRING_MARSHALLER);
+
   private static Logger logger = Logger.getLogger(XdsTestServer.class.getName());
 
   private int port = 8080;
   private String serverId = "java_server";
   private HealthStatusManager health;
   private Server server;
+  private String host;
 
   /**
    * The main application allowing this client to be launched from the command line.
@@ -111,10 +121,18 @@
   }
 
   private void start() throws Exception {
+    try {
+      host = InetAddress.getLocalHost().getHostName();
+    } catch (UnknownHostException e) {
+      logger.log(Level.SEVERE, "Failed to get host", e);
+      throw new RuntimeException(e);
+    }
     health = new HealthStatusManager();
     server =
         NettyServerBuilder.forPort(port)
-            .addService(new TestServiceImpl(serverId))
+            .addService(
+                ServerInterceptors.intercept(
+                    new TestServiceImpl(serverId, host), new HostnameInterceptor(host)))
             .addService(new XdsUpdateHealthServiceImpl(health))
             .addService(health.getHealthService())
             .addService(ProtoReflectionService.newInstance())
@@ -140,14 +158,16 @@
     private final String serverId;
     private final String host;
 
-    private TestServiceImpl(String serverId) {
+    private TestServiceImpl(String serverId, String host) {
       this.serverId = serverId;
-      try {
-        host = InetAddress.getLocalHost().getHostName();
-      } catch (UnknownHostException e) {
-        logger.log(Level.SEVERE, "Failed to get host", e);
-        throw new RuntimeException(e);
-      }
+      this.host = host;
+    }
+
+    @Override
+    public void emptyCall(
+        EmptyProtos.Empty req, StreamObserver<EmptyProtos.Empty> responseObserver) {
+      responseObserver.onNext(EmptyProtos.Empty.getDefaultInstance());
+      responseObserver.onCompleted();
     }
 
     @Override
@@ -182,4 +202,28 @@
       responseObserver.onCompleted();
     }
   }
+
+  private static class HostnameInterceptor implements ServerInterceptor {
+    private final String host;
+
+    private HostnameInterceptor(String host) {
+      this.host = host;
+    }
+
+    @Override
+    public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+        ServerCall<ReqT, RespT> call,
+        final Metadata requestHeaders,
+        ServerCallHandler<ReqT, RespT> next) {
+      return next.startCall(
+          new SimpleForwardingServerCall<ReqT, RespT>(call) {
+            @Override
+            public void sendHeaders(Metadata responseHeaders) {
+              responseHeaders.put(HOSTNAME_KEY, host);
+              super.sendHeaders(responseHeaders);
+            }
+          },
+          requestHeaders);
+    }
+  }
 }
diff --git a/interop-testing/src/main/proto/grpc/testing/messages.proto b/interop-testing/src/main/proto/grpc/testing/messages.proto
index 5665de8..a84f708 100644
--- a/interop-testing/src/main/proto/grpc/testing/messages.proto
+++ b/interop-testing/src/main/proto/grpc/testing/messages.proto
@@ -195,4 +195,10 @@
   map<string, int32> rpcs_by_peer = 1;
   // The number of RPCs that failed to record a remote peer.
   int32 num_failures = 2;
+  message RpcsByPeer {
+    // The number of completed RPCs for each peer.
+    map<string, int32> rpcs_by_peer = 1;
+  }
+  // The number of completed RPCs for each type (UnaryCall or EmptyCall).
+  map<string, RpcsByPeer> rpcs_by_method = 3;
 }