Merge "Allow Python test class to extend timeout"
diff --git a/harnesses/tradefed/src/com/android/tradefed/testtype/VtsMultiDeviceTest.java b/harnesses/tradefed/src/com/android/tradefed/testtype/VtsMultiDeviceTest.java
index d827087..324cfbc 100644
--- a/harnesses/tradefed/src/com/android/tradefed/testtype/VtsMultiDeviceTest.java
+++ b/harnesses/tradefed/src/com/android/tradefed/testtype/VtsMultiDeviceTest.java
@@ -86,6 +86,7 @@
     static final String TEST_BED = "test_bed";
     static final String TEST_PLAN_REPORT_FILE = "TEST_PLAN_REPORT_FILE";
     static final String TEST_SUITE = "test_suite";
+    static final String TEST_TIMEOUT = "test_timeout";
     static final String ABI_NAME = "abi_name";
     static final String ABI_BITNESS = "abi_bitness";
     static final String SKIP_ON_32BIT_ABI = "skip_on_32bit_abi";
@@ -153,22 +154,23 @@
     static final String TEMPLATE_HAL_HIDL_GTEST_PATH = "vts/testcases/template/hal_hidl_gtest/hal_hidl_gtest";
     static final String TEMPLATE_HAL_HIDL_REPLAY_TEST_PATH = "vts/testcases/template/hal_hidl_replay_test/hal_hidl_replay_test";
     static final String TEMPLATE_HOST_BINARY_TEST_PATH = "vts/testcases/template/host_binary_test/host_binary_test";
-    static final long TEST_ABORT_TIMEOUT_MSECS = 1000 * 15;
     static final String TEST_RUN_SUMMARY_FILE_NAME = "test_run_summary.json";
     static final float DEFAULT_TARGET_VERSION = -1;
     static final String DEFAULT_TESTCASE_CONFIG_PATH =
             "vts/tools/vts-tradefed/res/default/DefaultTestCase.runner_conf";
+    // TODO(hsinyichen): Read max-test-timeout from configuration
+    static final long MAX_TEST_TIMEOUT_MSECS = 1000 * 60 * 60 * 10;
 
     private ITestDevice mDevice = null;
     private IAbi mAbi = null;
 
     @Option(name = "test-timeout",
             description = "The amount of time (in milliseconds) for a test invocation. "
-                    + "If the test cannot finish before timeout, it is interrupted and cleans up "
-                    + "in " + TEST_ABORT_TIMEOUT_MSECS + "ms. Hence the actual timeout is the "
-                    + "specified value + " + TEST_ABORT_TIMEOUT_MSECS + "ms.",
+                    + "If the test cannot finish before timeout, it is interrupted. As some "
+                    + "classes generate test cases during setup, they can use the given timeout "
+                    + "value for each generated test set.",
             isTimeVal = true)
-    private long mTestTimeout = 1000 * 60 * 60 * 3;
+    private long mTestTimeout = 1000 * 60 * 3;
 
     @Option(name = "test-module-name",
         description = "The name for a test module.")
@@ -857,6 +859,9 @@
         jsonObject.put(TEST_SUITE, suite);
         CLog.d("Added %s to the Json object", TEST_SUITE);
 
+        jsonObject.put(TEST_TIMEOUT, mTestTimeout);
+        CLog.i("Added %s to the Json object: %d", TEST_TIMEOUT, mTestTimeout);
+
         if (!mLogSeverity.isEmpty()) {
             String logSeverity = mLogSeverity.toUpperCase();
             ArrayList<String> severityList =
@@ -1227,7 +1232,7 @@
 
         CommandResult commandResult = new CommandResult();
         String interruptMessage = vtsPythonRunnerHelper.runPythonRunner(
-                cmd.toArray(new String[0]), commandResult, mTestTimeout);
+                cmd.toArray(new String[0]), commandResult, MAX_TEST_TIMEOUT_MSECS);
 
         if (commandResult != null) {
             CommandStatus commandStatus = commandResult.getStatus();
diff --git a/harnesses/tradefed/src/com/android/tradefed/util/VtsPythonRunnerHelper.java b/harnesses/tradefed/src/com/android/tradefed/util/VtsPythonRunnerHelper.java
index 4f05a7a..64cca4b 100644
--- a/harnesses/tradefed/src/com/android/tradefed/util/VtsPythonRunnerHelper.java
+++ b/harnesses/tradefed/src/com/android/tradefed/util/VtsPythonRunnerHelper.java
@@ -38,7 +38,7 @@
     static final String PATH = "PATH";
     static final String PYTHONHOME = "PYTHONHOME";
     static final String VTS = "vts";
-    static final long TEST_ABORT_TIMEOUT_MSECS = 1000 * 15;
+    static final long TEST_ABORT_TIMEOUT_MSECS = 1000 * 40;
 
     // Python virtual environment root path
     private File mVirtualenvPath;
diff --git a/runners/host/base_test.py b/runners/host/base_test.py
index d10c983..c6f4d25 100644
--- a/runners/host/base_test.py
+++ b/runners/host/base_test.py
@@ -18,6 +18,7 @@
 import os
 import re
 import sys
+import threading
 
 from vts.proto import VtsReportMessage_pb2 as ReportMsg
 from vts.runners.host import asserts
@@ -48,9 +49,11 @@
 RESULT_LINE_TEMPLATE = TEST_CASE_TEMPLATE + " %s"
 STR_TEST = "test"
 STR_GENERATE = "generate"
+TEARDOWN_CLASS_TIMEOUT_SECS = 30
 _REPORT_MESSAGE_FILE_NAME = "report_proto.msg"
 _BUG_REPORT_FILE_PREFIX = "bugreport_"
 _BUG_REPORT_FILE_EXTENSION = ".zip"
+_DEFAULT_TEST_TIMEOUT_SECS = 60 * 3
 _LOGCAT_FILE_PREFIX = "logcat_"
 _LOGCAT_FILE_EXTENSION = ".txt"
 _ANDROID_DEVICES = '_android_devices'
@@ -90,6 +93,11 @@
         _current_record: A records.TestResultRecord object for the test case
                          currently being executed. If no test is running, this
                          should be None.
+        _interrupted: Whether the test execution has been interrupted.
+        _interrupt_lock: The threading.Lock object that protects _interrupted.
+        _timer: The threading.Timer object that interrupts main thread when
+                timeout.
+        timeout: A float, the timeout, in seconds, configured for this object.
         include_filer: A list of string, each representing a test case name to
                        include.
         exclude_filer: A list of string, each representing a test case name to
@@ -119,6 +127,22 @@
         self.log = logger.LoggerProxy()
         self._current_record = None
 
+        # Timeout
+        self._interrupted = False
+        self._interrupt_lock = threading.Lock()
+        self._timer = None
+        self.timeout = self.getUserParam(
+            keys.ConfigKeys.KEY_TEST_TIMEOUT,
+            default_value=_DEFAULT_TEST_TIMEOUT_SECS * 1000.0)
+        try:
+            self.timeout = float(self.timeout) / 1000.0
+        except (TypeError, ValueError):
+            logging.error("Cannot parse timeout: %s", self.timeout)
+            self.timeout = _DEFAULT_TEST_TIMEOUT_SECS
+        if self.timeout <= 0:
+            logging.error("Invalid timeout: %s", self.timeout)
+            self.timeout = _DEFAULT_TEST_TIMEOUT_SECS
+
         # Setup test filters
         self.include_filter = self.getUserParam(
             [
@@ -391,6 +415,8 @@
         """Proxy function to guarantee the base implementation of setUpClass
         is called.
         """
+        self.resetTimeout(self.timeout)
+
         if not precondition_utils.MeetFirstApiLevelPrecondition(self):
             self.skipAllTests("The device's first API level doesn't meet the "
                               "precondition.")
@@ -446,6 +472,8 @@
         is called.
         """
         ret = self.tearDownClass()
+
+        self.resetTimeout(TEARDOWN_CLASS_TIMEOUT_SECS)
         if self.log_uploading.enabled:
             self.log_uploading.UploadLogs()
         if self.web.enabled:
@@ -473,6 +501,40 @@
         """
         pass
 
+    def interrupt(self):
+        """Interrupts test execution and terminates process."""
+        with self._interrupt_lock:
+            if self._interrupted:
+                logging.warning("Cannot interrupt more than once.")
+                return
+            self._interrupted = True
+
+        utils.stop_current_process(TEARDOWN_CLASS_TIMEOUT_SECS)
+
+    def resetTimeout(self, timeout):
+        """Restarts the timer that will interrupt the main thread.
+
+        This class starts the timer before setUpClass. As the timeout depends
+        on number of generated tests, the subclass can restart the timer.
+
+        Args:
+            timeout: A float, wait time in seconds before interrupt.
+        """
+        with self._interrupt_lock:
+            if self._interrupted:
+                logging.warning("Test execution has been interrupted. "
+                                "Cannot reset timeout.")
+                return
+
+        if self._timer:
+            logging.info("Cancel timer.")
+            self._timer.cancel()
+
+        logging.info("Start timer with timeout=%ssec.", timeout)
+        self._timer = threading.Timer(timeout, self.interrupt)
+        self._timer.daemon = True
+        self._timer.start()
+
     def _testEntry(self, test_record):
         """Internal function to be called upon entry of a test case.
 
@@ -1028,16 +1090,23 @@
         """
         # Setup for the class with retry.
         for i in xrange(_SETUP_RETRY_NUMBER):
+            setup_done = False
+            caught_exception = None
             try:
                 if self._setUpClass() is False:
                     raise signals.TestFailure(
                         "Failed to setup %s." % self.test_module_name)
                 else:
-                    break
+                    setup_done = True
             except Exception as e:
+                caught_exception = e
                 logging.exception("Failed to setup %s.", self.test_module_name)
-                if i + 1 == _SETUP_RETRY_NUMBER:
-                    self.results.failClass(self.test_module_name, e)
+            finally:
+                if setup_done:
+                    break
+                elif not caught_exception or i + 1 == _SETUP_RETRY_NUMBER:
+                    self.results.failClass(self.test_module_name,
+                                           caught_exception)
                     self._exec_func(self._tearDownClass)
                     return self.results
                 else:
diff --git a/runners/host/keys.py b/runners/host/keys.py
index 51bd477..5ff8cd0 100644
--- a/runners/host/keys.py
+++ b/runners/host/keys.py
@@ -30,6 +30,7 @@
     KEY_TESTBED_NAME = "name"
     KEY_TEST_PATHS = "test_paths"
     KEY_TEST_SUITE = "test_suite"
+    KEY_TEST_TIMEOUT = "test_timeout"
 
     # Keys in test suite
     KEY_INCLUDE_FILTER = "include_filter"
diff --git a/runners/host/test_runner.py b/runners/host/test_runner.py
index 4980a02..85f768b 100644
--- a/runners/host/test_runner.py
+++ b/runners/host/test_runner.py
@@ -25,10 +25,6 @@
 import pkgutil
 import signal
 import sys
-try:
-    import thread
-except ImportError as e:
-    import _thread as thread
 import threading
 
 from vts.runners.host import base_test
@@ -106,25 +102,12 @@
     test_identifiers = [(test_cls_name, None)]
 
     for config in test_configs:
-        watcher_enabled = threading.Event()
-
         def watchStdin():
             while True:
                 line = sys.stdin.readline()
                 if not line:
                     break
-            watcher_enabled.wait()
-            logging.info("Attempt to interrupt runner thread.")
-            if not utils.is_on_windows():
-                # Default SIGINT handler sends KeyboardInterrupt to main thread
-                # and unblocks it.
-                os.kill(os.getpid(), signal.SIGINT)
-            else:
-                # On Windows, raising CTRL_C_EVENT, which is received as
-                # SIGINT, has no effect on non-console process.
-                # interrupt_main() behaves like SIGINT but does not unblock
-                # main thread immediately.
-                thread.interrupt_main()
+            utils.stop_current_process(base_test.TEARDOWN_CLASS_TIMEOUT_SECS)
 
         watcher_thread = threading.Thread(target=watchStdin, name="watchStdin")
         watcher_thread.daemon = True
@@ -133,7 +116,6 @@
         tr = TestRunner(config, test_identifiers)
         tr.parseTestConfig(config)
         try:
-            watcher_enabled.set()
             tr.runTestClass(test_class, None)
         except KeyboardInterrupt as e:
             logging.exception("Aborted")
@@ -141,7 +123,6 @@
             logging.error("Unexpected exception")
             logging.exception(e)
         finally:
-            watcher_enabled.clear()
             tr.stop()
             return tr.results
 
diff --git a/runners/host/utils.py b/runners/host/utils.py
index 972d991..21f9e36 100755
--- a/runners/host/utils.py
+++ b/runners/host/utils.py
@@ -29,6 +29,12 @@
 import time
 import traceback
 
+try:
+    # TODO: remove when we stop supporting Python 2
+    import thread
+except ImportError as e:
+    import _thread as thread
+
 # File name length is limited to 255 chars on some OS, so we need to make sure
 # the file names we output fits within the limit.
 MAX_FILENAME_LEN = 255
@@ -362,6 +368,33 @@
     return os.name == "nt"
 
 
+def stop_current_process(terminate_timeout):
+    """Sends KeyboardInterrupt to main thread and then terminates process.
+
+    The daemon thread calls this function when timeout or user interrupt.
+
+    Args:
+        terminate_timeout: A float, the interval in seconds between interrupt
+                           and termination.
+    """
+    logging.error("Interrupt main thread.")
+    if not is_on_windows():
+        # Default SIGINT handler sends KeyboardInterrupt to main thread
+        # and unblocks it.
+        os.kill(os.getpid(), signal.SIGINT)
+    else:
+        # On Windows, raising CTRL_C_EVENT, which is received as
+        # SIGINT, has no effect on non-console process.
+        # interrupt_main() behaves like SIGINT but does not unblock
+        # main thread immediately.
+        thread.interrupt_main()
+
+    time.sleep(terminate_timeout)
+    logging.error("Terminate current process.")
+    # Send SIGTERM on Linux. Call terminateProcess() on Windows.
+    os.kill(os.getpid(), signal.SIGTERM)
+
+
 def kill_process_group(proc, signal_no=signal.SIGTERM):
     """Sends signal to a process group.
 
diff --git a/testcases/template/hal_hidl_gtest/hal_hidl_gtest.py b/testcases/template/hal_hidl_gtest/hal_hidl_gtest.py
index 13f5bd4..3e91ed0 100644
--- a/testcases/template/hal_hidl_gtest/hal_hidl_gtest.py
+++ b/testcases/template/hal_hidl_gtest/hal_hidl_gtest.py
@@ -40,12 +40,14 @@
         testcases: list of GtestTestCase objects, list of test cases to run
         _cpu_freq: CpuFrequencyScalingController instance of a target device.
         _dut: AndroidDevice, the device under test as config
+        _initial_test_case_cnt: Number of initial test cases.
         _target_hals: List of String, the targeting hal service of the test.
                       e.g (["android.hardware.foo@1.0::IFoo"])
     """
 
     def setUpClass(self):
         """Checks precondition."""
+        self._initial_test_case_cnt = 0
         super(HidlHalGTest, self).setUpClass()
         if not hasattr(self, "_target_hals"):
             self._target_hals = []
@@ -76,6 +78,12 @@
             if not ret:
                 self.skipAllTests("HIDL HAL precondition check failed.")
 
+        # Extend timeout if there are multiple service instance combinations.
+        if (not self.isSkipAllTests() and self._initial_test_case_cnt and
+                len(self.testcases) > self._initial_test_case_cnt):
+            self.resetTimeout(self.timeout * len(self.testcases) /
+                              float(self._initial_test_case_cnt))
+
         if self.sancov.enabled and self._target_hals:
             self.sancov.InitializeDeviceCoverage(self._dut,
                                                  self._target_hals)
@@ -119,6 +127,7 @@
         """
         initial_test_cases = super(HidlHalGTest, self).CreateTestCase(path,
                                                                       tag)
+        self._initial_test_case_cnt += len(initial_test_cases)
         if not initial_test_cases:
             return initial_test_cases
         # first, run one test with --list_registered_services.
diff --git a/testcases/template/hal_hidl_replay_test/hal_hidl_replay_test.py b/testcases/template/hal_hidl_replay_test/hal_hidl_replay_test.py
index ecb263d..ad1224f 100644
--- a/testcases/template/hal_hidl_replay_test/hal_hidl_replay_test.py
+++ b/testcases/template/hal_hidl_replay_test/hal_hidl_replay_test.py
@@ -48,6 +48,12 @@
         if self.isSkipAllTests():
             return
 
+        # Extend timeout if there are multiple service instance combinations.
+        if (len(self.trace_paths) and
+                len(self.testcases) > len(self.trace_paths)):
+            self.resetTimeout(self.timeout * len(self.testcases) /
+                              float(len(self.trace_paths)))
+
         if self.coverage.enabled and self._test_hal_services is not None:
             self.coverage.SetHalNames(self._test_hal_services)