fix test_backward_node_failure flakiness (#31588)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31588
Per title. This test can sometimes fail with a different error regex
than the one that is currently tested, so add this error regex to make the test
pass consistently.
Differential Revision: D19222275
fbshipit-source-id: 89c95276d4d9beccf9e0961f970493750d78a96b
diff --git a/test/dist_autograd_test.py b/test/dist_autograd_test.py
index 4717bc6..e08634c 100644
--- a/test/dist_autograd_test.py
+++ b/test/dist_autograd_test.py
@@ -1135,7 +1135,8 @@
if rank % 2 != 0:
wait_until_node_failure(rank)
- with self.assertRaisesRegex(RuntimeError, "Request aborted during client shutdown"):
+ with self.assertRaisesRegex(RuntimeError, "(Request aborted during client shutdown)|"
+ "(worker.: Error in reponse from worker.: server shutting down)"):
# Run backwards, and validate we receive an error since all
# other nodes are dead.
dist_autograd.backward([res.sum()])
@@ -1314,7 +1315,8 @@
# Wait for rank 2 to die.
wait_until_node_failure(2)
- with self.assertRaisesRegex(RuntimeError, "Request aborted during client shutdown"):
+ with self.assertRaisesRegex(RuntimeError, "(Request aborted during client shutdown)|"
+ "(worker.: Error in reponse from worker.: server shutting down)"):
# Run backwards, and validate we receive an error since rank 2 is dead.
dist_autograd.backward([res.sum()])