apply linter to rpc test files (#32659)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32659
Applies linter to RPC test files so that we can use linter shortcuts
without getting unnecessary changes to the whole file.
ghstack-source-id: 97361237
Test Plan: No actual changes.
Differential Revision: D19584742
fbshipit-source-id: a11ce74ee0e2817e6f774fff7c39bcab06e99307
diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
index b32f099..b408b7b 100644
--- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
+++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py
@@ -1,19 +1,29 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import sys
+import threading
import time
import unittest
+from enum import Enum
import torch
import torch.distributed as dist
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
import torch.testing._internal.dist_utils
-from torch.testing._internal.dist_utils import dist_init, wait_until_node_failure, initialize_pg, get_shutdown_error_regex
-from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
from torch.testing import FileCheck
+from torch.testing._internal.dist_utils import (
+ dist_init,
+ get_shutdown_error_regex,
+ initialize_pg,
+ wait_until_node_failure,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+ RpcAgentTestFixture,
+)
-import threading
# Right now we test up to 3-layer nested rpc calls.
# rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id
@@ -53,12 +63,7 @@
# creates an owner rref on the given dst, and the rref holds a torch.ones tensor
# of the given size.
def _create_ones_rref_on(dst, sizes):
- return rpc.remote(
- dst,
- _torch_ones,
- args=(sizes,),
- kwargs={"requires_grad": True}
- )
+ return rpc.remote(dst, _torch_ones, args=(sizes,), kwargs={"requires_grad": True})
# This method must be called on the rref owner, and verifies that the grad of
@@ -71,6 +76,7 @@
def my_py_add(t1, t2):
return torch.add(t1, t2)
+
def my_scalar_add(a, b):
return a + b
@@ -79,10 +85,12 @@
ret = torch.add(rref_t1.local_value(), t2)
return ret
+
@torch.jit.script
def my_script_add(t1, t2):
return torch.add(t1, t2)
+
def my_nested_rref_add(dst, rref_t1, t2):
return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
@@ -94,11 +102,15 @@
def my_py_nested_call(t1, t2, dst, world_size, hops):
next_dst = (dst + 1) % world_size
if hops > 0:
- return rpc.rpc_sync("worker{}".format(next_dst), my_py_nested_call,
- args=(t1, t2, next_dst, world_size, hops - 1))
+ return rpc.rpc_sync(
+ "worker{}".format(next_dst),
+ my_py_nested_call,
+ args=(t1, t2, next_dst, world_size, hops - 1),
+ )
else:
return rpc.rpc_sync("worker{}".format(next_dst), my_py_add, args=(t1, t2))
+
# after dist autograd context is cleaned up, it should be cleaned up on other
# nodes. This helper allows timeout_seconds for those RPCs to be completed, and
# ensures that all the contexts have been cleaned up in that timeframe.any
@@ -128,12 +140,9 @@
dist_autograd.backward([ret.sum()])
# prevent deleting dist autograd context
rpc.rpc_sync(ps, _set_rpc_done, args=(context_id, rank_diff))
- rpc.rpc_sync(ps, _check_rpc_done, args=(0, ))
+ rpc.rpc_sync(ps, _check_rpc_done, args=(0,))
-from torch.autograd import Function
-from torch.autograd.function import once_differentiable
-
class SimulateBackwardError(Function):
@staticmethod
def forward(ctx, input):
@@ -142,9 +151,8 @@
@staticmethod
@once_differentiable
def backward(ctx, input):
- raise Exception('Simulate error on backward pass')
+ raise Exception("Simulate error on backward pass")
-from enum import Enum
class ExecMode(Enum):
LOCAL = 1 # Run the operation locally.
@@ -152,31 +160,35 @@
REMOTE = 3 # Run the operation using remote.
RPC_ASYNC = 4 # Run the operation using rpc_async
+
@unittest.skipIf(
- not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
+ not torch._six.PY3,
+ "Pytorch distributed autograd package " "does not support python2",
)
class DistAutogradTest(RpcAgentTestFixture):
-
def _exec_func(self, exec_mode, method, *args):
if ExecMode.LOCAL == exec_mode:
if len(args) == 1 and isinstance(args[0], list):
return method(*args[0])
return method(*args)
elif ExecMode.RPC_SYNC == exec_mode:
- return rpc.rpc_sync('worker{}'.format(self._next_rank()), method,
- args=(args))
+ return rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), method, args=(args)
+ )
elif ExecMode.REMOTE == exec_mode:
- return rpc.remote('worker{}'.format(self._next_rank()), method,
- args=(args)).to_here()
+ return rpc.remote(
+ "worker{}".format(self._next_rank()), method, args=(args)
+ ).to_here()
elif ExecMode.RPC_ASYNC == exec_mode:
- fut = rpc.rpc_async('worker{}'.format(self._next_rank()), method,
- args=(args))
+ fut = rpc.rpc_async(
+ "worker{}".format(self._next_rank()), method, args=(args)
+ )
return fut.wait()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
def _next_rank(self):
- if hasattr(self, 'dst_rank'):
+ if hasattr(self, "dst_rank"):
self.dst_rank = (self.dst_rank + 1) % self.world_size
if self.dst_rank == self.rank:
return self._next_rank()
@@ -217,7 +229,9 @@
def test_nested_context(self):
with dist_autograd.context() as context_id:
# Nested contexts not supported.
- with self.assertRaisesRegex(RuntimeError, "Already have an autograd context id for this thread"):
+ with self.assertRaisesRegex(
+ RuntimeError, "Already have an autograd context id for this thread"
+ ):
with dist_autograd.context() as context_id:
pass
@@ -234,7 +248,9 @@
# |
# t3.rpcRecvBackward
#
- def _verify_graph_for_first_rpc_call(self, send_function, recv_function, t1, t2, ret):
+ def _verify_graph_for_first_rpc_call(
+ self, send_function, recv_function, t1, t2, ret
+ ):
# Retrieve the next functions in the graph.
next_funcs = send_function.next_functions
self.assertEqual(2, len(next_funcs))
@@ -308,7 +324,6 @@
)
self.assertEqual(next_funcs[0][0], next_funcs[1][0])
-
# For send function when returning resonpose to previous call
# next function of the send function is the recv function
# for received tensor result returned from nested call
@@ -327,8 +342,7 @@
t1 = torch.ones(3, 3, requires_grad=True)
t2 = torch.zeros(3, 3, requires_grad=True)
if ExecMode.RPC_SYNC == exec_mode:
- ret = rpc.rpc_sync(
- "worker{}".format(dst_rank), fn, args=(t1, t2))
+ ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2))
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
"worker{}".format(dst_rank), fn, args=(t1, t2)
@@ -336,8 +350,9 @@
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
- rpc.rpc_sync("worker{}".format(dst_rank),
- _set_rpc_done, args=(context_id, 1))
+ rpc.rpc_sync(
+ "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
+ )
# Verify graph for current context id.
ctx = dist_autograd._current_context()
@@ -346,9 +361,13 @@
self.assertEqual(1, len(send_functions))
recv_functions = ctx._recv_functions()
self.assertEqual(1, len(recv_functions))
- self._verify_graph_for_first_rpc_call(list(send_functions.values())[0],
- list(recv_functions.values())[0],
- t1, t2, ret)
+ self._verify_graph_for_first_rpc_call(
+ list(send_functions.values())[0],
+ list(recv_functions.values())[0],
+ t1,
+ t2,
+ ret,
+ )
# Wait for the prev rank to be done with rpc.
self._check_rpc_done(1)
@@ -399,13 +418,13 @@
ret = rpc.rpc_sync(
"worker{}".format(dst_rank),
my_py_nested_call,
- args=(t1, t2, dst_rank, self.world_size, 1)
+ args=(t1, t2, dst_rank, self.world_size, 1),
)
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
"worker{}".format(dst_rank),
my_py_nested_call,
- args=(t1, t2, dst_rank, self.world_size, 1)
+ args=(t1, t2, dst_rank, self.world_size, 1),
).to_here()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
@@ -414,8 +433,11 @@
dist.barrier()
for rd in [1, 2, 3]:
- rpc.rpc_sync("worker{}".format((self.rank + rd) % self.world_size),
- _set_rpc_done, args=(context_id, rd))
+ rpc.rpc_sync(
+ "worker{}".format((self.rank + rd) % self.world_size),
+ _set_rpc_done,
+ args=(context_id, rd),
+ )
# Barrier to ensure all set_rpc_done have completed.
dist.barrier()
@@ -436,9 +458,13 @@
self.assertEqual(1, len(send_functions))
recv_functions = ctx._recv_functions()
self.assertEqual(1, len(recv_functions))
- self._verify_graph_for_first_rpc_call(list(send_functions.values())[0],
- list(recv_functions.values())[0],
- t1, t2, ret)
+ self._verify_graph_for_first_rpc_call(
+ list(send_functions.values())[0],
+ list(recv_functions.values())[0],
+ t1,
+ t2,
+ ret,
+ )
# Verify second graph for 1st nested call.
ctx = dist_autograd._retrieve_context(ctx_ids[1])
@@ -483,8 +509,8 @@
t2,
(self.rank - 1 + self.world_size) % self.world_size,
self.world_size,
- 0
- )
+ 0,
+ ),
)
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
@@ -495,14 +521,17 @@
t2,
(self.rank - 1 + self.world_size) % self.world_size,
self.world_size,
- 0
- )
+ 0,
+ ),
).to_here()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
- rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size),
- _set_rpc_done, args=(context_id, 1))
+ rpc.rpc_sync(
+ "worker{}".format((self.rank + 1) % self.world_size),
+ _set_rpc_done,
+ args=(context_id, 1),
+ )
# For self.rank, it has 2 graphs to verify.
# One is for current context id when this rank send first rpc
@@ -515,9 +544,13 @@
self.assertEqual(2, len(send_functions))
recv_functions = ctx._recv_functions()
self.assertEqual(2, len(recv_functions))
- self._verify_graph_for_first_rpc_call(list(send_functions.values())[0],
- list(recv_functions.values())[1],
- t1, t2, ret)
+ self._verify_graph_for_first_rpc_call(
+ list(send_functions.values())[0],
+ list(recv_functions.values())[1],
+ t1,
+ t2,
+ ret,
+ )
self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1])
# Verify two pairs of send and recv functions for nested
@@ -545,21 +578,18 @@
t2 = torch.zeros(3, 3, requires_grad=False)
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
- "worker{}".format(dst_rank),
- torch.add,
- args=(t1, t2)
+ "worker{}".format(dst_rank), torch.add, args=(t1, t2)
)
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
- "worker{}".format(dst_rank),
- torch.add,
- args=(t1, t2)
+ "worker{}".format(dst_rank), torch.add, args=(t1, t2)
).to_here()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
- rpc.rpc_sync("worker{}".format(dst_rank),
- _set_rpc_done, args=(context_id, 1))
+ rpc.rpc_sync(
+ "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
+ )
ctx = dist_autograd._current_context()
send_functions = ctx._send_functions()
@@ -591,22 +621,19 @@
dst_rank = (self.rank + 1) % self.world_size
with dist_autograd.context() as context_id:
if ExecMode.RPC_SYNC == exec_mode:
- ret = rpc.rpc_sync(
- "worker{}".format(dst_rank),
- ret_requires_grad
- )
+ ret = rpc.rpc_sync("worker{}".format(dst_rank), ret_requires_grad)
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
- "worker{}".format(dst_rank),
- ret_requires_grad
+ "worker{}".format(dst_rank), ret_requires_grad
).to_here()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
dist_autograd.backward([ret.sum()])
- rpc.rpc_sync("worker{}".format(dst_rank),
- _set_rpc_done, args=(context_id, 1))
+ rpc.rpc_sync(
+ "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
+ )
# Wait for the prev rank to be done with rpc.
self._check_rpc_done(1)
@@ -636,15 +663,11 @@
dst_rank = self._next_rank()
if ExecMode.RPC_SYNC == exec_mode:
ret = rpc.rpc_sync(
- "worker{}".format(dst_rank),
- torch.stack,
- args=(tensors,)
+ "worker{}".format(dst_rank), torch.stack, args=(tensors,)
)
elif ExecMode.REMOTE == exec_mode:
ret = rpc.remote(
- "worker{}".format(dst_rank),
- torch.stack,
- args=(tensors,)
+ "worker{}".format(dst_rank), torch.stack, args=(tensors,)
).to_here()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
@@ -694,9 +717,15 @@
with dist_autograd.context() as context_id:
for dst_rank in dst_ranks:
rpc.rpc_sync("worker{}".format(dst_rank), func, args=rpc_args)
- rpc.rpc_sync("worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1))
+ rpc.rpc_sync(
+ "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
+ )
if nested:
- rpc.rpc_sync("worker{}".format(nested_dst_rank), _set_rpc_done, args=(context_id, 2))
+ rpc.rpc_sync(
+ "worker{}".format(nested_dst_rank),
+ _set_rpc_done,
+ args=(context_id, 2),
+ )
# the thread's context id should be cleaned up
with self.assertRaises(RuntimeError):
dist_autograd._retrieve_context(context_id)
@@ -728,7 +757,9 @@
t2 = torch.zeros(3, 3, requires_grad=True)
dst_rank = (self.rank + 1) % self.world_size
args = (t1, t2, dst_rank, self.world_size, 0)
- self.context_cleanup_test_helper(rpc_args=args, func=my_py_nested_call, nested=True)
+ self.context_cleanup_test_helper(
+ rpc_args=args, func=my_py_nested_call, nested=True
+ )
@dist_init
def test_worker_ids_recorded(self):
@@ -752,7 +783,9 @@
t1.requires_grad = True
t2.requires_grad = True
for dst_rank in dst_ranks:
- ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2))
+ ret = rpc.rpc_sync(
+ "worker{}".format(dst_rank), torch.add, args=(t1, t2)
+ )
rpc.rpc_sync(
"worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1)
)
@@ -766,11 +799,11 @@
t1 = torch.rand(3, 3, requires_grad=True)
t2 = torch.rand(6, 6, requires_grad=True)
-
with self.assertRaises(RuntimeError):
# This should throw an error since matrix sizes don't match.
- rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul,
- args=(t1, t2))
+ rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.matmul, args=(t1, t2)
+ )
def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args):
if exec_mode == ExecMode.LOCAL:
@@ -807,7 +840,9 @@
with dist_autograd.context() as context_id:
ret = self._exec_func(exec_mode, torch.add, t1, t2)
loss = ret.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
+ ret = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2
+ )
local_grads = ret if ret else local_grads
# The current rank first creates a tensor on the rref_owner, and then passes
@@ -824,19 +859,14 @@
local_ret.sum().backward()
with dist_autograd.context() as context_id:
rref_t1 = rpc.remote(
- rref_owner,
- _torch_ones,
- args=((3, 3),),
- kwargs={"requires_grad": True}
+ rref_owner, _torch_ones, args=((3, 3),), kwargs={"requires_grad": True}
)
if callee == rref_owner:
rref = rpc.remote(callee, my_rref_add, args=(rref_t1, t2))
else:
rref = rpc.remote(
- callee,
- my_nested_rref_add,
- args=(rref_owner, rref_t1, t2)
+ callee, my_nested_rref_add, args=(rref_owner, rref_t1, t2)
)
ret = rref.to_here()
dist_autograd.backward([ret.sum()])
@@ -851,7 +881,7 @@
rpc.rpc_sync(
rref_owner,
_compare_owner_value,
- args=(context_id, rref_t1, t1.grad)
+ args=(context_id, rref_t1, t1.grad),
)
)
@@ -898,18 +928,20 @@
rref_t1 = rpc.rpc_sync(
"worker{}".format(self._next_rank()),
_create_ones_rref_on,
- args=(self_name, (3, 3))
+ args=(self_name, (3, 3)),
)
# kick off forward and backward pass on three other workers (trainers)
rank_diffs = [1, 2, 3]
futures = []
for rank_diff in rank_diffs:
- futures.append(rpc.rpc_async(
- "worker{}".format((self.rank + rank_diff) % self.world_size),
- _run_trainer,
- args=(rref_t1, t2, self_name, rank_diff)
- ))
+ futures.append(
+ rpc.rpc_async(
+ "worker{}".format((self.rank + rank_diff) % self.world_size),
+ _run_trainer,
+ args=(rref_t1, t2, self_name, rank_diff),
+ )
+ )
# check if the trainers have done with their backward pass
for rank_diff in rank_diffs:
@@ -953,7 +985,9 @@
val = self._exec_func(exec_mode, torch.matmul, val, val)
loss = val.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5)
+ ret = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5
+ )
local_grads = ret if ret else local_grads
@dist_init
@@ -970,7 +1004,9 @@
val = self._exec_func(exec_mode, torch.chain_matmul, [val, t3, t4])
loss = val.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4)
+ ret = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4
+ )
local_grads = ret if ret else local_grads
@dist_init
@@ -982,10 +1018,17 @@
for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3))
- val = self._exec_func(exec_mode, torch.matmul, torch.narrow(s, 0, 0, 1), torch.narrow(s, 0, 2, 1))
+ val = self._exec_func(
+ exec_mode,
+ torch.matmul,
+ torch.narrow(s, 0, 0, 1),
+ torch.narrow(s, 0, 2, 1),
+ )
loss = val.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3)
+ ret = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2, t3
+ )
local_grads = ret if ret else local_grads
@dist_init
@@ -1002,7 +1045,9 @@
val = self._exec_func(exec_mode, torch.chain_matmul, [t1, t2, t3])
loss = val.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t)
+ ret = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t
+ )
local_grads = ret if ret else local_grads
def _run_test_backward_unused_send_function_in_thread(self):
@@ -1012,19 +1057,21 @@
# We don't use the result of an RPC function, as a result the
# backward pass would hang in the "FAST" mode.
- res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
- args=(t1, t2))
+ res = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.add, args=(t1, t2)
+ )
val = torch.mul(t1, t2)
# Run backward, this would hang forever.
dist_autograd.backward([val.sum()])
-
@dist_init
def test_backward_unused_send_function(self):
# Run the test in a thread which would never finish.
- t = threading.Thread(target=self._run_test_backward_unused_send_function_in_thread)
+ t = threading.Thread(
+ target=self._run_test_backward_unused_send_function_in_thread
+ )
t.daemon = True
t.start()
t.join(10) # Wait for 10s.
@@ -1043,21 +1090,30 @@
# Run multiple round trips across different nodes and verify the
# original node receives an error thrown on a node deep in the chain.
- val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
- args=(t2, t3))
- val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.mul,
- args=(val, t2))
- val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul,
- args=(val, t2))
- val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.div,
- args=(val, t2))
+ val = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.add, args=(t2, t3)
+ )
+ val = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.mul, args=(val, t2)
+ )
+ val = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.matmul, args=(val, t2)
+ )
+ val = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.div, args=(val, t2)
+ )
- with self.assertRaisesRegex(RuntimeError, 'Simulate error on backward pass'):
+ with self.assertRaisesRegex(
+ RuntimeError, "Simulate error on backward pass"
+ ):
# Run backwards, and validate we receive an error.
dist_autograd.backward([val.sum()])
- @unittest.skipIf(torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP",
- "Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures")
+ @unittest.skipIf(
+ torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
+ == "PROCESS_GROUP",
+ "Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures",
+ )
@dist_init(clean_shutdown=False)
def test_backward_node_failure(self):
initialize_pg(self.init_method, self.rank, self.world_size)
@@ -1066,8 +1122,9 @@
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
- res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
- args=(t1, t2))
+ res = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.add, args=(t1, t2)
+ )
# Wait for all RPCs to be done.
dist.barrier()
@@ -1094,9 +1151,12 @@
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
- with self.assertRaisesRegex(RuntimeError, "Current thread doesn't have a valid autograd context"):
- res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
- args=(t1, t2))
+ with self.assertRaisesRegex(
+ RuntimeError, "Current thread doesn't have a valid autograd context"
+ ):
+ res = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()), torch.add, args=(t1, t2)
+ )
dist_autograd.backward([res.sum()])
@dist_init
@@ -1122,18 +1182,24 @@
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
dist_autograd.backward(None)
- with self.assertRaisesRegex(RuntimeError, "No tensors provided for gradient computation"):
+ with self.assertRaisesRegex(
+ RuntimeError, "No tensors provided for gradient computation"
+ ):
dist_autograd.backward([])
with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
t = torch.rand(3, 3)
dist_autograd.backward([t])
- with self.assertRaisesRegex(RuntimeError, "is not a scalar, all roots need to be scalar"):
+ with self.assertRaisesRegex(
+ RuntimeError, "is not a scalar, all roots need to be scalar"
+ ):
t = torch.rand(3, 3, requires_grad=True)
dist_autograd.backward([t])
- with self.assertRaisesRegex(RuntimeError, "does not have a valid gradient function"):
+ with self.assertRaisesRegex(
+ RuntimeError, "does not have a valid gradient function"
+ ):
t = torch.rand(1, requires_grad=True)
dist_autograd.backward([t])
@@ -1149,7 +1215,9 @@
r3 = self._exec_func(exec_mode, torch.cos, t1).sum()
r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum()
- local_grads = self._verify_backwards(exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2)
+ local_grads = self._verify_backwards(
+ exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2
+ )
@dist_init
def test_backward_different_dtypes(self):
@@ -1160,7 +1228,9 @@
with dist_autograd.context() as context_id:
loss = self._exec_func(exec_mode, torch.add, t1, t2).sum()
- local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
+ local_grads = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2
+ )
@dist_init
def test_backward_simple_python_udf(self):
@@ -1173,7 +1243,9 @@
with dist_autograd.context() as context_id:
ret = self._exec_func(exec_mode, my_py_add, t1, t2)
loss = ret.sum()
- local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
+ local_grads = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2
+ )
@dist_init
def test_backward_simple_script_call(self):
@@ -1182,11 +1254,18 @@
local_grads = None
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
- for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.RPC_ASYNC, ExecMode.REMOTE]:
+ for exec_mode in [
+ ExecMode.LOCAL,
+ ExecMode.RPC_SYNC,
+ ExecMode.RPC_ASYNC,
+ ExecMode.REMOTE,
+ ]:
with dist_autograd.context() as context_id:
ret = self._exec_func(exec_mode, my_script_add, t1, t2)
loss = ret.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
+ ret = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2
+ )
local_grads = ret if ret else local_grads
@staticmethod
@@ -1205,9 +1284,13 @@
t2 = torch.rand((3, 3), requires_grad=True)
for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
with dist_autograd.context() as context_id:
- ret = self._exec_func(exec_mode, DistAutogradTest._complex_python_udf, t1, t2)
+ ret = self._exec_func(
+ exec_mode, DistAutogradTest._complex_python_udf, t1, t2
+ )
loss = ret.sum()
- local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
+ local_grads = self._verify_backwards(
+ exec_mode, [loss], context_id, local_grads, t1, t2
+ )
@staticmethod
def _python_udf_with_backward_error(t1, t2):
@@ -1219,24 +1302,28 @@
def _nested_rpc_call_backward_error(t1, t2, dst):
t1 = t1 * t2
t2 = t1 + t2
- res = rpc.rpc_sync('worker{}'.format(dst),
- DistAutogradTest._python_udf_with_backward_error,
- args=(t1, t2))
+ res = rpc.rpc_sync(
+ "worker{}".format(dst),
+ DistAutogradTest._python_udf_with_backward_error,
+ args=(t1, t2),
+ )
return torch.chain_matmul(t1, t2, res)
-
@dist_init
def test_backward_python_udf_error(self):
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
with dist_autograd.context() as context_id:
- loss = rpc.rpc_sync('worker{}'.format(self._next_rank()),
- DistAutogradTest._nested_rpc_call_backward_error,
- args=(t1, t2, self._next_rank()))
- with self.assertRaisesRegex(RuntimeError, 'Simulate error on backward pass'):
+ loss = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()),
+ DistAutogradTest._nested_rpc_call_backward_error,
+ args=(t1, t2, self._next_rank()),
+ )
+ with self.assertRaisesRegex(
+ RuntimeError, "Simulate error on backward pass"
+ ):
dist_autograd.backward([loss.sum()])
-
_backward_done = False
@staticmethod
@@ -1248,9 +1335,12 @@
while not DistAutogradTest._backward_done:
time.sleep(0.1)
- @unittest.skipIf(torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP",
- "Skipping this test temporarily since ProcessGroupAgent " +
- "does not report errors on node failures")
+ @unittest.skipIf(
+ torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
+ == "PROCESS_GROUP",
+ "Skipping this test temporarily since ProcessGroupAgent "
+ + "does not report errors on node failures",
+ )
@dist_init(clean_shutdown=False)
def test_backward_node_failure_python_udf(self):
initialize_pg(self.init_method, self.rank, self.world_size)
@@ -1260,8 +1350,11 @@
t2 = torch.rand((3, 3), requires_grad=True)
dst = self._next_rank()
- res = rpc.rpc_sync('worker{}'.format(dst), my_py_nested_call,
- args=(t1, t2, dst, self.world_size, 1))
+ res = rpc.rpc_sync(
+ "worker{}".format(dst),
+ my_py_nested_call,
+ args=(t1, t2, dst, self.world_size, 1),
+ )
# Wait for all RPCs to be done.
dist.barrier()
@@ -1283,7 +1376,11 @@
# Tell other nodes RPC is done.
for i in range(self.world_size):
if i != self.rank and i != 2:
- rpc.rpc_sync('worker{}'.format(i), DistAutogradTest._set_backward_done, args=())
+ rpc.rpc_sync(
+ "worker{}".format(i),
+ DistAutogradTest._set_backward_done,
+ args=(),
+ )
else:
# Wait for backward to finish on rank 0.
DistAutogradTest._wait_backward_done()
@@ -1292,7 +1389,7 @@
def _nested_python_udf(t1, t2, dst):
t3 = t1 * t2
t4 = t1 + t2
- res = rpc.rpc_sync('worker{}'.format(dst), my_py_add, args=(t3, t4))
+ res = rpc.rpc_sync("worker{}".format(dst), my_py_add, args=(t3, t4))
return torch.chain_matmul(t1, t2, t3, t4, res)
@dist_init
@@ -1308,16 +1405,17 @@
# Now run distributed autograd.
with dist_autograd.context() as context_id:
- loss = rpc.rpc_sync('worker{}'.format(self._next_rank()),
- DistAutogradTest._nested_python_udf,
- args=(t1, t2, self._next_rank()))
+ loss = rpc.rpc_sync(
+ "worker{}".format(self._next_rank()),
+ DistAutogradTest._nested_python_udf,
+ args=(t1, t2, self._next_rank()),
+ )
dist_autograd.backward([loss.sum()])
grads = dist_autograd.get_gradients(context_id)
self.assertEqual(t1.grad, grads[t1])
self.assertEqual(t2.grad, grads[t2])
-
_test_clean_context_backward_context_id = None
class MyBackwardFunc(Function):
@@ -1328,21 +1426,23 @@
@staticmethod
@once_differentiable
def backward(ctx, input):
- assert(DistAutogradTest._test_clean_context_backward_context_id is not None)
+ assert DistAutogradTest._test_clean_context_backward_context_id is not None
# Release the context to simulate error (use barrier before releasing
# context to ensure all nodes execute the backward function).
dist.barrier()
- dist_autograd._release_context(DistAutogradTest._test_clean_context_backward_context_id)
+ dist_autograd._release_context(
+ DistAutogradTest._test_clean_context_backward_context_id
+ )
# Verify all contexts are cleaned up.
- assert(_all_contexts_cleaned_up())
+ assert _all_contexts_cleaned_up()
return input
@dist_init
def test_clean_context_during_backward(self):
- '''
+ """
This test simulates the situation where the 'backward' call might throw
an exception locally which would lead to the autograd context being
cleaned up if we're using the context manager. As a result, the autograd
@@ -1351,7 +1451,7 @@
It is fine for the 'backward' call to throw an exception in this test,
but the process should not crash.
- '''
+ """
initialize_pg(self.init_method, self.rank, self.world_size)
context = dist_autograd._new_context()
@@ -1362,7 +1462,11 @@
for i in range(0, self.world_size):
if i != self.rank:
rank_distance = (i - self.rank + self.world_size) % self.world_size
- rpc.rpc_sync("worker{}".format(i), _set_rpc_done, args=(context_id, rank_distance))
+ rpc.rpc_sync(
+ "worker{}".format(i),
+ _set_rpc_done,
+ args=(context_id, rank_distance),
+ )
dist.barrier()
@@ -1379,7 +1483,9 @@
t1 = DistAutogradTest.MyBackwardFunc.apply(t1)
self.assertEqual(100, len(context._send_functions()))
- with self.assertRaisesRegex(RuntimeError, "Could not find autograd context with id"):
+ with self.assertRaisesRegex(
+ RuntimeError, "Could not find autograd context with id"
+ ):
dist_autograd.backward([t1.sum()])
# HACK: Killing workers since otherwise the autograd engine gets stuck on
@@ -1408,10 +1514,13 @@
@dist_init
def test_embedding_bag_with_no_grad_tensors(self):
dst = self._next_rank()
- remote_embedding = rpc.remote("worker{}".format(dst),
- torch.nn.EmbeddingBag, args=(16, 16),
- kwargs={'mode': 'sum', 'sparse': True})
- local_embedding = torch.nn.EmbeddingBag(16, 16, mode='sum', sparse=True)
+ remote_embedding = rpc.remote(
+ "worker{}".format(dst),
+ torch.nn.EmbeddingBag,
+ args=(16, 16),
+ kwargs={"mode": "sum", "sparse": True},
+ )
+ local_embedding = torch.nn.EmbeddingBag(16, 16, mode="sum", sparse=True)
input = torch.LongTensor([1, 2, 4, 5, 4, 3, 2, 9])
# requires_grad = True to record send/recv functions
@@ -1423,15 +1532,19 @@
local_grad = local_embedding.weight.grad
with dist_autograd.context() as context_id:
- res = rpc.rpc_sync("worker{}".format(dst),
- DistAutogradTest._call_remote_embedding,
- args=(remote_embedding, input, offsets, per_sample_weights))
+ res = rpc.rpc_sync(
+ "worker{}".format(dst),
+ DistAutogradTest._call_remote_embedding,
+ args=(remote_embedding, input, offsets, per_sample_weights),
+ )
dist_autograd.backward([res.sum()])
- remote_grad = rpc.rpc_sync("worker{}".format(dst),
- DistAutogradTest._get_grad,
- args=(remote_embedding, context_id))
+ remote_grad = rpc.rpc_sync(
+ "worker{}".format(dst),
+ DistAutogradTest._get_grad,
+ args=(remote_embedding, context_id),
+ )
self.assertEqual(local_grad.to_dense(), remote_grad)
@@ -1448,7 +1561,9 @@
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=False)
with dist_autograd.context() as context_id:
- ret = self._exec_func(exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2)
+ ret = self._exec_func(
+ exec_mode, DistAutogradTest._mixed_requires_grad, t1, t2
+ )
self.assertEqual(t1 * t2, ret)
dist_autograd.backward([ret.sum()])
self.assertTrue(t1.requires_grad)
@@ -1467,13 +1582,13 @@
@once_differentiable
def backward(ctx, input):
debug_info = dist_autograd._get_debug_info()
- assert (debug_info is not None)
- backward_passes = int(debug_info['num_current_backward_passes'])
+ assert debug_info is not None
+ backward_passes = int(debug_info["num_current_backward_passes"])
# Hard to validate exact numbers because of the distributed nature.
# We can't use a barrier() here since that would block the single
# CPU thread available for autograd and can cause deadlocks.
- assert (backward_passes >= 1 and backward_passes <= 4)
+ assert backward_passes >= 1 and backward_passes <= 4
return input
@dist_init
@@ -1488,8 +1603,9 @@
res[i] = t1
for rank in range(self.world_size):
if rank != self.rank:
- res[i + 1] = rpc.rpc_sync('worker{}'.format(rank), torch.add,
- args=(res[i], t2))
+ res[i + 1] = rpc.rpc_sync(
+ "worker{}".format(rank), torch.add, args=(res[i], t2)
+ )
i += 1
# Call custom function in middle of backward pass to ensure all
@@ -1499,34 +1615,38 @@
for rank in range(self.world_size):
if rank != self.rank:
- res[i + 1] = rpc.rpc_sync('worker{}'.format(rank), torch.add,
- args=(res[i], t2))
+ res[i + 1] = rpc.rpc_sync(
+ "worker{}".format(rank), torch.add, args=(res[i], t2)
+ )
i += 1
dist_autograd.backward([res[i].sum()])
debug_info = dist_autograd._get_debug_info()
- num_autograd_context = int(debug_info['num_autograd_contexts'])
+ num_autograd_context = int(debug_info["num_autograd_contexts"])
# Need atleast one context and not more than 4.
self.assertTrue(num_autograd_context >= 1 and num_autograd_context <= 4)
for rd in range(self.world_size - 1):
- rpc.rpc_sync("worker{}".format((self.rank + rd + 1) % self.world_size),
- _set_rpc_done, args=(context_id, rd + 1))
+ rpc.rpc_sync(
+ "worker{}".format((self.rank + rd + 1) % self.world_size),
+ _set_rpc_done,
+ args=(context_id, rd + 1),
+ )
dist.barrier()
# Validate information
debug_info = dist_autograd._get_debug_info()
- assert (debug_info is not None)
- self.assertEqual(0, int(debug_info['num_current_backward_passes']))
- self.assertEqual(0, int(debug_info['local_autograd_engine_cpu_queue_size']))
+ assert debug_info is not None
+ self.assertEqual(0, int(debug_info["num_current_backward_passes"]))
+ self.assertEqual(0, int(debug_info["local_autograd_engine_cpu_queue_size"]))
self.assertTrue(_all_contexts_cleaned_up())
# All contexts should be cleaned up.
debug_info = dist_autograd._get_debug_info()
- self.assertEqual(0, int(debug_info['num_autograd_contexts']))
+ self.assertEqual(0, int(debug_info["num_autograd_contexts"]))
@staticmethod
def _workload_thread():
@@ -1542,11 +1662,11 @@
@dist_init
def test_async_dist_autograd(self):
- '''
+ """
This test ensures async processing for distributed autograd works
appropriately. This is achieved by spawning multiple threads and
hammering a single node with a lot of backward() calls.
- '''
+ """
initialize_pg(self.init_method, self.rank, self.world_size)
if self.rank != 0:
@@ -1563,14 +1683,15 @@
dist.barrier()
-
@unittest.skipIf(
- not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
+ not torch._six.PY3,
+ "Pytorch distributed autograd package " "does not support python2",
)
class DistAutogradJitTest(RpcAgentTestFixture):
@dist_init
def test_get_gradients(self):
dst_rank = self.rank
+
@torch.jit.script
def dist_get_gradients(context_id):
# type: (int) -> (Dict[Tensor, Tensor])
@@ -1592,5 +1713,5 @@
self.assertEqual(torch.ones(3, 3), grads[t2])
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
index 842ce31..29fed93 100644
--- a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
+++ b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py
@@ -1,15 +1,17 @@
from __future__ import absolute_import, division, print_function, unicode_literals
+import threading
import unittest
-from torch.testing._internal.dist_utils import dist_init
-from torch import optim
-from torch.distributed.optim import DistributedOptimizer
import torch
import torch.distributed.autograd as dist_autograd
import torch.distributed.rpc as rpc
-import threading
-from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
+from torch import optim
+from torch.distributed.optim import DistributedOptimizer
+from torch.testing._internal.dist_utils import dist_init
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+ RpcAgentTestFixture,
+)
class MyModule:
@@ -34,13 +36,13 @@
super(FailingOptimizer, self).__init__(params, {})
def step(self, closure=None):
- raise ValueError('Error running optimizer.')
+ raise ValueError("Error running optimizer.")
class OptimizerFailingOnConstructor(optim.Optimizer):
def __init__(self, params):
super(OptimizerFailingOnConstructor, self).__init__(params, {})
- raise ValueError('Error creating optimizer.')
+ raise ValueError("Error creating optimizer.")
def step(self, closure=None):
raise NotImplementedError
@@ -66,7 +68,7 @@
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
- kwargs=kwargs
+ kwargs=kwargs,
)
@@ -86,7 +88,7 @@
obj_rref.owner(),
_call_method,
args=[method, obj_rref] + list(args),
- kwargs=kwargs
+ kwargs=kwargs,
)
@@ -94,12 +96,11 @@
not torch._six.PY3, "Pytorch distributed optim does not support python2"
)
class DistOptimizerTest(RpcAgentTestFixture):
-
@dist_init()
def test_dist_optim_exception(self):
# distributed version
- owner1 = 'worker%d' % ((self.rank + 1) % self.world_size)
- owner2 = 'worker%d' % ((self.rank + 2) % self.world_size)
+ owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
+ owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
@@ -107,8 +108,7 @@
remote_param2 = remote_method(MyModule.get_w, remote_module2)
dist_optim = DistributedOptimizer(
- FailingOptimizer,
- [remote_param1, remote_param2],
+ FailingOptimizer, [remote_param1, remote_param2]
)
with dist_autograd.context():
@@ -116,8 +116,7 @@
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
- output2 = rpc_async_method(
- MyModule.forward, remote_module2, output1.wait())
+ output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1).sum()
dist_autograd.backward([loss])
@@ -127,8 +126,8 @@
@dist_init()
def test_dist_optim_exception_on_constructor(self):
# distributed version
- owner1 = 'worker%d' % ((self.rank + 1) % self.world_size)
- owner2 = 'worker%d' % ((self.rank + 2) % self.world_size)
+ owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
+ owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
@@ -137,8 +136,7 @@
with self.assertRaisesRegex(Exception, "Error creating optimizer."):
dist_optim = DistributedOptimizer(
- OptimizerFailingOnConstructor,
- [remote_param1, remote_param2],
+ OptimizerFailingOnConstructor, [remote_param1, remote_param2]
)
@dist_init()
@@ -163,8 +161,8 @@
local_optim.step()
# distributed version
- owner1 = 'worker%d' % ((self.rank + 1) % self.world_size)
- owner2 = 'worker%d' % ((self.rank + 2) % self.world_size)
+ owner1 = "worker%d" % ((self.rank + 1) % self.world_size)
+ owner2 = "worker%d" % ((self.rank + 2) % self.world_size)
remote_module1 = rpc.remote(owner1, MyModule)
remote_module2 = rpc.remote(owner2, MyModule)
@@ -178,9 +176,7 @@
self.assertEqual(old_w2, remote_param2.to_here())
dist_optim = DistributedOptimizer(
- optim.SGD,
- [remote_param1, remote_param2],
- lr=0.05,
+ optim.SGD, [remote_param1, remote_param2], lr=0.05
)
with dist_autograd.context():
@@ -188,8 +184,7 @@
t1 = torch.rand((3, 3), requires_grad=True)
t2 = torch.rand((3, 3), requires_grad=True)
output1 = rpc_async_method(MyModule.forward, remote_module1, t2)
- output2 = rpc_async_method(
- MyModule.forward, remote_module2, output1.wait())
+ output2 = rpc_async_method(MyModule.forward, remote_module2, output1.wait())
loss = torch.add(output2.wait(), t1)
dist_autograd.backward([loss.sum()])
diff --git a/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py b/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
index b15eba2..9bd17bd 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_agent_test_fixture.py
@@ -1,5 +1,5 @@
-import torch.testing._internal.dist_utils
import torch.distributed.rpc as rpc
+import torch.testing._internal.dist_utils
class RpcAgentTestFixture(object):
@@ -9,12 +9,18 @@
@property
def init_method(self):
- return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(file_name=self.file_name)
+ return torch.testing._internal.dist_utils.INIT_METHOD_TEMPLATE.format(
+ file_name=self.file_name
+ )
@property
def rpc_backend(self):
- return rpc.backend_registry.BackendType[torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name]
+ return rpc.backend_registry.BackendType[
+ torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
+ ]
@property
def rpc_backend_options(self):
- return torch.testing._internal.dist_utils.TEST_CONFIG.build_rpc_backend_options(self)
+ return torch.testing._internal.dist_utils.TEST_CONFIG.build_rpc_backend_options(
+ self
+ )
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index 1f7c6fb..626d8a4 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -1,30 +1,39 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import concurrent.futures
-from datetime import timedelta
import sys
import time
import unittest
from collections import namedtuple
+from datetime import timedelta
from unittest import mock
import torch
import torch.distributed as dist
import torch.distributed.rpc as rpc
-from torch.testing._internal.common_utils import load_tests, IS_MACOS
-from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info
import torch.testing._internal.dist_utils
-from torch.testing._internal.dist_utils import dist_init, wait_until_node_failure, initialize_pg, get_shutdown_error_regex
-from torch.distributed.rpc.api import _use_rpc_pickler
-from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler, RPCExecMode
-from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import RpcAgentTestFixture
from torch._jit_internal import _qualified_name
+from torch.distributed.rpc import RRef, _get_debug_info, _rref_context_get_debug_info
+from torch.distributed.rpc.api import _use_rpc_pickler
+from torch.distributed.rpc.internal import PythonUDF, RPCExecMode, _internal_rpc_pickler
+from torch.testing._internal.common_utils import IS_MACOS, load_tests
+from torch.testing._internal.dist_utils import (
+ dist_init,
+ get_shutdown_error_regex,
+ initialize_pg,
+ wait_until_node_failure,
+)
+from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
+ RpcAgentTestFixture,
+)
def requires_process_group_agent(message=""):
def decorator(old_func):
return unittest.skipUnless(
- torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP", message
+ torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
+ == "PROCESS_GROUP",
+ message,
)(old_func)
return decorator
@@ -40,22 +49,16 @@
def get_worker_infos(self):
return {
- rpc.WorkerInfo(
- name="worker{}".format(rank),
- id=rank,
- ) for rank in range(self.world_size)
+ rpc.WorkerInfo(name="worker{}".format(rank), id=rank)
+ for rank in range(self.world_size)
}
-def _stub_construct_rpc_backend_options_handler(
- **kwargs
-):
+def _stub_construct_rpc_backend_options_handler(**kwargs):
return mock.Mock() # RpcBackendOptions.
-def _stub_init_rpc_backend_handler(
- store, name, rank, world_size, rpc_backend_options
-):
+def _stub_init_rpc_backend_handler(store, name, rank, world_size, rpc_backend_options):
return StubRpcAgent(world_size=world_size)
@@ -147,6 +150,7 @@
def my_tensor_function(a, b):
return a + b
+
def my_sleep_func(seconds=1):
time.sleep(seconds)
@@ -229,12 +233,15 @@
def raise_func():
raise ValueError("Expected error")
+
global_rref = None
+
def set_global_rref(rref):
global global_rref
global_rref = rref
+
def clear_global_rref():
global global_rref
global_rref = None
@@ -288,20 +295,14 @@
def test_get_worker_infos(self):
worker_infos = rpc.api._get_current_rpc_agent().get_worker_infos()
- worker_names = {
- worker_info.name for worker_info in worker_infos
- }
+ worker_names = {worker_info.name for worker_info in worker_infos}
expected_worker_names = {
"worker{}".format(rank) for rank in range(self.world_size)
}
self.assertEqual(worker_names, expected_worker_names)
- worker_ids = {
- worker_info.id for worker_info in worker_infos
- }
- expected_worker_ids = {
- rank for rank in range(self.world_size)
- }
+ worker_ids = {worker_info.id for worker_info in worker_infos}
+ expected_worker_ids = {rank for rank in range(self.world_size)}
self.assertEqual(worker_ids, expected_worker_ids)
@dist_init
@@ -340,7 +341,9 @@
self_worker_info = rpc.get_worker_info()
rref = rpc.remote(self_worker_info, my_function, args=(torch.ones(2, 2), 1, 3))
ret_rref = rpc.remote(dst, add_rref_to_value, args=(rref, torch.ones(2, 2)))
- self.assertEqual(ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2))
+ self.assertEqual(
+ ret_rref.to_here(), torch.ones(2, 2) + 1 + 3 + torch.ones(2, 2)
+ )
@dist_init
def test_self_remote_rref_as_remote_arg(self):
@@ -386,9 +389,11 @@
@dist_init(setup_rpc=False)
def test_duplicate_name(self):
with self.assertRaisesRegex(RuntimeError, "is not unique"):
- store, _, _ = next(torch.distributed.rendezvous(
- self.init_method, rank=self.rank, world_size=self.world_size
- ))
+ store, _, _ = next(
+ torch.distributed.rendezvous(
+ self.init_method, rank=self.rank, world_size=self.world_size
+ )
+ )
rpc.api._init_rpc_backend(
backend=self.rpc_backend,
store=store,
@@ -425,6 +430,7 @@
@dist_init(setup_rpc=False)
def test_invalid_names(self):
from torch.distributed.rpc import WorkerInfo
+
worker_id = 0
with self.assertRaisesRegex(RuntimeError, "Worker name must match"):
info = WorkerInfo("abc*", worker_id)
@@ -648,13 +654,19 @@
# this, we wait until the current RRef context doesn't have
# any pending users, which indicates that the confirmation
# was processed on this worker.
- num_pending_users = int(_rref_context_get_debug_info()["num_pending_users"])
+ num_pending_users = int(
+ _rref_context_get_debug_info()["num_pending_users"]
+ )
while num_pending_users != 0:
time.sleep(0.1)
- num_pending_users = int(_rref_context_get_debug_info()["num_pending_users"])
+ num_pending_users = int(
+ _rref_context_get_debug_info()["num_pending_users"]
+ )
events = prof.function_events
- rpc_event = [event for event in events if rpc_exec_mode.value in event.name][0]
+ rpc_event = [
+ event for event in events if rpc_exec_mode.value in event.name
+ ][0]
# the sender, dest worker, function run, and type of RPC should all
# be recorded.
self_worker_name = "worker{}".format(self.rank)
@@ -671,7 +683,9 @@
@dist_init
def test_profiler_with_sync_rpc_builtin(self):
- self._profiler_test_with_rpc(RPCExecMode.SYNC, torch.add, args=(torch.ones(1), torch.ones(1)))
+ self._profiler_test_with_rpc(
+ RPCExecMode.SYNC, torch.add, args=(torch.ones(1), torch.ones(1))
+ )
@dist_init
def test_profiler_with_async_rpc_udf(self):
@@ -679,7 +693,9 @@
@dist_init
def test_profiler_with_async_rpc_builtin(self):
- self._profiler_test_with_rpc(RPCExecMode.ASYNC, torch.add, args=(torch.ones(1), torch.ones(1)))
+ self._profiler_test_with_rpc(
+ RPCExecMode.ASYNC, torch.add, args=(torch.ones(1), torch.ones(1))
+ )
@dist_init
def test_profiler_with_remote_udf(self):
@@ -687,7 +703,9 @@
@dist_init
def test_profiler_with_remote_builtin(self):
- self._profiler_test_with_rpc(RPCExecMode.REMOTE, torch.add, args=(torch.ones(1), torch.ones(1)))
+ self._profiler_test_with_rpc(
+ RPCExecMode.REMOTE, torch.add, args=(torch.ones(1), torch.ones(1))
+ )
@dist_init
def test_py_class_constructor(self):
@@ -820,7 +838,9 @@
with self.assertRaisesRegex(Exception, "no_args"):
ret = rpc.rpc_sync("worker{}".format(dst_rank), no_args, args=(10,))
- with self.assertRaisesRegex(Exception, r"no_args\(\) expected at most 0 argument"):
+ with self.assertRaisesRegex(
+ Exception, r"no_args\(\) expected at most 0 argument"
+ ):
rref = rpc.remote("worker{}".format(dst_rank), no_args, args=(10,))
@dist_init
@@ -833,32 +853,35 @@
# accept script module and script module method.
n = self.rank + 1
dst_rank = n % self.world_size
- with self.assertRaisesRegex(RuntimeError, "attempted to get undefined function"):
+ with self.assertRaisesRegex(
+ RuntimeError, "attempted to get undefined function"
+ ):
ret = rpc._rpc_sync_torchscript(
- 'worker{}'.format(dst_rank),
- _qualified_name(MyScriptClass),
- args=())
- ret = rpc.rpc_sync(
- 'worker{}'.format(dst_rank), MyScriptClass, args=())
+ "worker{}".format(dst_rank), _qualified_name(MyScriptClass), args=()
+ )
+ ret = rpc.rpc_sync("worker{}".format(dst_rank), MyScriptClass, args=())
- with self.assertRaisesRegex(RuntimeError, "attempted to get undefined function"):
+ with self.assertRaisesRegex(
+ RuntimeError, "attempted to get undefined function"
+ ):
ret = rpc._rpc_sync_torchscript(
- 'worker{}'.format(dst_rank),
- _qualified_name(MyScriptModule),
- args=())
+ "worker{}".format(dst_rank), _qualified_name(MyScriptModule), args=()
+ )
- with self.assertRaisesRegex(RuntimeError, "attempted to get undefined function"):
+ with self.assertRaisesRegex(
+ RuntimeError, "attempted to get undefined function"
+ ):
ret = rpc._rpc_sync_torchscript(
- 'worker{}'.format(dst_rank),
+ "worker{}".format(dst_rank),
_qualified_name(MyScriptModule().my_method),
- args=())
+ args=(),
+ )
# Python 3.5 and Python 3.6 throw different error message, the only
# common word can be greped is "pickle".
with self.assertRaisesRegex(Exception, "pickle"):
ret = rpc.rpc_sync(
- 'worker{}'.format(dst_rank),
- MyScriptModule().my_method,
- args=())
+ "worker{}".format(dst_rank), MyScriptModule().my_method, args=()
+ )
@dist_init
def test_nested_rpc(self):
@@ -1118,9 +1141,12 @@
# ensure that an error message is thrown if a user tries to call
# local_value() on a non-owning node.
next_rank = (self.rank + 1) % self.world_size
- rref = rpc.remote("worker{}".format(next_rank), torch.add, args=(
- torch.ones(1), torch.ones(1)))
- with self.assertRaisesRegex(RuntimeError, "Call it on worker{}".format(next_rank)):
+ rref = rpc.remote(
+ "worker{}".format(next_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+ )
+ with self.assertRaisesRegex(
+ RuntimeError, "Call it on worker{}".format(next_rank)
+ ):
rref.local_value()
@dist_init
@@ -1129,17 +1155,22 @@
dst_rank = n % self.world_size
rref_list = rpc.rpc_sync(
- "worker{}".format(dst_rank), get_rref_list, args=(
- [1, 2, 3], ))
+ "worker{}".format(dst_rank), get_rref_list, args=([1, 2, 3],)
+ )
for rref in rref_list:
- rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
- MyClass.increment_value, rref, 10))
+ rpc.rpc_sync(
+ rref.owner(),
+ _call_method_on_rref,
+ args=(MyClass.increment_value, rref, 10),
+ )
rets = [
- rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
- MyClass.get_value, rref))
- for rref in rref_list]
+ rpc.rpc_sync(
+ rref.owner(), _call_method_on_rref, args=(MyClass.get_value, rref)
+ )
+ for rref in rref_list
+ ]
self.assertEqual(rets, [11, 12, 13])
@@ -1183,16 +1214,14 @@
rref = RRef(40)
self.assertEqual(
- rpc.rpc_sync(
- dst_worker, add_rref_to_value, args=(rref, 50)), 90)
+ rpc.rpc_sync(dst_worker, add_rref_to_value, args=(rref, 50)), 90
+ )
self.assertEqual(
- rpc.rpc_async(
- dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90)
+ rpc.rpc_async(dst_worker, add_rref_to_value, args=(rref, 50)).wait(), 90
+ )
self.assertEqual(
- rpc.remote(
- dst_worker,
- add_rref_to_value,
- args=(rref, 50)).to_here(), 90)
+ rpc.remote(dst_worker, add_rref_to_value, args=(rref, 50)).to_here(), 90
+ )
@dist_init
def test_remote_same_worker(self):
@@ -1220,19 +1249,29 @@
dst_worker = "worker{}".format(dst_rank)
# creates a remote object
- rref = rpc.remote(dst_worker, MyClass, args=(vals[0], ))
+ rref = rpc.remote(dst_worker, MyClass, args=(vals[0],))
# modifies state of the remote object
- rpc.rpc_sync(rref.owner(), _call_method_on_rref, args=(
- MyClass.increment_value, rref, vals[1]))
- rpc.rpc_async(rref.owner(), _call_method_on_rref, args=(
- MyClass.increment_value, rref, vals[2])).wait()
- rpc.remote(rref.owner(), _call_method_on_rref, args=(
- MyClass.increment_value, rref, vals[3])).to_here()
+ rpc.rpc_sync(
+ rref.owner(),
+ _call_method_on_rref,
+ args=(MyClass.increment_value, rref, vals[1]),
+ )
+ rpc.rpc_async(
+ rref.owner(),
+ _call_method_on_rref,
+ args=(MyClass.increment_value, rref, vals[2]),
+ ).wait()
+ rpc.remote(
+ rref.owner(),
+ _call_method_on_rref,
+ args=(MyClass.increment_value, rref, vals[3]),
+ ).to_here()
# queries state of the remote object
- result = rpc.rpc_sync(dst_worker, _call_method_on_rref, args=(
- MyClass.get_value, rref))
+ result = rpc.rpc_sync(
+ dst_worker, _call_method_on_rref, args=(MyClass.get_value, rref)
+ )
self.assertEqual(result, sum(vals))
@@ -1252,10 +1291,11 @@
rref = rpc.remote(
"worker{}".format((self.rank + 1) % self.world_size),
torch.add,
- args=(torch.ones(2, 2), 1)
+ args=(torch.ones(2, 2), 1),
)
import torch.distributed.rpc.api as api
+
if ignore_leak:
api._ignore_rref_leak = True
rpc.shutdown(graceful=True)
@@ -1277,15 +1317,18 @@
rref1 = RRef(self.rank)
id_class = "GloballyUniqueId"
self.assertEqual(
- "OwnerRRef({}({}, 0))".format(id_class, self.rank),
- rref1.__str__()
+ "OwnerRRef({}({}, 0))".format(id_class, self.rank), rref1.__str__()
)
dst_rank = (self.rank + 1) % self.world_size
- rref2 = rpc.remote("worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1))
+ rref2 = rpc.remote(
+ "worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1)
+ )
self.assertEqual(
rref2.__str__(),
- "UserRRef(RRefId = {0}({1}, 1), ForkId = {0}({1}, 2))".format(id_class, self.rank)
+ "UserRRef(RRefId = {0}({1}, 1), ForkId = {0}({1}, 2))".format(
+ id_class, self.rank
+ ),
)
@dist_init
@@ -1316,11 +1359,7 @@
###########################################################
dst_rank = (self.rank + 1) % self.world_size
- rpc.rpc_sync(
- "worker{}".format(dst_rank),
- set_global_rref,
- args=(rref1,)
- )
+ rpc.rpc_sync("worker{}".format(dst_rank), set_global_rref, args=(rref1,))
# barrier before check 2
dist.barrier()
@@ -1339,14 +1378,10 @@
# Check 3: rpc.remote call should update owners_ map
####################################################
rref2 = rpc.remote(
- "worker{}".format(dst_rank),
- torch.add,
- args=(torch.ones(2, 2), 1)
+ "worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1)
)
rref3 = rpc.remote(
- "worker{}".format(dst_rank),
- torch.add,
- args=(torch.ones(2, 2), 1)
+ "worker{}".format(dst_rank), torch.add, args=(torch.ones(2, 2), 1)
)
rref2.to_here()
rref3.to_here()
@@ -1370,11 +1405,15 @@
# GIL profiling should be disabled by default.
dst_rank = (self.rank + 1) % self.world_size
- rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)))
+ rpc.rpc_sync(
+ "worker{}".format(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+ )
info = rpc.api._get_current_rpc_agent().get_debug_info()
self.assertRaises(KeyError, lambda: info["agent.gil_average_wait_time_us"])
rpc.enable_gil_profiling(True)
- rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1)))
+ rpc.rpc_sync(
+ "worker{}".format(dst_rank), torch.add, args=(torch.ones(1), torch.ones(1))
+ )
info = rpc.api._get_current_rpc_agent().get_debug_info()
self.assertIn("agent.gil_average_wait_time_us", info)
@@ -1399,9 +1438,7 @@
dist.barrier()
dst_rank = (self.rank + 1) % self.world_size
fut = rpc.rpc_async(
- "worker{}".format(dst_rank),
- set_and_check_done,
- args=(dst_rank,)
+ "worker{}".format(dst_rank), set_and_check_done, args=(dst_rank,)
)
# blocks until the request arrives
self.assertEqual(self.rank, VALUE_FUTURE.result())
@@ -1486,8 +1523,10 @@
self.assertEqual(expected.keys(), info.keys())
@dist_init(setup_rpc=False)
- @unittest.skipIf(IS_MACOS,
- "Test is flaky on MacOS, see https://github.com/pytorch/pytorch/issues/32019")
+ @unittest.skipIf(
+ IS_MACOS,
+ "Test is flaky on MacOS, see https://github.com/pytorch/pytorch/issues/32019",
+ )
def test_handle_send_exceptions(self):
# test that if a callee node has gone down, we raise an appropriate
# exception instead of just crashing.
@@ -1516,7 +1555,8 @@
error_str = (
"Encountered exception in ProcessGroupAgent::enqueueSend"
if self.rpc_backend == rpc.backend_registry.BackendType.PROCESS_GROUP
- else get_shutdown_error_regex())
+ else get_shutdown_error_regex()
+ )
with self.assertRaisesRegex(RuntimeError, error_str):
fut.wait()
# exit all workers non-gracefully.
@@ -1574,7 +1614,10 @@
dst_rank = (self.rank + 1) % self.world_size
rpc._set_rpc_timeout(timedelta(milliseconds=1))
# futures should time out and be marked with an exception indicating it as such.
- futs = [rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=()) for _ in range(10)]
+ futs = [
+ rpc.rpc_async("worker{}".format(dst_rank), my_sleep_func, args=())
+ for _ in range(10)
+ ]
for fut in futs:
with self.assertRaisesRegex(RuntimeError, "RPC ran for more than"):
fut.wait()
@@ -1603,7 +1646,10 @@
def test_func():
return "expected result"
- if torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name == "PROCESS_GROUP":
+ if (
+ torch.testing._internal.dist_utils.TEST_CONFIG.rpc_backend_name
+ == "PROCESS_GROUP"
+ ):
self.assertEqual(test_func(), "expected result")
def test_dist_init_decorator(self):
@@ -1620,9 +1666,12 @@
self.assertEqual(test_func(self), "expected result")
def test_use_rpc_pickler(self):
- class TestPickler():
+ class TestPickler:
pass
+
test_pickler = TestPickler()
with _use_rpc_pickler(test_pickler):
self.assertTrue(torch.distributed.rpc.api._default_pickler is test_pickler)
- self.assertTrue(torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler)
+ self.assertTrue(
+ torch.distributed.rpc.api._default_pickler is _internal_rpc_pickler
+ )