interop-testing: add flags to xds test client
diff --git a/buildscripts/kokoro/xds.sh b/buildscripts/kokoro/xds.sh
index 9d66245..f911f25 100755
--- a/buildscripts/kokoro/xds.sh
+++ b/buildscripts/kokoro/xds.sh
@@ -20,9 +20,15 @@
git clone -b "${branch}" --single-branch --depth=1 https://github.com/grpc/grpc.git
grpc/tools/run_tests/helper_scripts/prep_xds.sh
+
+# Test cases "path_matching" and "header_matching" are not included in "all",
+# because not all interop clients in all languages support these new tests.
+#
+# TODO(ericgribkoff): remove "path_matching" and "header_matching" from
+# --test_case after they are added into "all".
JAVA_OPTS=-Djava.util.logging.config.file=grpc-java/buildscripts/xds_logging.properties \
python3 grpc/tools/run_tests/run_xds_tests.py \
- --test_case=all \
+ --test_case="all,path_matching,header_matching" \
--project_id=grpc-testing \
--source_image=projects/grpc-testing/global/images/xds-test-server \
--path_to_server_binary=/java_server/grpc-java/interop-testing/build/install/grpc-interop-testing/bin/xds-test-server \
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
index 0bd2fbe..73af159 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java
@@ -16,6 +16,9 @@
package io.grpc.testing.integration;
+import com.google.common.base.CaseFormat;
+import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Ints;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
@@ -24,12 +27,16 @@
import com.google.common.util.concurrent.MoreExecutors;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions;
+import io.grpc.Channel;
import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
+import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
import io.grpc.Server;
-import io.grpc.Status;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.stub.StreamObserver;
@@ -38,6 +45,7 @@
import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse;
import java.util.ArrayList;
+import java.util.EnumMap;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -47,6 +55,7 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.Nullable;
@@ -62,13 +71,24 @@
private int numChannels = 1;
private boolean printResponse = false;
private int qps = 1;
- private int rpcTimeoutSec = 2;
+ private List<RpcType> rpcTypes = ImmutableList.of(RpcType.UNARY_CALL);
+ private EnumMap<RpcType, Metadata> metadata = new EnumMap<>(RpcType.class);
+ private int rpcTimeoutSec = 20;
private String server = "localhost:8080";
private int statsPort = 8081;
private Server statsServer;
private long currentRequestId;
private ListeningScheduledExecutorService exec;
+ private enum RpcType {
+ EMPTY_CALL,
+ UNARY_CALL;
+
+ public String toCamelCase() {
+ return CaseFormat.UPPER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, toString());
+ }
+ }
+
/**
* The main application allowing this client to be launched from the command line.
*/
@@ -111,12 +131,16 @@
break;
}
String value = parts[1];
- if ("num_channels".equals(key)) {
+ if ("metadata".equals(key)) {
+ metadata = parseMetadata(value);
+ } else if ("num_channels".equals(key)) {
numChannels = Integer.valueOf(value);
} else if ("print_response".equals(key)) {
printResponse = Boolean.valueOf(value);
} else if ("qps".equals(key)) {
qps = Integer.valueOf(value);
+ } else if ("rpc".equals(key)) {
+ rpcTypes = parseRpcs(value);
} else if ("rpc_timeout_sec".equals(key)) {
rpcTimeoutSec = Integer.valueOf(value);
} else if ("server".equals(key)) {
@@ -139,8 +163,12 @@
+ c.numChannels
+ "\n --print_response=BOOL Write RPC response to stdout. Default: "
+ c.printResponse
- + "\n --qps=INT Qps per channel. Default: "
+ + "\n --qps=INT Qps per channel, for each type of RPC. Default: "
+ c.qps
+ + "\n --rpc=STR Types of RPCs to make, ',' separated string. RPCs can "
+ + "be EmptyCall or UnaryCall. Default: UnaryCall"
+ + "\n --metadata=STR The metadata to send with each RPC, in the format "
+ + "EmptyCall:key1:value1,UnaryCall:key2:value2."
+ "\n --rpc_timeout_sec=INT Per RPC timeout seconds. Default: "
+ c.rpcTimeoutSec
+ "\n --server=host:port Address of server. Default: "
@@ -152,6 +180,45 @@
}
}
+ private static List<RpcType> parseRpcs(String rpcArg) {
+ List<RpcType> rpcs = new ArrayList<>();
+ for (String rpc : Splitter.on(',').split(rpcArg)) {
+ rpcs.add(parseRpc(rpc));
+ }
+ return rpcs;
+ }
+
+ private static EnumMap<RpcType, Metadata> parseMetadata(String metadataArg) {
+ EnumMap<RpcType, Metadata> rpcMetadata = new EnumMap<>(RpcType.class);
+ for (String metadata : Splitter.on(',').omitEmptyStrings().split(metadataArg)) {
+ List<String> parts = Splitter.on(':').splitToList(metadata);
+ if (parts.size() != 3) {
+ throw new IllegalArgumentException("Invalid metadata: '" + metadata + "'");
+ }
+ RpcType rpc = parseRpc(parts.get(0));
+ String key = parts.get(1);
+ String value = parts.get(2);
+ Metadata md = new Metadata();
+ md.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value);
+ if (rpcMetadata.containsKey(rpc)) {
+ rpcMetadata.get(rpc).merge(md);
+ } else {
+ rpcMetadata.put(rpc, md);
+ }
+ }
+ return rpcMetadata;
+ }
+
+ private static RpcType parseRpc(String rpc) {
+ if ("EmptyCall".equals(rpc)) {
+ return RpcType.EMPTY_CALL;
+ } else if ("UnaryCall".equals(rpc)) {
+ return RpcType.UNARY_CALL;
+ } else {
+ throw new IllegalArgumentException("Unknown RPC: '" + rpc + "'");
+ }
+ }
+
private void run() {
statsServer = NettyServerBuilder.forPort(statsPort).addService(new XdsStatsImpl()).build();
try {
@@ -186,6 +253,11 @@
private void runQps() throws InterruptedException, ExecutionException {
final SettableFuture<Void> failure = SettableFuture.create();
final class PeriodicRpc implements Runnable {
+ private final RpcType rpcType;
+
+ private PeriodicRpc(RpcType rpcType) {
+ this.rpcType = rpcType;
+ }
@Override
public void run() {
@@ -197,69 +269,137 @@
savedWatchers.addAll(watchers);
}
- SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
+ final Metadata headersToSend;
+ if (metadata.containsKey(rpcType)) {
+ headersToSend = metadata.get(rpcType);
+ } else {
+ headersToSend = new Metadata();
+ }
ManagedChannel channel = channels.get((int) (requestId % channels.size()));
- final ClientCall<SimpleRequest, SimpleResponse> call =
- channel.newCall(
- TestServiceGrpc.getUnaryCallMethod(),
- CallOptions.DEFAULT.withDeadlineAfter(rpcTimeoutSec, TimeUnit.SECONDS));
- call.start(
- new ClientCall.Listener<SimpleResponse>() {
- private String hostname;
+ TestServiceGrpc.TestServiceStub stub = TestServiceGrpc.newStub(channel);
+ final AtomicReference<ClientCall<?, ?>> clientCallRef = new AtomicReference<>();
+ final AtomicReference<String> hostnameRef = new AtomicReference<>();
+ stub =
+ stub.withDeadlineAfter(rpcTimeoutSec, TimeUnit.SECONDS)
+ .withInterceptors(
+ new ClientInterceptor() {
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
+ MethodDescriptor<ReqT, RespT> method,
+ CallOptions callOptions,
+ Channel next) {
+ ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
+ clientCallRef.set(call);
+ return new SimpleForwardingClientCall<ReqT, RespT>(call) {
+ @Override
+ public void start(Listener<RespT> responseListener, Metadata headers) {
+ headers.merge(headersToSend);
+ super.start(
+ new SimpleForwardingClientCallListener<RespT>(responseListener) {
+ @Override
+ public void onHeaders(Metadata headers) {
+ hostnameRef.set(headers.get(XdsTestServer.HOSTNAME_KEY));
+ super.onHeaders(headers);
+ }
+ },
+ headers);
+ }
+ };
+ }
+ });
- @Override
- public void onMessage(SimpleResponse response) {
- hostname = response.getHostname();
- // TODO(ericgribkoff) Currently some test environments cannot access the stats RPC
- // service and rely on parsing stdout.
- if (printResponse) {
- System.out.println(
- "Greeting: Hello world, this is "
- + hostname
- + ", from "
- + call.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
+ if (rpcType == RpcType.EMPTY_CALL) {
+ stub.emptyCall(
+ EmptyProtos.Empty.getDefaultInstance(),
+ new StreamObserver<EmptyProtos.Empty>() {
+ @Override
+ public void onCompleted() {
+ notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
}
- }
- @Override
- public void onClose(Status status, Metadata trailers) {
- if (printResponse && !status.isOk()) {
- logger.log(Level.WARNING, "Greeting RPC failed with status {0}", status);
+ @Override
+ public void onError(Throwable t) {
+ notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
}
- for (XdsStatsWatcher watcher : savedWatchers) {
- watcher.rpcCompleted(requestId, hostname);
- }
- }
- },
- new Metadata());
- call.sendMessage(request);
- call.request(1);
- call.halfClose();
+ @Override
+ public void onNext(EmptyProtos.Empty response) {}
+ });
+ } else if (rpcType == RpcType.UNARY_CALL) {
+ SimpleRequest request = SimpleRequest.newBuilder().setFillServerId(true).build();
+ stub.unaryCall(
+ request,
+ new StreamObserver<SimpleResponse>() {
+ @Override
+ public void onCompleted() {
+ notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ if (printResponse) {
+ logger.log(Level.WARNING, "Rpc failed: {0}", t);
+ }
+ notifyWatchers(savedWatchers, rpcType, requestId, hostnameRef.get());
+ }
+
+ @Override
+ public void onNext(SimpleResponse response) {
+ // TODO(ericgribkoff) Currently some test environments cannot access the stats RPC
+ // service and rely on parsing stdout.
+ if (printResponse) {
+ System.out.println(
+ "Greeting: Hello world, this is "
+ + response.getHostname()
+ + ", from "
+ + clientCallRef
+ .get()
+ .getAttributes()
+ .get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR));
+ }
+ // Use the hostname from the response if not present in the metadata.
+ // TODO(ericgribkoff) Delete when server is deployed that sets metadata value.
+ if (hostnameRef.get() == null) {
+ hostnameRef.set(response.getHostname());
+ }
+ }
+ });
+ }
}
}
long nanosPerQuery = TimeUnit.SECONDS.toNanos(1) / qps;
- ListenableScheduledFuture<?> future =
- exec.scheduleAtFixedRate(new PeriodicRpc(), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
- Futures.addCallback(
- future,
- new FutureCallback<Object>() {
+ for (RpcType rpcType : rpcTypes) {
+ ListenableScheduledFuture<?> future =
+ exec.scheduleAtFixedRate(
+ new PeriodicRpc(rpcType), 0, nanosPerQuery, TimeUnit.NANOSECONDS);
- @Override
- public void onFailure(Throwable t) {
- failure.setException(t);
- }
+ Futures.addCallback(
+ future,
+ new FutureCallback<Object>() {
- @Override
- public void onSuccess(Object o) {}
- },
- MoreExecutors.directExecutor());
+ @Override
+ public void onFailure(Throwable t) {
+ failure.setException(t);
+ }
+
+ @Override
+ public void onSuccess(Object o) {}
+ },
+ MoreExecutors.directExecutor());
+ }
failure.get();
}
+ private void notifyWatchers(
+ Set<XdsStatsWatcher> watchers, RpcType rpcType, long requestId, String hostname) {
+ for (XdsStatsWatcher watcher : watchers) {
+ watcher.rpcCompleted(rpcType, requestId, hostname);
+ }
+ }
+
private class XdsStatsImpl extends LoadBalancerStatsServiceGrpc.LoadBalancerStatsServiceImplBase {
@Override
public void getClientStats(
@@ -286,6 +426,8 @@
private final long startId;
private final long endId;
private final Map<String, Integer> rpcsByPeer = new HashMap<>();
+ private final EnumMap<RpcType, Map<String, Integer>> rpcsByTypeAndPeer =
+ new EnumMap<>(RpcType.class);
private final Object lock = new Object();
private int noRemotePeer;
@@ -295,7 +437,7 @@
this.endId = endId;
}
- void rpcCompleted(long requestId, @Nullable String hostname) {
+ void rpcCompleted(RpcType rpcType, long requestId, @Nullable String hostname) {
synchronized (lock) {
if (startId <= requestId && requestId < endId) {
if (hostname != null) {
@@ -304,6 +446,19 @@
} else {
rpcsByPeer.put(hostname, 1);
}
+ if (rpcsByTypeAndPeer.containsKey(rpcType)) {
+ if (rpcsByTypeAndPeer.get(rpcType).containsKey(hostname)) {
+ rpcsByTypeAndPeer
+ .get(rpcType)
+ .put(hostname, rpcsByTypeAndPeer.get(rpcType).get(hostname) + 1);
+ } else {
+ rpcsByTypeAndPeer.get(rpcType).put(hostname, 1);
+ }
+ } else {
+ Map<String, Integer> rpcMap = new HashMap<>();
+ rpcMap.put(hostname, 1);
+ rpcsByTypeAndPeer.put(rpcType, rpcMap);
+ }
} else {
noRemotePeer += 1;
}
@@ -325,6 +480,12 @@
LoadBalancerStatsResponse.Builder builder = LoadBalancerStatsResponse.newBuilder();
synchronized (lock) {
builder.putAllRpcsByPeer(rpcsByPeer);
+ for (Map.Entry<RpcType, Map<String, Integer>> entry : rpcsByTypeAndPeer.entrySet()) {
+ LoadBalancerStatsResponse.RpcsByPeer.Builder rpcs =
+ LoadBalancerStatsResponse.RpcsByPeer.newBuilder();
+ rpcs.putAllRpcsByPeer(entry.getValue());
+ builder.putRpcsByMethod(entry.getKey().toCamelCase(), rpcs.build());
+ }
builder.setNumFailures(noRemotePeer + (int) latch.getCount());
}
return builder.build();
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java
index a7fcbf7..b9c8e3a 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestServer.java
@@ -16,7 +16,13 @@
package io.grpc.testing.integration;
+import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
+import io.grpc.Metadata;
import io.grpc.Server;
+import io.grpc.ServerCall;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.ServerInterceptors;
import io.grpc.health.v1.HealthCheckResponse.ServingStatus;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.protobuf.services.ProtoReflectionService;
@@ -32,12 +38,16 @@
/** Interop test server that implements the xDS testing service. */
public final class XdsTestServer {
+ static final Metadata.Key<String> HOSTNAME_KEY =
+ Metadata.Key.of("hostname", Metadata.ASCII_STRING_MARSHALLER);
+
private static Logger logger = Logger.getLogger(XdsTestServer.class.getName());
private int port = 8080;
private String serverId = "java_server";
private HealthStatusManager health;
private Server server;
+ private String host;
/**
* The main application allowing this client to be launched from the command line.
@@ -111,10 +121,18 @@
}
private void start() throws Exception {
+ try {
+ host = InetAddress.getLocalHost().getHostName();
+ } catch (UnknownHostException e) {
+ logger.log(Level.SEVERE, "Failed to get host", e);
+ throw new RuntimeException(e);
+ }
health = new HealthStatusManager();
server =
NettyServerBuilder.forPort(port)
- .addService(new TestServiceImpl(serverId))
+ .addService(
+ ServerInterceptors.intercept(
+ new TestServiceImpl(serverId, host), new HostnameInterceptor(host)))
.addService(new XdsUpdateHealthServiceImpl(health))
.addService(health.getHealthService())
.addService(ProtoReflectionService.newInstance())
@@ -140,14 +158,16 @@
private final String serverId;
private final String host;
- private TestServiceImpl(String serverId) {
+ private TestServiceImpl(String serverId, String host) {
this.serverId = serverId;
- try {
- host = InetAddress.getLocalHost().getHostName();
- } catch (UnknownHostException e) {
- logger.log(Level.SEVERE, "Failed to get host", e);
- throw new RuntimeException(e);
- }
+ this.host = host;
+ }
+
+ @Override
+ public void emptyCall(
+ EmptyProtos.Empty req, StreamObserver<EmptyProtos.Empty> responseObserver) {
+ responseObserver.onNext(EmptyProtos.Empty.getDefaultInstance());
+ responseObserver.onCompleted();
}
@Override
@@ -182,4 +202,28 @@
responseObserver.onCompleted();
}
}
+
+ private static class HostnameInterceptor implements ServerInterceptor {
+ private final String host;
+
+ private HostnameInterceptor(String host) {
+ this.host = host;
+ }
+
+ @Override
+ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
+ ServerCall<ReqT, RespT> call,
+ final Metadata requestHeaders,
+ ServerCallHandler<ReqT, RespT> next) {
+ return next.startCall(
+ new SimpleForwardingServerCall<ReqT, RespT>(call) {
+ @Override
+ public void sendHeaders(Metadata responseHeaders) {
+ responseHeaders.put(HOSTNAME_KEY, host);
+ super.sendHeaders(responseHeaders);
+ }
+ },
+ requestHeaders);
+ }
+ }
}
diff --git a/interop-testing/src/main/proto/grpc/testing/messages.proto b/interop-testing/src/main/proto/grpc/testing/messages.proto
index 5665de8..a84f708 100644
--- a/interop-testing/src/main/proto/grpc/testing/messages.proto
+++ b/interop-testing/src/main/proto/grpc/testing/messages.proto
@@ -195,4 +195,10 @@
map<string, int32> rpcs_by_peer = 1;
// The number of RPCs that failed to record a remote peer.
int32 num_failures = 2;
+ message RpcsByPeer {
+ // The number of completed RPCs for each peer.
+ map<string, int32> rpcs_by_peer = 1;
+ }
+ // The number of completed RPCs for each type (UnaryCall or EmptyCall).
+ map<string, RpcsByPeer> rpcs_by_method = 3;
}