interop-testing: add fake altsHandshakerService for test (#7847)
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
index 769309a..758f99d 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
@@ -170,6 +170,7 @@
private ScheduledExecutorService testServiceExecutor;
private Server server;
+ private Server handshakerServer;
private final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
new LinkedBlockingQueue<>();
@@ -223,6 +224,7 @@
protected static final Empty EMPTY = Empty.getDefaultInstance();
private void startServer() {
+ maybeStartHandshakerServer();
ServerBuilder<?> builder = getServerBuilder();
if (builder == null) {
server = null;
@@ -251,6 +253,17 @@
}
}
+ private void maybeStartHandshakerServer() {
+ ServerBuilder<?> handshakerServerBuilder = getHandshakerServerBuilder();
+ if (handshakerServerBuilder != null) {
+ try {
+ handshakerServer = handshakerServerBuilder.build().start();
+ } catch (IOException ex) {
+ throw new RuntimeException(ex);
+ }
+ }
+ }
+
private void stopServer() {
if (server != null) {
server.shutdownNow();
@@ -258,6 +271,9 @@
if (testServiceExecutor != null) {
testServiceExecutor.shutdown();
}
+ if (handshakerServer != null) {
+ handshakerServer.shutdownNow();
+ }
}
@VisibleForTesting
@@ -348,6 +364,11 @@
return null;
}
+ @Nullable
+ protected ServerBuilder<?> getHandshakerServerBuilder() {
+ return null;
+ }
+
protected final ClientInterceptor createCensusStatsClientInterceptor() {
return
InternalCensusStatsAccessor
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java b/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java
new file mode 100644
index 0000000..bf4a2fe
--- /dev/null
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.testing.integration;
+
+import static com.google.common.base.Preconditions.checkState;
+import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.CLIENT_START;
+import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.NEXT;
+import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.SERVER_START;
+
+import com.google.protobuf.ByteString;
+import io.grpc.alts.internal.HandshakerReq;
+import io.grpc.alts.internal.HandshakerResp;
+import io.grpc.alts.internal.HandshakerResult;
+import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceImplBase;
+import io.grpc.alts.internal.Identity;
+import io.grpc.alts.internal.RpcProtocolVersions;
+import io.grpc.alts.internal.RpcProtocolVersions.Version;
+import io.grpc.stub.StreamObserver;
+import java.util.Random;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * A fake HandshakeService for ALTS integration testing in non-gcp environments.
+ * */
+public class AltsHandshakerTestService extends HandshakerServiceImplBase {
+ private static final Logger log = Logger.getLogger(AltsHandshakerTestService.class.getName());
+
+ private final Random random = new Random();
+ private static final int FIXED_LENGTH_OUTPUT = 16;
+ private final ByteString fakeOutput = data(FIXED_LENGTH_OUTPUT);
+ private final ByteString secret = data(128);
+ private State expectState = State.CLIENT_INIT;
+
+ @Override
+ public StreamObserver<HandshakerReq> doHandshake(
+ final StreamObserver<HandshakerResp> responseObserver) {
+ return new StreamObserver<HandshakerReq>() {
+ @Override
+ public void onNext(HandshakerReq value) {
+ log.log(Level.FINE, "request received: " + value);
+ synchronized (this) {
+ switch (expectState) {
+ case CLIENT_INIT:
+ checkState(CLIENT_START.equals(value.getReqOneofCase()));
+ HandshakerResp initClient = HandshakerResp.newBuilder()
+ .setOutFrames(fakeOutput)
+ .build();
+ log.log(Level.FINE, "init client response " + initClient);
+ responseObserver.onNext(initClient);
+ expectState = State.SERVER_INIT;
+ break;
+ case SERVER_INIT:
+ checkState(SERVER_START.equals(value.getReqOneofCase()));
+ HandshakerResp initServer = HandshakerResp.newBuilder()
+ .setBytesConsumed(FIXED_LENGTH_OUTPUT)
+ .setOutFrames(fakeOutput)
+ .build();
+ log.log(Level.FINE, "init server response" + initServer);
+ responseObserver.onNext(initServer);
+ expectState = State.CLIENT_FINISH;
+ break;
+ case CLIENT_FINISH:
+ checkState(NEXT.equals(value.getReqOneofCase()));
+ HandshakerResp resp = HandshakerResp.newBuilder()
+ .setResult(getResult())
+ .setBytesConsumed(FIXED_LENGTH_OUTPUT)
+ .setOutFrames(fakeOutput)
+ .build();
+ log.log(Level.FINE, "client finished response " + resp);
+ responseObserver.onNext(resp);
+ expectState = State.SERVER_FINISH;
+ break;
+ case SERVER_FINISH:
+ resp = HandshakerResp.newBuilder()
+ .setResult(getResult())
+ .setBytesConsumed(FIXED_LENGTH_OUTPUT)
+ .build();
+ log.log(Level.FINE, "server finished response " + resp);
+ responseObserver.onNext(resp);
+ expectState = State.CLIENT_INIT;
+ break;
+ default:
+ throw new RuntimeException("unknown state");
+ }
+ }
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ log.log(Level.INFO, "onError " + t);
+ }
+
+ @Override
+ public void onCompleted() {
+ responseObserver.onCompleted();
+ }
+ };
+ }
+
+ private HandshakerResult getResult() {
+ return HandshakerResult.newBuilder().setApplicationProtocol("grpc")
+ .setRecordProtocol("ALTSRP_GCM_AES128_REKEY")
+ .setKeyData(secret)
+ .setMaxFrameSize(131072)
+ .setPeerIdentity(Identity.newBuilder()
+ .setServiceAccount("123456789-compute@developer.gserviceaccount.com")
+ .build())
+ .setPeerRpcVersions(RpcProtocolVersions.newBuilder()
+ .setMaxRpcVersion(Version.newBuilder()
+ .setMajor(2).setMinor(1)
+ .build())
+ .setMinRpcVersion(Version.newBuilder()
+ .setMajor(2).setMinor(1)
+ .build())
+ .build())
+ .build();
+ }
+
+ private ByteString data(int len) {
+ byte[] k = new byte[len];
+ random.nextBytes(k);
+ return ByteString.copyFrom(k);
+ }
+
+ private enum State {
+ CLIENT_INIT,
+ SERVER_INIT,
+ CLIENT_FINISH,
+ SERVER_FINISH
+ }
+}
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java
index 82a379c..3ca35df 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java
@@ -21,8 +21,10 @@
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
+import io.grpc.InsecureServerCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
+import io.grpc.ServerBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelCredentials;
@@ -42,6 +44,7 @@
import java.io.FileInputStream;
import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
/**
* Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs
@@ -83,6 +86,7 @@
private String serviceAccountKeyFile;
private String oauthScope;
private boolean fullStreamDecompression;
+ private int localHandshakerPort = -1;
private Tester tester = new Tester();
@@ -141,6 +145,8 @@
oauthScope = value;
} else if ("full_stream_decompression".equals(key)) {
fullStreamDecompression = Boolean.parseBoolean(value);
+ } else if ("local_handshaker_port".equals(key)) {
+ localHandshakerPort = Integer.parseInt(value);
} else {
System.err.println("Unknown argument: " + key);
usage = true;
@@ -165,6 +171,9 @@
+ "\n --use_tls=true|false Whether to use TLS. Default " + c.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + c.useAlts
+ + "\n --local_handshaker_port=PORT"
+ + "\n Use local ALTS handshaker service on the specified "
+ + "\n port for testing. Only effective when --use_alts=true."
+ "\n --use_upgrade=true|false Whether to use the h2c Upgrade mechanism."
+ "\n Enabling h2c Upgrade will disable TLS."
+ "\n Default " + c.useH2cUpgrade
@@ -398,7 +407,13 @@
} else if (useAlts) {
useGeneric = true; // Retain old behavior; avoids erroring if incompatible
- channelCredentials = AltsChannelCredentials.create();
+ if (localHandshakerPort > -1) {
+ channelCredentials = AltsChannelCredentials.newBuilder()
+ .enableUntrustedAltsForTesting()
+ .setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
+ } else {
+ channelCredentials = AltsChannelCredentials.create();
+ }
} else if (useTls) {
if (!useTestCa) {
@@ -475,6 +490,18 @@
// TODO(zhangkun83): remove this override once the said issue is fixed.
return false;
}
+
+ @Override
+ @Nullable
+ protected ServerBuilder<?> getHandshakerServerBuilder() {
+ if (localHandshakerPort > -1) {
+ return Grpc.newServerBuilderForPort(localHandshakerPort,
+ InsecureServerCredentials.create())
+ .addService(new AltsHandshakerTestService());
+ } else {
+ return null;
+ }
+ }
}
private static String validTestCasesHelpText() {
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java
index 2a5c0eb..19946ec 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java
@@ -70,6 +70,7 @@
private ScheduledExecutorService executor;
private Server server;
+ private int localHandshakerPort = -1;
@VisibleForTesting
void parseArgs(String[] args) {
@@ -98,6 +99,8 @@
useTls = Boolean.parseBoolean(value);
} else if ("use_alts".equals(key)) {
useAlts = Boolean.parseBoolean(value);
+ } else if ("local_handshaker_port".equals(key)) {
+ localHandshakerPort = Integer.parseInt(value);
} else if ("grpc_version".equals(key)) {
if (!"2".equals(value)) {
System.err.println("Only grpc version 2 is supported");
@@ -122,6 +125,9 @@
+ "\n --use_tls=true|false Whether to use TLS. Default " + s.useTls
+ "\n --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
+ "\n Default " + s.useAlts
+ + "\n --local_handshaker_port=PORT"
+ + "\n Use local ALTS handshaker service on the specified port "
+ + "\n for testing. Only effective when --use_alts=true."
);
System.exit(1);
}
@@ -132,7 +138,13 @@
executor = Executors.newSingleThreadScheduledExecutor();
ServerCredentials serverCreds;
if (useAlts) {
- serverCreds = AltsServerCredentials.create();
+ if (localHandshakerPort > -1) {
+ serverCreds = AltsServerCredentials.newBuilder()
+ .enableUntrustedAltsForTesting()
+ .setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
+ } else {
+ serverCreds = AltsServerCredentials.create();
+ }
} else if (useTls) {
serverCreds = TlsServerCredentials.create(
TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java
new file mode 100644
index 0000000..c6c1d2b
--- /dev/null
+++ b/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.testing.integration;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.protobuf.ByteString;
+import io.grpc.ChannelCredentials;
+import io.grpc.Grpc;
+import io.grpc.ManagedChannel;
+import io.grpc.Server;
+import io.grpc.ServerCredentials;
+import io.grpc.alts.AltsChannelCredentials;
+import io.grpc.alts.AltsServerCredentials;
+import io.grpc.netty.NettyServerBuilder;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.testing.integration.Messages.Payload;
+import io.grpc.testing.integration.Messages.SimpleRequest;
+import io.grpc.testing.integration.Messages.SimpleResponse;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.util.concurrent.DefaultThreadFactory;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class AltsHandshakerTest {
+ @Rule
+ public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+ private Server handshakerServer;
+ private Server testServer;
+ private ManagedChannel channel;
+
+ private void startAltsServer() throws Exception {
+ ServerCredentials serverCredentials = AltsServerCredentials.newBuilder()
+ .enableUntrustedAltsForTesting()
+ .setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort())
+ .build();
+ testServer = grpcCleanup.register(
+ Grpc.newServerBuilderForPort(0, serverCredentials)
+ .addService(new TestServiceGrpc.TestServiceImplBase() {
+ @Override
+ public void unaryCall(SimpleRequest request, StreamObserver<SimpleResponse> so) {
+ so.onNext(SimpleResponse.getDefaultInstance());
+ so.onCompleted();
+ }
+ })
+ .build())
+ .start();
+ }
+
+ @Before
+ public void setup() throws Exception {
+ // create new EventLoopGroups to avoid deadlock at server side handshake negotiation, e.g.
+ // happens when handshakerServer and testServer child channels are on the same eventloop.
+ handshakerServer = grpcCleanup.register(NettyServerBuilder.forPort(0)
+ .bossEventLoopGroup(
+ new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-boss")))
+ .workerEventLoopGroup(
+ new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-worker")))
+ .channelType(NioServerSocketChannel.class)
+ .addService(new AltsHandshakerTestService())
+ .build()).start();
+ startAltsServer();
+
+ ChannelCredentials channelCredentials = AltsChannelCredentials.newBuilder()
+ .enableUntrustedAltsForTesting()
+ .setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort()).build();
+ channel = grpcCleanup.register(
+ Grpc.newChannelBuilderForAddress("localhost", testServer.getPort(), channelCredentials)
+ .build());
+ }
+
+ @Test
+ public void testAlts() {
+ TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel);
+ final SimpleRequest request = SimpleRequest.newBuilder()
+ .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10])))
+ .build();
+ assertEquals(SimpleResponse.getDefaultInstance(), blockingStub.unaryCall(request));
+ }
+}