Clean up command interrupter.

Bug: 123529934
Test: unit tests
Bug: 117420403
Change-Id: Ie87e5f2b698d675ebcb34c0da923d011354ed495
Merged-In: Ie87e5f2b698d675ebcb34c0da923d011354ed495
diff --git a/src/com/android/tradefed/command/CommandInterrupter.java b/src/com/android/tradefed/command/CommandInterrupter.java
index 78b86bd..f01c55c 100644
--- a/src/com/android/tradefed/command/CommandInterrupter.java
+++ b/src/com/android/tradefed/command/CommandInterrupter.java
@@ -19,12 +19,17 @@
 import com.android.tradefed.util.RunInterruptedException;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.MapMaker;
 
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.Map;
-import java.util.Timer;
-import java.util.TimerTask;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 /** Service allowing TradeFederation commands to be interrupted or marked as uninterruptible. */
 public class CommandInterrupter {
@@ -32,149 +37,94 @@
     /** Singleton. */
     public static final CommandInterrupter INSTANCE = new CommandInterrupter();
 
-    private Map<Thread, Boolean> mMapIsInterruptAllowed = new HashMap<>();
-    private Map<Thread, String> mMapInterruptThreads = new HashMap<>();
-    private Map<Thread, Timer> mWatchdogInterrupt = new HashMap<>();
+    private final ScheduledExecutorService mExecutor = Executors.newScheduledThreadPool(0);
+
+    // tracks whether a thread is currently interruptible
+    private ConcurrentMap<Thread, Boolean> mInterruptible = new MapMaker().weakKeys().makeMap();
+    // presence of an interrupt error message indicates that the thread should be interrupted
+    private ConcurrentMap<Thread, String> mInterruptMessage = new MapMaker().weakKeys().makeMap();
 
     @VisibleForTesting
     // FIXME: reduce visibility once RunUtil interrupt tests are removed
     public CommandInterrupter() {}
 
-    /** Remove the thread that are not alive anymore from our tracking to keep the list small. */
-    private void cleanInterruptStateThreadMap() {
-        synchronized (mMapIsInterruptAllowed) {
-            for (Iterator<Thread> iterator = mMapIsInterruptAllowed.keySet().iterator();
-                    iterator.hasNext();
-                    ) {
-                Thread t = iterator.next();
-                if (!t.isAlive()) {
-                    iterator.remove();
-                }
-            }
-        }
-    }
-
-    /**
-     * Allows/disallows run interrupts on the current thread. If it is allowed, run operations of
-     * the current thread can be interrupted from other threads via {@link #interrupt} method.
-     *
-     * @param allow whether to allow run interrupts on the current thread.
-     */
-    public void allowInterrupt(boolean allow) {
-        CLog.d("run interrupt allowed: %s", allow);
-        synchronized (mMapIsInterruptAllowed) {
-            mMapIsInterruptAllowed.put(Thread.currentThread(), allow);
-        }
+    /** Allow current thread to be interrupted. */
+    public void allowInterrupt() {
+        CLog.d("Interrupt allowed");
+        mInterruptible.put(Thread.currentThread(), true);
         checkInterrupted();
     }
 
-    /**
-     * Give the interrupt status of the RunUtil.
-     *
-     * @return true if the Run can be interrupted, false otherwise.
-     */
-    public boolean isInterruptAllowed() {
-        synchronized (mMapIsInterruptAllowed) {
-            if (mMapIsInterruptAllowed.get(Thread.currentThread()) == null) {
-                // We don't add in this case to keep the map relatively small.
-                return false;
-            }
-            return mMapIsInterruptAllowed.get(Thread.currentThread());
-        }
-    }
-
-    /**
-     * Set as interruptible after some waiting time. {@link CommandScheduler#shutdownHard()} to
-     * enforce we terminate eventually.
-     *
-     * @param thread the thread that will become interruptible.
-     * @param timeMs time to wait before setting interruptible.
-     */
-    // FIXME: reduce visibility once RunUtil interrupt methods are removed
-    public void setInterruptibleInFuture(Thread thread, final long timeMs) {
-        CLog.w("Setting future interruption in %s ms", timeMs);
-        synchronized (mMapIsInterruptAllowed) {
-            if (Boolean.TRUE.equals(mMapIsInterruptAllowed.get(thread))) {
-                CLog.v("Thread is already interruptible. setInterruptibleInFuture is inop.");
-                return;
-            }
-        }
-        Timer timer = new Timer(true);
-        synchronized (mWatchdogInterrupt) {
-            mWatchdogInterrupt.put(thread, timer);
-        }
-        timer.schedule(new InterruptTask(thread), timeMs);
-    }
-
-    /**
-     * Interrupts the ongoing/forthcoming run operations on the given thread. The run operations on
-     * the given thread will throw {@link RunInterruptedException}.
-     *
-     * @param thread
-     * @param message the message for {@link RunInterruptedException}.
-     */
-    // FIXME: reduce visibility once RunUtil interrupt methods are removed
-    public synchronized void interrupt(Thread thread, String message) {
-        if (message == null) {
-            throw new IllegalArgumentException("message cannot be null.");
-        }
-        mMapInterruptThreads.put(thread, message);
+    /** Prevent current thread from being interrupted. */
+    public void blockInterrupt() {
+        CLog.d("Interrupt blocked");
+        mInterruptible.put(Thread.currentThread(), false);
         checkInterrupted();
     }
 
-    public synchronized void checkInterrupted() {
-        // Keep the map of thread's state clean of dead threads.
-        this.cleanInterruptStateThreadMap();
+    /** @return true if current thread is interruptible */
+    public boolean isInterruptible() {
+        return isInterruptible(Thread.currentThread());
+    }
 
-        final Thread thread = Thread.currentThread();
-        if (isInterruptAllowed()) {
-            final String message = mMapInterruptThreads.remove(thread);
+    /** @return true if specified thread is interruptible */
+    public boolean isInterruptible(@Nonnull Thread thread) {
+        return Boolean.TRUE.equals(mInterruptible.get(thread));
+    }
+
+    /**
+     * Allow a specified thread to be interrupted after a delay.
+     *
+     * @param thread thread to mark as interruptible
+     * @param delay time from now to delay execution
+     * @param unit time unit of the delay parameter
+     */
+    // FIXME: reduce visibility once RunUtil interrupt methods are removed
+    public Future<?> allowInterruptAsync(
+            @Nonnull Thread thread, long delay, @Nonnull TimeUnit unit) {
+        if (isInterruptible(thread)) {
+            CLog.v("Thread already interruptible");
+            return CompletableFuture.completedFuture(null);
+        }
+
+        CLog.w("Allowing interrupt in %d ms", unit.toMillis(delay));
+        return mExecutor.schedule(
+                () -> {
+                    CLog.e("Interrupt allowed asynchronously");
+                    mInterruptible.put(thread, true);
+                },
+                delay,
+                unit);
+    }
+
+    /**
+     * Flag a thread, interrupting it if and when it becomes interruptible.
+     *
+     * @param thread thread to mark for interruption
+     * @param template interruption error template
+     * @param args interruption error arguments
+     */
+    // FIXME: reduce visibility once RunUtil interrupt methods are removed
+    public void interrupt(
+            @Nonnull Thread thread, @Nullable String template, @Nullable Object... args) {
+        String message = String.format(String.valueOf(template), args);
+        mInterruptMessage.put(thread, message);
+        if (isInterruptible(thread)) {
+            thread.interrupt();
+        }
+    }
+
+    /**
+     * Interrupts the current thread if it should be interrupted. Threads are encouraged to
+     * periodically call this method in order to throw the right {@link RunInterruptedException}.
+     */
+    public void checkInterrupted() {
+        Thread thread = Thread.currentThread();
+        if (isInterruptible()) {
+            String message = mInterruptMessage.remove(thread);
             if (message != null) {
-                thread.interrupt();
                 throw new RunInterruptedException(message);
             }
         }
     }
-
-    /** Allow to stop the Timer Thread for the run util instance if started. */
-    @VisibleForTesting
-    // FIXME: reduce visibility once RunUtil interrupt tests are removed
-    public void terminateTimer() {
-        if (mWatchdogInterrupt != null && !mWatchdogInterrupt.isEmpty()) {
-            for (Timer t : mWatchdogInterrupt.values()) {
-                t.purge();
-                t.cancel();
-            }
-        }
-    }
-
-    /** Timer that will execute a interrupt on the Thread registered. */
-    private class InterruptTask extends TimerTask {
-
-        private Thread mToInterrupt = null;
-
-        public InterruptTask(Thread t) {
-            mToInterrupt = t;
-        }
-
-        @Override
-        public void run() {
-            if (mToInterrupt != null) {
-                synchronized (mWatchdogInterrupt) {
-                    // Ensure that the timer associated with the task is cancelled too.
-                    mWatchdogInterrupt.get(mToInterrupt).cancel();
-                }
-
-                CLog.e("Interrupting with TimerTask");
-                synchronized (mMapIsInterruptAllowed) {
-                    mMapIsInterruptAllowed.put(mToInterrupt, true);
-                }
-                mToInterrupt.interrupt();
-
-                synchronized (mWatchdogInterrupt) {
-                    mWatchdogInterrupt.remove(mToInterrupt);
-                }
-            }
-        }
-    }
 }
diff --git a/src/com/android/tradefed/util/RunUtil.java b/src/com/android/tradefed/util/RunUtil.java
index 40d3bcc..62f73ee 100644
--- a/src/com/android/tradefed/util/RunUtil.java
+++ b/src/com/android/tradefed/util/RunUtil.java
@@ -463,25 +463,30 @@
     /** {@inheritDoc} */
     @Override
     public void allowInterrupt(boolean allow) {
-        mInterrupter.allowInterrupt(allow);
+        if (allow) {
+            mInterrupter.allowInterrupt();
+        } else {
+            mInterrupter.blockInterrupt();
+        }
     }
 
     /** {@inheritDoc} */
     @Override
     public boolean isInterruptAllowed() {
-        return mInterrupter.isInterruptAllowed();
+        return mInterrupter.isInterruptible();
     }
 
     /** {@inheritDoc} */
     @Override
     public void setInterruptibleInFuture(Thread thread, final long timeMs) {
-        mInterrupter.setInterruptibleInFuture(thread, timeMs);
+        mInterrupter.allowInterruptAsync(thread, timeMs, TimeUnit.MILLISECONDS);
     }
 
     /** {@inheritDoc} */
     @Override
     public synchronized void interrupt(Thread thread, String message) {
         mInterrupter.interrupt(thread, message);
+        mInterrupter.checkInterrupted();
     }
 
     /**
@@ -722,15 +727,7 @@
         return t;
     }
 
-    /** Allow to stop the Timer Thread for the run util instance if started. */
-    @VisibleForTesting
-    void terminateTimer() {
-        mInterrupter.terminateTimer();
-    }
-
-    /**
-     * {@inheritDoc}
-     */
+    /** {@inheritDoc} */
     @Override
     public void setEnvVariablePriority(EnvPriority priority) {
         if (this.equals(sDefaultInstance)) {
diff --git a/tests/src/com/android/tradefed/command/CommandInterrupterTest.java b/tests/src/com/android/tradefed/command/CommandInterrupterTest.java
index 596b37c..df67c9a 100644
--- a/tests/src/com/android/tradefed/command/CommandInterrupterTest.java
+++ b/tests/src/com/android/tradefed/command/CommandInterrupterTest.java
@@ -28,6 +28,9 @@
 import org.junit.runner.RunWith;
 import org.junit.runners.JUnit4;
 
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
 /** Unit tests for {@link CommandInterrupter} */
 @RunWith(JUnit4.class)
 public class CommandInterrupterTest {
@@ -46,11 +49,11 @@
         execute(
                 () -> {
                     // interrupts initially blocked
-                    assertFalse(mInterrupter.isInterruptAllowed());
+                    assertFalse(mInterrupter.isInterruptible());
 
                     // thread can be made interruptible
-                    mInterrupter.allowInterrupt(true);
-                    assertTrue(mInterrupter.isInterruptAllowed());
+                    mInterrupter.allowInterrupt();
+                    assertTrue(mInterrupter.isInterruptible());
                 });
     }
 
@@ -58,10 +61,14 @@
     public void testInterrupt() throws InterruptedException {
         execute(
                 () -> {
+                    // flag thread for interruption
+                    mInterrupter.allowInterrupt();
+                    mInterrupter.interrupt(Thread.currentThread(), MESSAGE);
+                    assertTrue(Thread.interrupted());
+
                     try {
-                        // can interrupt the thread
-                        mInterrupter.allowInterrupt(true);
-                        mInterrupter.interrupt(Thread.currentThread(), MESSAGE);
+                        // will be interrupted
+                        mInterrupter.checkInterrupted();
                         fail("RunInterruptedException was expected");
                     } catch (RunInterruptedException e) {
                         assertEquals(MESSAGE, e.getMessage());
@@ -73,60 +80,79 @@
     public void testInterrupt_blocked() throws InterruptedException {
         execute(
                 () -> {
-                    // track whether interrupts were successfully blocked
-                    boolean success = false;
+                    // block interrupts, but flag for interruption
+                    mInterrupter.blockInterrupt();
+                    mInterrupter.interrupt(Thread.currentThread(), MESSAGE);
+                    assertFalse(Thread.interrupted());
+
+                    // not interrupted
+                    mInterrupter.checkInterrupted();
 
                     try {
-                        // not interrupted if interrupts disallowed
-                        mInterrupter.allowInterrupt(false);
-                        mInterrupter.interrupt(Thread.currentThread(), MESSAGE);
-                        success = true;
-
-                        // interrupted once interrupts allowed
-                        mInterrupter.allowInterrupt(true);
+                        // will be interrupted once interrupts are allowed
+                        mInterrupter.allowInterrupt();
                         fail("RunInterruptedException was expected");
                     } catch (RunInterruptedException e) {
                         assertEquals(MESSAGE, e.getMessage());
-                        assertTrue(success);
                     }
                 });
     }
 
     @Test
-    public void testSetInterruptibleInFuture() throws InterruptedException {
+    public void testInterrupt_clearsFlag() throws InterruptedException {
         execute(
                 () -> {
+                    // flag thread for interruption
+                    mInterrupter.allowInterrupt();
+                    mInterrupter.interrupt(Thread.currentThread(), MESSAGE);
+                    assertTrue(Thread.interrupted());
+
                     try {
-                        // allow interruptions after a delay
-                        mInterrupter.setInterruptibleInFuture(Thread.currentThread(), 200L);
-
-                        // not yet marked as interruptible
-                        RunUtil.getDefault().sleep(50);
-                        assertFalse(mInterrupter.isInterruptAllowed());
-
-                        // marked as interruptible after enough time has passed
-                        RunUtil.getDefault().sleep(200L);
-                        assertTrue(mInterrupter.isInterruptAllowed());
-                    } finally {
-                        mInterrupter.terminateTimer();
+                        // interrupt the thread
+                        mInterrupter.checkInterrupted();
+                        fail("RunInterruptedException was expected");
+                    } catch (RunInterruptedException e) {
+                        // ignore
                     }
+
+                    // interrupt flag was cleared, exception no longer thrown
+                    mInterrupter.checkInterrupted();
                 });
     }
 
     @Test
-    public void testSetInterruptibleInFuture_alreadyAllowed() throws InterruptedException {
+    public void testAllowInterruptAsync() throws InterruptedException {
         execute(
                 () -> {
-                    try {
-                        // interrupts allowed
-                        mInterrupter.allowInterrupt(true);
+                    // allow interruptions after a delay
+                    Future<?> future =
+                            mInterrupter.allowInterruptAsync(
+                                    Thread.currentThread(), 200L, TimeUnit.MILLISECONDS);
+                    assertFalse(future.isDone());
 
-                        // unchanged after asynchronously allowing interrupts
-                        mInterrupter.setInterruptibleInFuture(Thread.currentThread(), 200L);
-                        assertTrue(mInterrupter.isInterruptAllowed());
-                    } finally {
-                        mInterrupter.terminateTimer();
-                    }
+                    // not yet marked as interruptible
+                    RunUtil.getDefault().sleep(50);
+                    assertFalse(mInterrupter.isInterruptible());
+
+                    // marked as interruptible after enough time has passed
+                    RunUtil.getDefault().sleep(200L);
+                    assertTrue(mInterrupter.isInterruptible());
+                });
+    }
+
+    @Test
+    public void testAllowInterruptsAsync_alreadyAllowed() throws InterruptedException {
+        execute(
+                () -> {
+                    // interrupts allowed
+                    mInterrupter.allowInterrupt();
+
+                    // unchanged after asynchronously allowing interrupt
+                    Future<?> future =
+                            mInterrupter.allowInterruptAsync(
+                                    Thread.currentThread(), 200L, TimeUnit.MILLISECONDS);
+                    assertTrue(future.isDone());
+                    assertTrue(mInterrupter.isInterruptible());
                 });
     }
 
diff --git a/tests/src/com/android/tradefed/util/RunUtilTest.java b/tests/src/com/android/tradefed/util/RunUtilTest.java
index 0ef6bf1..a6cc003 100644
--- a/tests/src/com/android/tradefed/util/RunUtilTest.java
+++ b/tests/src/com/android/tradefed/util/RunUtilTest.java
@@ -426,7 +426,6 @@
                                     assertEquals("TEST", rie.getMessage());
                                 }
                                 success = mRunUtil.isInterruptAllowed();
-                                mRunUtil.terminateTimer();
                             }
                         });
         mRunUtil.interrupt(test, "TEST");
@@ -445,16 +444,13 @@
     public void testSetInterruptibleInFuture_beforeTimeout() {
         mRunUtil.allowInterrupt(false);
         assertFalse(mRunUtil.isInterruptAllowed());
-        try {
-            mRunUtil.setInterruptibleInFuture(Thread.currentThread(), SHORT_TIMEOUT_MS);
-            mRunUtil.sleep(50);
-            // Should still be false
-            assertFalse(mRunUtil.isInterruptAllowed());
-            mRunUtil.sleep(SHORT_TIMEOUT_MS);
-            assertTrue(mRunUtil.isInterruptAllowed());
-        } finally {
-            mRunUtil.terminateTimer();
-        }
+
+        mRunUtil.setInterruptibleInFuture(Thread.currentThread(), SHORT_TIMEOUT_MS);
+        mRunUtil.sleep(50);
+        // Should still be false
+        assertFalse(mRunUtil.isInterruptAllowed());
+        mRunUtil.sleep(SHORT_TIMEOUT_MS);
+        assertTrue(mRunUtil.isInterruptAllowed());
     }
 
     /**