okhttp: forceful close after MAX_CONNECTION_AGE_GRACE_TIME (#9968)

diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
index 1fd9807..5ec393e 100644
--- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
+++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerTransport.java
@@ -60,6 +60,7 @@
 import java.util.concurrent.TimeUnit;
 import java.util.logging.Level;
 import java.util.logging.Logger;
+import javax.annotation.Nullable;
 import javax.annotation.concurrent.GuardedBy;
 import okio.Buffer;
 import okio.BufferedSource;
@@ -73,6 +74,9 @@
       ExceptionHandlingFrameWriter.TransportExceptionHandler, OutboundFlowController.Transport {
   private static final Logger log = Logger.getLogger(OkHttpServerTransport.class.getName());
   private static final int GRACEFUL_SHUTDOWN_PING = 0x1111;
+
+  private static final long GRACEFUL_SHUTDOWN_PING_TIMEOUT_NANOS = TimeUnit.SECONDS.toNanos(1);
+
   private static final int KEEPALIVE_PING = 0xDEAD;
   private static final ByteString HTTP_METHOD = ByteString.encodeUtf8(":method");
   private static final ByteString CONNECT_METHOD = ByteString.encodeUtf8("CONNECT");
@@ -132,6 +136,8 @@
   /** Non-{@code null} when waiting for forceful close GOAWAY to be sent. */
   @GuardedBy("lock")
   private ScheduledFuture<?> forcefulCloseTimer;
+  @GuardedBy("lock")
+  private Long gracefulShutdownPeriod = null;
 
   public OkHttpServerTransport(Config config, Socket bareSocket) {
     this.config = Preconditions.checkNotNull(config, "config");
@@ -250,15 +256,16 @@
 
   @Override
   public void shutdown() {
-    shutdown(TimeUnit.SECONDS.toNanos(1L));
+    shutdown(null);
   }
 
-  private void shutdown(Long graceTimeInNanos) {
+  private void shutdown(@Nullable Long gracefulShutdownPeriod) {
     synchronized (lock) {
       if (gracefulShutdown || abruptShutdown) {
         return;
       }
       gracefulShutdown = true;
+      this.gracefulShutdownPeriod = gracefulShutdownPeriod;
       if (frameWriter == null) {
         handshakeShutdown = true;
         GrpcUtil.closeQuietly(bareSocket);
@@ -267,7 +274,8 @@
         // we also set a timer to limit the upper bound in case the PING is excessively stalled or
         // the client is malicious.
         secondGoawayTimer = scheduledExecutorService.schedule(
-            this::triggerGracefulSecondGoaway, graceTimeInNanos, TimeUnit.NANOSECONDS);
+            this::triggerGracefulSecondGoaway,
+            GRACEFUL_SHUTDOWN_PING_TIMEOUT_NANOS, TimeUnit.NANOSECONDS);
         frameWriter.goAway(Integer.MAX_VALUE, ErrorCode.NO_ERROR, new byte[0]);
         frameWriter.ping(false, 0, GRACEFUL_SHUTDOWN_PING);
         frameWriter.flush();
@@ -289,6 +297,10 @@
       } else {
         frameWriter.flush();
       }
+      if (gracefulShutdownPeriod != null) {
+        forcefulCloseTimer = scheduledExecutorService.schedule(
+            this::triggerForcefulClose, gracefulShutdownPeriod, TimeUnit.NANOSECONDS);
+      }
     }
   }
 
diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
index b58f21b..816272f 100644
--- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
+++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpServerTransportTest.java
@@ -155,7 +155,7 @@
   @Test
   public void maxConnectionAge() throws Exception {
     serverBuilder.maxConnectionAge(5, TimeUnit.SECONDS)
-        .maxConnectionAgeGrace(1, TimeUnit.SECONDS);
+        .maxConnectionAgeGrace(3, TimeUnit.SECONDS);
     initTransport();
     handshake();
     clientFrameWriter.headers(1, Arrays.asList(
@@ -169,8 +169,20 @@
         new Header("some-client-sent-trailer", "trailer-value")));
     pingPong();
     fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(6)); // > 1.1 * 5
-    fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(1));
     verifyGracefulShutdown(1);
+    pingPong();
+    fakeClock.forwardNanos(TimeUnit.SECONDS.toNanos(3));
+    assertThat(socket.isClosed()).isTrue();
+  }
+
+  @Test
+  public void maxConnectionAge_shutdown() throws Exception {
+    serverBuilder.maxConnectionAge(5, TimeUnit.SECONDS)
+        .maxConnectionAgeGrace(3, TimeUnit.SECONDS);
+    initTransport();
+    handshake();
+    shutdownAndTerminate(0);
+    assertThat(fakeClock.numPendingTasks()).isEqualTo(0);
   }
 
   @Test
@@ -1369,6 +1381,7 @@
         // PipedInputStream can only be woken by PipedOutputStream, so PipedOutputStream.close() is
         // a better imitation of Socket.close().
         inputStreamSource.close();
+        super.close();
       }
     }