interop-testing: add fake altsHandshakerService for test (#7847)

diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
index 769309a..758f99d 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/AbstractInteropTest.java
@@ -170,6 +170,7 @@
 
   private ScheduledExecutorService testServiceExecutor;
   private Server server;
+  private Server handshakerServer;
 
   private final LinkedBlockingQueue<ServerStreamTracerInfo> serverStreamTracers =
       new LinkedBlockingQueue<>();
@@ -223,6 +224,7 @@
   protected static final Empty EMPTY = Empty.getDefaultInstance();
 
   private void startServer() {
+    maybeStartHandshakerServer();
     ServerBuilder<?> builder = getServerBuilder();
     if (builder == null) {
       server = null;
@@ -251,6 +253,17 @@
     }
   }
 
+  private void maybeStartHandshakerServer() {
+    ServerBuilder<?> handshakerServerBuilder = getHandshakerServerBuilder();
+    if (handshakerServerBuilder != null) {
+      try {
+        handshakerServer = handshakerServerBuilder.build().start();
+      } catch (IOException ex) {
+        throw new RuntimeException(ex);
+      }
+    }
+  }
+
   private void stopServer() {
     if (server != null) {
       server.shutdownNow();
@@ -258,6 +271,9 @@
     if (testServiceExecutor != null) {
       testServiceExecutor.shutdown();
     }
+    if (handshakerServer != null) {
+      handshakerServer.shutdownNow();
+    }
   }
 
   @VisibleForTesting
@@ -348,6 +364,11 @@
     return null;
   }
 
+  @Nullable
+  protected ServerBuilder<?> getHandshakerServerBuilder() {
+    return null;
+  }
+
   protected final ClientInterceptor createCensusStatsClientInterceptor() {
     return
         InternalCensusStatsAccessor
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java b/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java
new file mode 100644
index 0000000..bf4a2fe
--- /dev/null
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/AltsHandshakerTestService.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.testing.integration;
+
+import static com.google.common.base.Preconditions.checkState;
+import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.CLIENT_START;
+import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.NEXT;
+import static io.grpc.alts.internal.HandshakerReq.ReqOneofCase.SERVER_START;
+
+import com.google.protobuf.ByteString;
+import io.grpc.alts.internal.HandshakerReq;
+import io.grpc.alts.internal.HandshakerResp;
+import io.grpc.alts.internal.HandshakerResult;
+import io.grpc.alts.internal.HandshakerServiceGrpc.HandshakerServiceImplBase;
+import io.grpc.alts.internal.Identity;
+import io.grpc.alts.internal.RpcProtocolVersions;
+import io.grpc.alts.internal.RpcProtocolVersions.Version;
+import io.grpc.stub.StreamObserver;
+import java.util.Random;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+/**
+ * A fake HandshakeService for ALTS integration testing in non-gcp environments.
+ * */
+public class AltsHandshakerTestService extends HandshakerServiceImplBase {
+  private static final Logger log = Logger.getLogger(AltsHandshakerTestService.class.getName());
+
+  private final Random random = new Random();
+  private static final int FIXED_LENGTH_OUTPUT = 16;
+  private final ByteString fakeOutput = data(FIXED_LENGTH_OUTPUT);
+  private final ByteString secret = data(128);
+  private State expectState = State.CLIENT_INIT;
+
+  @Override
+  public StreamObserver<HandshakerReq> doHandshake(
+      final StreamObserver<HandshakerResp> responseObserver) {
+    return new StreamObserver<HandshakerReq>() {
+      @Override
+      public void onNext(HandshakerReq value) {
+        log.log(Level.FINE, "request received: " + value);
+        synchronized (this) {
+          switch (expectState) {
+            case CLIENT_INIT:
+              checkState(CLIENT_START.equals(value.getReqOneofCase()));
+              HandshakerResp initClient = HandshakerResp.newBuilder()
+                  .setOutFrames(fakeOutput)
+                  .build();
+              log.log(Level.FINE, "init client response " + initClient);
+              responseObserver.onNext(initClient);
+              expectState = State.SERVER_INIT;
+              break;
+            case SERVER_INIT:
+              checkState(SERVER_START.equals(value.getReqOneofCase()));
+              HandshakerResp initServer = HandshakerResp.newBuilder()
+                  .setBytesConsumed(FIXED_LENGTH_OUTPUT)
+                  .setOutFrames(fakeOutput)
+                  .build();
+              log.log(Level.FINE, "init server response" + initServer);
+              responseObserver.onNext(initServer);
+              expectState = State.CLIENT_FINISH;
+              break;
+            case CLIENT_FINISH:
+              checkState(NEXT.equals(value.getReqOneofCase()));
+              HandshakerResp resp = HandshakerResp.newBuilder()
+                  .setResult(getResult())
+                  .setBytesConsumed(FIXED_LENGTH_OUTPUT)
+                  .setOutFrames(fakeOutput)
+                  .build();
+              log.log(Level.FINE, "client finished response " + resp);
+              responseObserver.onNext(resp);
+              expectState = State.SERVER_FINISH;
+              break;
+            case SERVER_FINISH:
+              resp = HandshakerResp.newBuilder()
+                  .setResult(getResult())
+                  .setBytesConsumed(FIXED_LENGTH_OUTPUT)
+                  .build();
+              log.log(Level.FINE, "server finished response " + resp);
+              responseObserver.onNext(resp);
+              expectState = State.CLIENT_INIT;
+              break;
+            default:
+              throw new RuntimeException("unknown state");
+          }
+        }
+      }
+
+      @Override
+      public void onError(Throwable t) {
+        log.log(Level.INFO, "onError " + t);
+      }
+
+      @Override
+      public void onCompleted() {
+        responseObserver.onCompleted();
+      }
+    };
+  }
+
+  private HandshakerResult getResult() {
+    return HandshakerResult.newBuilder().setApplicationProtocol("grpc")
+        .setRecordProtocol("ALTSRP_GCM_AES128_REKEY")
+        .setKeyData(secret)
+        .setMaxFrameSize(131072)
+        .setPeerIdentity(Identity.newBuilder()
+            .setServiceAccount("123456789-compute@developer.gserviceaccount.com")
+            .build())
+        .setPeerRpcVersions(RpcProtocolVersions.newBuilder()
+            .setMaxRpcVersion(Version.newBuilder()
+                .setMajor(2).setMinor(1)
+                .build())
+            .setMinRpcVersion(Version.newBuilder()
+                .setMajor(2).setMinor(1)
+                .build())
+            .build())
+        .build();
+  }
+
+  private ByteString data(int len) {
+    byte[] k = new byte[len];
+    random.nextBytes(k);
+    return ByteString.copyFrom(k);
+  }
+
+  private enum State {
+    CLIENT_INIT,
+    SERVER_INIT,
+    CLIENT_FINISH,
+    SERVER_FINISH
+  }
+}
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java
index 82a379c..3ca35df 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java
@@ -21,8 +21,10 @@
 import io.grpc.ChannelCredentials;
 import io.grpc.Grpc;
 import io.grpc.InsecureChannelCredentials;
+import io.grpc.InsecureServerCredentials;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
+import io.grpc.ServerBuilder;
 import io.grpc.TlsChannelCredentials;
 import io.grpc.alts.AltsChannelCredentials;
 import io.grpc.alts.ComputeEngineChannelCredentials;
@@ -42,6 +44,7 @@
 import java.io.FileInputStream;
 import java.nio.charset.Charset;
 import java.util.concurrent.TimeUnit;
+import javax.annotation.Nullable;
 
 /**
  * Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs
@@ -83,6 +86,7 @@
   private String serviceAccountKeyFile;
   private String oauthScope;
   private boolean fullStreamDecompression;
+  private int localHandshakerPort = -1;
 
   private Tester tester = new Tester();
 
@@ -141,6 +145,8 @@
         oauthScope = value;
       } else if ("full_stream_decompression".equals(key)) {
         fullStreamDecompression = Boolean.parseBoolean(value);
+      } else if ("local_handshaker_port".equals(key)) {
+        localHandshakerPort = Integer.parseInt(value);
       } else {
         System.err.println("Unknown argument: " + key);
         usage = true;
@@ -165,6 +171,9 @@
           + "\n  --use_tls=true|false        Whether to use TLS. Default " + c.useTls
           + "\n  --use_alts=true|false       Whether to use ALTS. Enable ALTS will disable TLS."
           + "\n                              Default " + c.useAlts
+          + "\n  --local_handshaker_port=PORT"
+          + "\n                              Use local ALTS handshaker service on the specified "
+          + "\n                              port for testing. Only effective when --use_alts=true."
           + "\n  --use_upgrade=true|false    Whether to use the h2c Upgrade mechanism."
           + "\n                              Enabling h2c Upgrade will disable TLS."
           + "\n                              Default " + c.useH2cUpgrade
@@ -398,7 +407,13 @@
 
       } else if (useAlts) {
         useGeneric = true; // Retain old behavior; avoids erroring if incompatible
-        channelCredentials = AltsChannelCredentials.create();
+        if (localHandshakerPort > -1) {
+          channelCredentials = AltsChannelCredentials.newBuilder()
+              .enableUntrustedAltsForTesting()
+              .setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
+        } else {
+          channelCredentials = AltsChannelCredentials.create();
+        }
 
       } else if (useTls) {
         if (!useTestCa) {
@@ -475,6 +490,18 @@
       // TODO(zhangkun83): remove this override once the said issue is fixed.
       return false;
     }
+
+    @Override
+    @Nullable
+    protected ServerBuilder<?> getHandshakerServerBuilder() {
+      if (localHandshakerPort > -1) {
+        return Grpc.newServerBuilderForPort(localHandshakerPort,
+            InsecureServerCredentials.create())
+            .addService(new AltsHandshakerTestService());
+      } else {
+        return null;
+      }
+    }
   }
 
   private static String validTestCasesHelpText() {
diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java
index 2a5c0eb..19946ec 100644
--- a/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java
+++ b/interop-testing/src/main/java/io/grpc/testing/integration/TestServiceServer.java
@@ -70,6 +70,7 @@
 
   private ScheduledExecutorService executor;
   private Server server;
+  private int localHandshakerPort = -1;
 
   @VisibleForTesting
   void parseArgs(String[] args) {
@@ -98,6 +99,8 @@
         useTls = Boolean.parseBoolean(value);
       } else if ("use_alts".equals(key)) {
         useAlts = Boolean.parseBoolean(value);
+      } else if ("local_handshaker_port".equals(key)) {
+        localHandshakerPort = Integer.parseInt(value);
       } else if ("grpc_version".equals(key)) {
         if (!"2".equals(value)) {
           System.err.println("Only grpc version 2 is supported");
@@ -122,6 +125,9 @@
               + "\n  --use_tls=true|false  Whether to use TLS. Default " + s.useTls
               + "\n  --use_alts=true|false Whether to use ALTS. Enable ALTS will disable TLS."
               + "\n                        Default " + s.useAlts
+              + "\n  --local_handshaker_port=PORT"
+              + "\n                        Use local ALTS handshaker service on the specified port "
+              + "\n                        for testing. Only effective when --use_alts=true."
       );
       System.exit(1);
     }
@@ -132,7 +138,13 @@
     executor = Executors.newSingleThreadScheduledExecutor();
     ServerCredentials serverCreds;
     if (useAlts) {
-      serverCreds = AltsServerCredentials.create();
+      if (localHandshakerPort > -1) {
+        serverCreds = AltsServerCredentials.newBuilder()
+            .enableUntrustedAltsForTesting()
+            .setHandshakerAddressForTesting("localhost:" + localHandshakerPort).build();
+      } else {
+        serverCreds = AltsServerCredentials.create();
+      }
     } else if (useTls) {
       serverCreds = TlsServerCredentials.create(
           TestUtils.loadCert("server1.pem"), TestUtils.loadCert("server1.key"));
diff --git a/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java b/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java
new file mode 100644
index 0000000..c6c1d2b
--- /dev/null
+++ b/interop-testing/src/test/java/io/grpc/testing/integration/AltsHandshakerTest.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2021 The gRPC Authors
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.grpc.testing.integration;
+
+import static org.junit.Assert.assertEquals;
+
+import com.google.protobuf.ByteString;
+import io.grpc.ChannelCredentials;
+import io.grpc.Grpc;
+import io.grpc.ManagedChannel;
+import io.grpc.Server;
+import io.grpc.ServerCredentials;
+import io.grpc.alts.AltsChannelCredentials;
+import io.grpc.alts.AltsServerCredentials;
+import io.grpc.netty.NettyServerBuilder;
+import io.grpc.stub.StreamObserver;
+import io.grpc.testing.GrpcCleanupRule;
+import io.grpc.testing.integration.Messages.Payload;
+import io.grpc.testing.integration.Messages.SimpleRequest;
+import io.grpc.testing.integration.Messages.SimpleResponse;
+import io.netty.channel.nio.NioEventLoopGroup;
+import io.netty.channel.socket.nio.NioServerSocketChannel;
+import io.netty.util.concurrent.DefaultThreadFactory;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class AltsHandshakerTest {
+  @Rule
+  public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule();
+  private Server handshakerServer;
+  private Server testServer;
+  private ManagedChannel channel;
+
+  private void startAltsServer() throws Exception {
+    ServerCredentials serverCredentials = AltsServerCredentials.newBuilder()
+        .enableUntrustedAltsForTesting()
+        .setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort())
+        .build();
+    testServer = grpcCleanup.register(
+        Grpc.newServerBuilderForPort(0, serverCredentials)
+            .addService(new TestServiceGrpc.TestServiceImplBase() {
+              @Override
+              public void unaryCall(SimpleRequest request, StreamObserver<SimpleResponse> so) {
+                so.onNext(SimpleResponse.getDefaultInstance());
+                so.onCompleted();
+              }
+            })
+            .build())
+        .start();
+  }
+
+  @Before
+  public void setup() throws Exception {
+    // create new EventLoopGroups to avoid deadlock at server side handshake negotiation, e.g.
+    // happens when handshakerServer and testServer child channels are on the same eventloop.
+    handshakerServer = grpcCleanup.register(NettyServerBuilder.forPort(0)
+        .bossEventLoopGroup(
+            new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-boss")))
+        .workerEventLoopGroup(
+            new NioEventLoopGroup(0, new DefaultThreadFactory("test-alts-worker")))
+        .channelType(NioServerSocketChannel.class)
+        .addService(new AltsHandshakerTestService())
+        .build()).start();
+    startAltsServer();
+
+    ChannelCredentials channelCredentials = AltsChannelCredentials.newBuilder()
+        .enableUntrustedAltsForTesting()
+        .setHandshakerAddressForTesting("localhost:" + handshakerServer.getPort()).build();
+    channel = grpcCleanup.register(
+        Grpc.newChannelBuilderForAddress("localhost", testServer.getPort(), channelCredentials)
+            .build());
+  }
+
+  @Test
+  public void testAlts() {
+    TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(channel);
+    final SimpleRequest request = SimpleRequest.newBuilder()
+            .setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10])))
+            .build();
+    assertEquals(SimpleResponse.getDefaultInstance(), blockingStub.unaryCall(request));
+  }
+}