Multiple fixes to test_c10d.py. (#25334)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25334
1) There was a bug in https://github.com/pytorch/pytorch/pull/25012, where the
tests which needed to be skipped for return code checking was incorrect.
2) Added proper setup and teardown for the nccl_error tests.
3) Ensure AssertionError is not ignored for tests that skip return code
checking.
Test Plan: unit tests
Differential Revision: D17003555
fbshipit-source-id: 0e0429367fb6dae251b74e9f8b2baa67a48a0d22
diff --git a/test/common_distributed.py b/test/common_distributed.py
index 15980e9..006bea6 100644
--- a/test/common_distributed.py
+++ b/test/common_distributed.py
@@ -5,6 +5,7 @@
import tempfile
import time
import unittest
+import logging
from collections import namedtuple
from functools import wraps
@@ -88,6 +89,11 @@
class MultiProcessTestCase(TestCase):
MAIN_PROCESS_RANK = -1
+ # This exit code is used to indicate that the test code had an error and
+ # exited abnormally. There are certain tests that might use sys.exit() to
+ # simulate failures and in those cases, we can't have an exit code of 0,
+ # but we still want to ensure we didn't run into any other errors.
+ TEST_ERROR_EXIT_CODE = 10
@property
def world_size(self):
@@ -100,7 +106,12 @@
if self.rank == self.MAIN_PROCESS_RANK:
self._join_processes(fn)
else:
- fn(self)
+ try:
+ fn(self)
+ except Exception as e:
+ logging.error('Caught exception: {}, exiting process with exit code: {}'
+ .format(e, MultiProcessTestCase.TEST_ERROR_EXIT_CODE))
+ sys.exit(MultiProcessTestCase.TEST_ERROR_EXIT_CODE)
return wrapper
# The main process spawns N subprocesses that run the test.
@@ -147,8 +158,19 @@
p.join(timeout)
elapsed_time = time.time() - start_time
if fn in self.skip_return_code_checks:
+ self._check_no_test_errors(elapsed_time)
+ else:
self._check_return_codes(elapsed_time)
+ def _check_no_test_errors(self, elapsed_time):
+ """
+ Checks that we didn't have any errors thrown in the child processes.
+ """
+ for i, p in enumerate(self.processes):
+ if p.exitcode is None:
+ raise RuntimeError('Process {} timed out after {} seconds'.format(i, elapsed_time))
+ self.assertNotEqual(self.TEST_ERROR_EXIT_CODE, p.exitcode)
+
def _check_return_codes(self, elapsed_time):
"""
Checks that the return codes of all spawned processes match, and skips
diff --git a/test/test_c10d.py b/test/test_c10d.py
index def4f12..4667069 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -645,6 +645,7 @@
(i * self.world_size) + (i % self.world_size)
]),
inputs[i],
+ None,
"Mismatch in iteration %d" % i,
)
@@ -727,6 +728,7 @@
(self.world_size * (self.world_size - 1) / 2)
]),
inputs[i],
+ None,
"Mismatch in iteration %d" % i,
)
@@ -972,6 +974,7 @@
self.assertEqual(
torch.Tensor([iter + root]),
outputs[iter][root],
+ None,
"Mismatch in iteration %d for rank %d" % (iter, root)
)
@@ -1118,6 +1121,7 @@
self.assertEqual(
expected_outputs[iter],
outputs[iter],
+ None,
"Mismatch in iteration %d for root %d" % (iter, root)
)
@@ -1218,6 +1222,7 @@
self.assertEqual(
expected_outputs[i],
outputs[i],
+ None,
"Mismatch in iteration %d" % i
)
@@ -1306,6 +1311,7 @@
(self.world_size * (self.world_size - 1) / 2)
]),
outputs[i],
+ None,
"Mismatch in iteration %d with root rank %d" % (iter, root),
)
@@ -2893,12 +2899,20 @@
# Need to skip return code checking for these tests since the child
# processes don't exit cleanly.
self.skip_return_code_checks = [
- self.test_nccl_errors_blocking_abort,
- self.test_nccl_errors_blocking_sigkill,
- self.test_nccl_errors_blocking_sigstop,
- self.test_nccl_errors_blocking_sigterm,
+ self._get_wrapped_func(self.test_nccl_errors_blocking_abort),
+ self._get_wrapped_func(self.test_nccl_errors_blocking_sigkill),
+ self._get_wrapped_func(self.test_nccl_errors_blocking_sigterm),
+ self._get_wrapped_func(self.test_nccl_errors_blocking_nonzero_exit),
]
- self.op_timeout_sec = 1
+
+ def _get_wrapped_func(self, func):
+ # Get the original function which was wrapped in the decorator.
+ if hasattr(func, '__wrapped__'):
+ # py3 way.
+ return func.__wrapped__
+ else:
+ # py2 way.
+ return func.func_closure[0].cell_contents
def tearDown(self):
super(CommTest, self).tearDown()
@@ -2906,7 +2920,10 @@
os.remove(self.file.name)
except OSError:
pass
- os.environ["NCCL_BLOCKING_WAIT"] = "0"
+
+ @property
+ def op_timeout_sec(self):
+ return 1
@property
def world_size(self):
@@ -2959,8 +2976,9 @@
# Now the work scheduled next should hang forever since the previous
# allreduce will never complete.
t = threading.Thread(target=self._run_all_reduce, args=(process_group,))
+ t.daemon = True
t.start()
- t.join(int(get_timeout(self.id()) / 2))
+ t.join(int(get_timeout(self.id()) / 5))
self.assertTrue(t.is_alive())
def _test_nccl_errors_blocking(self, func):
@@ -2983,6 +3001,11 @@
@requires_nccl()
@skip_if_not_multigpu
+ def test_nccl_errors_blocking_nonzero_exit(self):
+ self._test_nccl_errors_blocking(lambda : sys.exit(1))
+
+ @requires_nccl()
+ @skip_if_not_multigpu
def test_nccl_errors_blocking_abort(self):
self._test_nccl_errors_blocking(lambda : os.abort())
@@ -2993,15 +3016,6 @@
@requires_nccl()
@skip_if_not_multigpu
- def test_nccl_errors_blocking_sigstop(self):
- self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGSTOP))
- if self.rank == 0:
- time.sleep(2 * self.op_timeout_sec)
- for i in range(1, len(self.processes)):
- os.kill(self.processes[i].pid, signal.SIGCONT)
-
- @requires_nccl()
- @skip_if_not_multigpu
def test_nccl_errors_blocking_sigterm(self):
self._test_nccl_errors_blocking(lambda : os.kill(os.getpid(), signal.SIGTERM))