Multiple fixes to test_c10d.py. (#25334)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25334

1) There was a bug in https://github.com/pytorch/pytorch/pull/25012, where the
tests which needed to be skipped for return code checking was incorrect.
2) Added proper setup and teardown for the nccl_error tests.
3) Ensure AssertionError is not ignored for tests that skip return code
checking.

Test Plan: unit tests

Differential Revision: D17003555

fbshipit-source-id: 0e0429367fb6dae251b74e9f8b2baa67a48a0d22
diff --git a/test/common_distributed.py b/test/common_distributed.py
index 15980e9..006bea6 100644
--- a/test/common_distributed.py
+++ b/test/common_distributed.py
@@ -5,6 +5,7 @@
 import tempfile
 import time
 import unittest
+import logging
 
 from collections import namedtuple
 from functools import wraps
@@ -88,6 +89,11 @@
 
 class MultiProcessTestCase(TestCase):
     MAIN_PROCESS_RANK = -1
+    # This exit code is used to indicate that the test code had an error and
+    # exited abnormally. There are certain tests that might use sys.exit() to
+    # simulate failures and in those cases, we can't have an exit code of 0,
+    # but we still want to ensure we didn't run into any other errors.
+    TEST_ERROR_EXIT_CODE = 10
 
     @property
     def world_size(self):
@@ -100,7 +106,12 @@
             if self.rank == self.MAIN_PROCESS_RANK:
                 self._join_processes(fn)
             else:
-                fn(self)
+                try:
+                    fn(self)
+                except Exception as e:
+                    logging.error('Caught exception: {}, exiting process with exit code: {}'
+                                  .format(e, MultiProcessTestCase.TEST_ERROR_EXIT_CODE))
+                    sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
         return wrapper
 
     # The main process spawns N subprocesses that run the test.
@@ -147,8 +158,19 @@
             p.join(timeout)
         elapsed_time = time.time() - start_time
         if fn in self.skip_return_code_checks:
+            self._check_no_test_errors(elapsed_time)
+        else:
             self._check_return_codes(elapsed_time)
 
+    def _check_no_test_errors(self, elapsed_time):
+        """
+        Checks that we didn't have any errors thrown in the child processes.
+        """
+        for i, p in enumerate(self.processes):
+            if p.exitcode is None:
+                raise RuntimeError('Process {} timed out after {} seconds'.format(i, elapsed_time))
+            self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
+
     def _check_return_codes(self, elapsed_time):
         """
         Checks that the return codes of all spawned processes match, and skips
diff --git a/test/test_c10d.py b/test/test_c10d.py
index def4f12..4667069 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -645,6 +645,7 @@
                     (i * self.world_size) + (i % self.world_size)
                 ]),
                 inputs[i],
+                None,
                 "Mismatch in iteration %d" % i,
             )
 
@@ -727,6 +728,7 @@
                     (self.world_size * (self.world_size - 1) / 2)
                 ]),
                 inputs[i],
+                None,
                 "Mismatch in iteration %d" % i,
             )
 
@@ -972,6 +974,7 @@
             self.assertEqual(
                 torch.Tensor([iter + root]),
                 outputs[iter][root],
+                None,
                 "Mismatch in iteration %d for rank %d" % (iter, root)
             )
 
@@ -1118,6 +1121,7 @@
                 self.assertEqual(
                     expected_outputs[iter],
                     outputs[iter],
+                    None,
                     "Mismatch in iteration %d for root %d" % (iter, root)
                 )
 
@@ -1218,6 +1222,7 @@
             self.assertEqual(
                 expected_outputs[i],
                 outputs[i],
+                None,
                 "Mismatch in iteration %d" % i
             )
 
@@ -1306,6 +1311,7 @@
                         (self.world_size * (self.world_size - 1) / 2)
                     ]),
                     outputs[i],
+                    None,
                     "Mismatch in iteration %d with root rank %d" % (iter, root),
                 )
 
@@ -2893,12 +2899,20 @@
         # Need to skip return code checking for these tests since the child
         # processes don't exit cleanly.
         self.skip_return_code_checks = [
-            self.test_nccl_errors_blocking_abort,
-            self.test_nccl_errors_blocking_sigkill,
-            self.test_nccl_errors_blocking_sigstop,
-            self.test_nccl_errors_blocking_sigterm,
+            self._get_wrapped_func(self.test_nccl_errors_blocking_abort),
+            self._get_wrapped_func(self.test_nccl_errors_blocking_sigkill),
+            self._get_wrapped_func(self.test_nccl_errors_blocking_sigterm),
+            self._get_wrapped_func(self.test_nccl_errors_blocking_nonzero_exit),
         ]
-        self.op_timeout_sec = 1
+
+    def _get_wrapped_func(self, func):
+        # Get the original function which was wrapped in the decorator.
+        if hasattr(func, '__wrapped__'):
+            # py3 way.
+            return func.__wrapped__
+        else:
+            # py2 way.
+            return func.func_closure[0].cell_contents
 
     def tearDown(self):
         super(CommTest, self).tearDown()
@@ -2906,7 +2920,10 @@
             os.remove(self.file.name)
         except OSError:
             pass
-        os.environ["NCCL_BLOCKING_WAIT"] = "0"
+
+    @property
+    def op_timeout_sec(self):
+        return 1
 
     @property
     def world_size(self):
@@ -2959,8 +2976,9 @@
             # Now the work scheduled next should hang forever since the previous
             # allreduce will never complete.
             t = threading.Thread(target=self._run_all_reduce, args=(process_group,))
+            t.daemon = True
             t.start()
-            t.join(int(get_timeout(self.id()) / 2))
+            t.join(int(get_timeout(self.id()) / 5))
             self.assertTrue(t.is_alive())
 
     def _test_nccl_errors_blocking(self, func):
@@ -2983,6 +3001,11 @@
 
     @requires_nccl()
     @skip_if_not_multigpu
+    def test_nccl_errors_blocking_nonzero_exit(self):
+        self._test_nccl_errors_blocking(lambda : sys.exit(1))
+
+    @requires_nccl()
+    @skip_if_not_multigpu
     def test_nccl_errors_blocking_abort(self):
         self._test_nccl_errors_blocking(lambda : os.abort())
 
@@ -2993,15 +3016,6 @@
 
     @requires_nccl()
     @skip_if_not_multigpu
-    def test_nccl_errors_blocking_sigstop(self):
-        self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGSTOP))
-        if self.rank == 0:
-            time.sleep(2 * self.op_timeout_sec)
-            for i in range(1, len(self.processes)):
-                os.kill(self.processes[i].pid, signal.SIGCONT)
-
-    @requires_nccl()
-    @skip_if_not_multigpu
     def test_nccl_errors_blocking_sigterm(self):
         self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGTERM))