| from __future__ import absolute_import, division, print_function, unicode_literals |
| |
| import time |
| import unittest |
| |
| import torch |
| import torch.distributed as dist |
| import torch.distributed.autograd as dist_autograd |
| import torch.distributed.rpc as rpc |
| from dist_utils import INIT_METHOD_TEMPLATE, dist_init |
| |
| import threading |
| |
| # Right now we test up to 3-layer nested rpc calls. |
| # rpc_done[1] and ctx_ids[1] represent rpc is done in prev rank, and context id |
| # sent from prev rank respectively. |
| # rpc_done[2] and ctx_ids[2] represents for prev of prev rank. |
| # rpc_done[3] and ctx_ids[3] represents for prev of prev of prev rank. |
| # rpc_done[0] and ctx_ids[0] represents for current rank, but mostly not used. |
| rpc_done = [False, False, False, False] |
| ctx_ids = [-1, -1, -1, -1] |
| |
| known_context_ids = [] |
| |
| # we don't need a lock here since the GIL is held while executing remote |
| # python UDFs, so access to known_context_ids is serialized across several workers. |
| def _store_context_id(context_id): |
| global known_context_ids |
| known_context_ids.append(context_id) |
| |
| # Send rpc done info and context_id to |
| # dst_rank = (self.rank + rank_distance) % self.world_size |
| def _set_rpc_done(ctx_id, rank_distance): |
| global rpc_done |
| global ctx_ids |
| rpc_done[rank_distance] = True |
| ctx_ids[rank_distance] = ctx_id |
| |
| |
| def my_py_add(t1, t2): |
| return torch.add(t1, t2) |
| |
| |
| def my_py_nested_call(t1, t2, dst, world_size, hops): |
| next_dst = (dst + 1) % world_size |
| if hops > 0: |
| return rpc.rpc_sync("worker{}".format(next_dst), my_py_nested_call, |
| args=(t1, t2, next_dst, world_size, hops - 1)) |
| else: |
| return rpc.rpc_sync("worker{}".format(next_dst), torch.add, args=(t1, t2)) |
| |
| # after dist autograd context is cleaned up, it should be cleaned up on other |
| # nodes. This helper allows timeout_seconds for those RPCs to be completed, and |
| # ensures that all the contexts have been cleaned up in that timeframe.any |
| def _all_contexts_cleaned_up(num_contexts, timeout_seconds=10): |
| global known_context_ids |
| start = time.time() |
| context_id_to_raised = {} |
| while time.time() - start < timeout_seconds: |
| for context_id in known_context_ids: |
| try: |
| dist_autograd._retrieve_context(context_id) |
| except RuntimeError: |
| context_id_to_raised[context_id] = True |
| if len(context_id_to_raised) == num_contexts: |
| break |
| # all contexts have been cleaned up if trying to retrieve any context resulted in a RuntimeError. |
| success = len(context_id_to_raised) == num_contexts and all(context_id_to_raised.values()) |
| return success |
| |
| |
| from torch.autograd import Function |
| from torch.autograd.function import once_differentiable |
| |
| class SimulateBackwardError(Function): |
| @staticmethod |
| def forward(ctx, input): |
| return input |
| |
| @staticmethod |
| @once_differentiable |
| def backward(ctx, input): |
| raise Exception('Simulate error on backward pass') |
| |
| from enum import Enum |
| |
| class ExecMode(Enum): |
| LOCAL = 1 # Run the operation locally. |
| REMOTE = 2 # Run the operation using RPC. |
| |
| |
| @unittest.skipIf( |
| not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2" |
| ) |
| class DistAutogradTest(object): |
| |
| def _exec_func(self, exec_mode, method, *args): |
| if ExecMode.LOCAL == exec_mode: |
| if len(args) == 1 and isinstance(args[0], list): |
| return method(*args[0]) |
| return method(*args) |
| else: |
| return rpc.rpc_sync('worker{}'.format(self._next_rank()), method, |
| args=(args)) |
| |
| def _next_rank(self): |
| if hasattr(self, 'dst_rank'): |
| self.dst_rank = (self.dst_rank + 1) % self.world_size |
| if self.dst_rank == self.rank: |
| self._next_rank() |
| else: |
| self.dst_rank = (self.rank + 1) % self.world_size |
| return self.dst_rank |
| |
| def _check_rpc_done(self, rank_distance): |
| while not rpc_done[rank_distance]: |
| time.sleep(0.1) |
| pass |
| |
| @property |
| def world_size(self): |
| return 4 |
| |
| @property |
| def init_method(self): |
| return INIT_METHOD_TEMPLATE.format(file_name=self.file_name) |
| |
| @dist_init |
| def test_autograd_context(self): |
| # Verify max possible id. |
| max_auto_increment = 281474976710655 |
| self.assertEqual( |
| max_auto_increment + (self.worker_id << 48), dist_autograd._get_max_id() |
| ) |
| |
| context_ids = [] |
| for i in range(1000): |
| with dist_autograd.context() as context_id: |
| self.assertEqual( |
| context_id, |
| dist_autograd._retrieve_context(context_id)._context_id(), |
| ) |
| # First 16 bits should be worker_id. |
| self.assertEqual(self.worker_id, context_id >> 48) |
| context_ids.append(context_id) |
| |
| for context_id in context_ids: |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Could not find autograd context with id: {}".format(context_id), |
| ): |
| dist_autograd._retrieve_context(context_id) |
| |
| @dist_init |
| def test_nested_context(self): |
| with dist_autograd.context() as context_id: |
| # Nested contexts not supported. |
| with self.assertRaisesRegex(RuntimeError, "Already have an autograd context id for this thread"): |
| with dist_autograd.context() as context_id: |
| pass |
| |
| # For current context, this rank sends t1 and t2 tensors to dst_rank, |
| # then get t3 = torch.add(t1, t2) result tensor. |
| # For the current context in this rank, it expects graph like this: |
| # send function: |
| # rpcSendBackward |
| # / \ |
| # t1.AccumulateGrad t2.AccumulateGrad |
| # |
| # recv function: |
| # |
| # | |
| # t3.rpcRecvBackward |
| # |
| def _verify_graph_for_first_rpc_call(self, send_function, recv_function, t1, t2, ret): |
| # Retrieve the next functions in the graph. |
| next_funcs = send_function.next_functions |
| self.assertEqual(2, len(next_funcs)) |
| |
| # We should now hit t1 and t2 in the autograd graph. |
| self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[0][0].name()) |
| self.assertEqual(t1, next_funcs[0][0].variable) |
| self.assertEqual(0, next_funcs[0][1]) |
| self.assertEqual("torch::autograd::AccumulateGrad", next_funcs[1][0].name()) |
| self.assertEqual(t2, next_funcs[1][0].variable) |
| self.assertEqual(0, next_funcs[1][1]) |
| |
| # Test recv functions. |
| self.assertEqual(ret.grad_fn, recv_function) |
| |
| # For a context passed from previous nested chain calls, this rank |
| # receives two tensors t1 and t2, executes torch.add(t1, t2) and sends |
| # result tensor t3 back. |
| # For this context in this rank, it expects graph like this: |
| # send and recv functions: |
| # rpcSendBackward |
| # | |
| # t3.AddBackward0 |
| # / \ |
| # t1.recvRpcBackward t2.recvRpcBackward |
| def _verify_graph_for_rpc_call_exec(self, send_function): |
| # Verify next function is AddBackward0 |
| next_funcs = send_function.next_functions |
| self.assertEqual(1, len(next_funcs)) |
| add_backward_fn = next_funcs[0][0] |
| self.assertEqual("AddBackward0", add_backward_fn.name()) |
| |
| # Verify the next two functions are the same recv backward function. |
| next_funcs = add_backward_fn.next_functions |
| self.assertEqual(2, len(next_funcs)) |
| self.assertEqual( |
| "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() |
| ) |
| self.assertEqual( |
| "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() |
| ) |
| self.assertEqual(next_funcs[0][0], next_funcs[1][0]) |
| |
| # For a context passed from previous nested chain calls, this rank |
| # receives two tensors t1 and t2, forwards t1 and t2 tensors using |
| # nested rpc call to next dst. In return route, receive result tensor t3 |
| # from next dst and forwarding t3 back to previous calls. |
| # For this context in this rank, it expects graph like this: |
| # send and recv functions for receving and forwarding t1 and t2: |
| # rpcSendBackward |
| # / \ |
| # t1.recvRpcBackward t2.recvRpcBackward |
| # send and recv functions for receiving and forwarding t3: |
| # rpcSendBackward |
| # | |
| # t3.recvRpcBackward |
| def _verify_graph_for_nested_rpc_call(self, ctx): |
| send_functions = ctx._send_functions() |
| self.assertEqual(2, len(send_functions)) |
| |
| # For send function when making nest rpc call, |
| # next functions of the send function are two recv functions |
| # for recevied two tensors from previous call |
| next_funcs = list(send_functions.values())[0].next_functions |
| self.assertEqual(2, len(next_funcs)) |
| self.assertEqual( |
| "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() |
| ) |
| self.assertEqual( |
| "torch::distributed::autograd::RecvRpcBackward", next_funcs[1][0].name() |
| ) |
| self.assertEqual(next_funcs[0][0], next_funcs[1][0]) |
| |
| |
| # For send function when returning resonpose to previous call |
| # next function of the send function is the recv function |
| # for received tensor result returned from nested call |
| next_funcs = list(send_functions.values())[1].next_functions |
| self.assertEqual(1, len(next_funcs)) |
| self.assertEqual( |
| "torch::distributed::autograd::RecvRpcBackward", next_funcs[0][0].name() |
| ) |
| |
| def _test_graph(self, fn): |
| dst_rank = (self.rank + 1) % self.world_size |
| with dist_autograd.context() as context_id: |
| t1 = torch.ones(3, 3, requires_grad=True) |
| t2 = torch.zeros(3, 3, requires_grad=True) |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), fn, args=(t1, t2)) |
| rpc.rpc_sync("worker{}".format(dst_rank), |
| _set_rpc_done, args=(context_id, 1)) |
| |
| # Verify graph for current context id. |
| ctx = dist_autograd._current_context() |
| self.assertEqual(context_id, ctx._context_id()) |
| send_functions = ctx._send_functions() |
| self.assertEqual(1, len(send_functions)) |
| recv_functions = ctx._recv_functions() |
| self.assertEqual(1, len(recv_functions)) |
| self._verify_graph_for_first_rpc_call(list(send_functions.values())[0], |
| list(recv_functions.values())[0], |
| t1, t2, ret) |
| |
| # Wait for the prev rank to be done with rpc. |
| self._check_rpc_done(1) |
| # Verify graph for previous context id. |
| ctx = dist_autograd._retrieve_context(ctx_ids[1]) |
| send_functions = ctx._send_functions() |
| self.assertEqual(1, len(send_functions)) |
| self._verify_graph_for_rpc_call_exec(list(send_functions.values())[0]) |
| # this barrier is needed so one worker does not clean up their |
| # autograd context before another worker tries to access it. |
| dist.barrier() |
| |
| # autograd context should be cleaned up by now. |
| with self.assertRaises(RuntimeError): |
| ctx = dist_autograd._retrieve_context(context_id) |
| |
| # No autograd context available. |
| with self.assertRaises(RuntimeError): |
| ctx = dist_autograd._current_context() |
| |
| @dist_init |
| def test_graph_for_builtin_call(self): |
| self._test_graph(torch.add) |
| |
| @dist_init |
| def test_graph_for_python_call(self): |
| self._test_graph(my_py_add) |
| |
| # 3-layer nested calls |
| @dist_init |
| def test_graph_for_py_nested_call(self): |
| dst_rank = (self.rank + 1) % self.world_size |
| with dist_autograd.context() as context_id: |
| t1 = torch.ones(3, 3, requires_grad=True) |
| t2 = torch.zeros(3, 3, requires_grad=True) |
| nest_dst_rank = (dst_rank + 1) % self.world_size |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), |
| my_py_nested_call, args=(t1, t2, dst_rank, self.world_size, 1)) |
| for rd in [1, 2, 3]: |
| rpc.rpc_sync("worker{}".format((self.rank + rd) % self.world_size), |
| _set_rpc_done, args=(context_id, rd)) |
| |
| # For self.rank, it has 4 graphs to verify |
| # One is for current context id when this rank send first rpc call. |
| # Second one is for prev context id when this rank make 1st nested |
| # call. |
| # Third one is for prev prev context id when this rank make |
| # 2nd nested call. |
| # Last one is for prev prev prev context id when this rank |
| # execute the torch.add() operator. |
| |
| # Verify first graph for current context id. |
| ctx = dist_autograd._current_context() |
| self.assertEqual(context_id, ctx._context_id()) |
| send_functions = ctx._send_functions() |
| self.assertEqual(1, len(send_functions)) |
| recv_functions = ctx._recv_functions() |
| self.assertEqual(1, len(recv_functions)) |
| self._verify_graph_for_first_rpc_call(list(send_functions.values())[0], |
| list(recv_functions.values())[0], |
| t1, t2, ret) |
| |
| # Verify second graph for 1st nested call. |
| self._check_rpc_done(1) |
| ctx = dist_autograd._retrieve_context(ctx_ids[1]) |
| self._verify_graph_for_nested_rpc_call(ctx) |
| |
| # Verify third graph for 2nd nested call. |
| self._check_rpc_done(2) |
| ctx = dist_autograd._retrieve_context(ctx_ids[2]) |
| self._verify_graph_for_nested_rpc_call(ctx) |
| |
| # verify last graph for rpc call execution. |
| self._check_rpc_done(3) |
| ctx = dist_autograd._retrieve_context(ctx_ids[3]) |
| send_functions = ctx._send_functions() |
| self.assertEqual(1, len(send_functions)) |
| self._verify_graph_for_rpc_call_exec(list(send_functions.values())[0]) |
| # this barrier is needed so one worker does not clean up their |
| # autograd context before another worker tries to access it. |
| dist.barrier() |
| |
| # Rank0->Rank1->Rank0 |
| @dist_init |
| def test_graph_for_py_nested_call_itself(self): |
| dst_rank = (self.rank + 1) % self.world_size |
| with dist_autograd.context() as context_id: |
| t1 = torch.ones(3, 3, requires_grad=True) |
| t2 = torch.zeros(3, 3, requires_grad=True) |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), |
| my_py_nested_call, |
| args=(t1, t2, (self.rank - 1 + self.world_size) % self.world_size, self.world_size, 0)) |
| rpc.rpc_sync("worker{}".format((self.rank + 1) % self.world_size), |
| _set_rpc_done, args=(context_id, 1)) |
| |
| # For self.rank, it has 2 graphs to verify. |
| # One is for current context id when this rank send first rpc |
| # call and execute the torch.add() operator. |
| # Another one is for prev context id when this rank make |
| # nested call. |
| ctx = dist_autograd._current_context() |
| self.assertEqual(context_id, ctx._context_id()) |
| send_functions = ctx._send_functions() |
| self.assertEqual(2, len(send_functions)) |
| recv_functions = ctx._recv_functions() |
| self.assertEqual(2, len(recv_functions)) |
| self._verify_graph_for_first_rpc_call(list(send_functions.values())[0], |
| list(recv_functions.values())[1], |
| t1, t2, ret) |
| self._verify_graph_for_rpc_call_exec(list(send_functions.values())[1]) |
| |
| # Verify two pairs of send and recv functions for nested |
| # call |
| self._check_rpc_done(1) |
| ctx = dist_autograd._retrieve_context(ctx_ids[1]) |
| self._verify_graph_for_nested_rpc_call(ctx) |
| # this barrier is needed so one worker does not clean up their |
| # autograd context before another worker tries to access it. |
| dist.barrier() |
| |
| @dist_init |
| def test_no_graph_with_tensors_not_require_grad(self): |
| dst_rank = (self.rank + 1) % self.world_size |
| with dist_autograd.context() as context_id: |
| t1 = torch.ones(3, 3, requires_grad=False) |
| t2 = torch.zeros(3, 3, requires_grad=False) |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) |
| rpc.rpc_sync("worker{}".format(dst_rank), |
| _set_rpc_done, args=(context_id, 1)) |
| |
| ctx = dist_autograd._current_context() |
| send_functions = ctx._send_functions() |
| self.assertEqual(len(send_functions), 0) |
| recv_functions = ctx._recv_functions() |
| self.assertEqual(len(recv_functions), 0) |
| |
| # Wait for the prev rank to be done with rpc. |
| self._check_rpc_done(1) |
| # prev context id is not passed over as tensors do not require grads |
| with self.assertRaises(RuntimeError): |
| ctx = dist_autograd._retrieve_context(ctx_ids[1]) |
| |
| @dist_init |
| def test_rpc_complex_args(self): |
| with dist_autograd.context() as context_id: |
| num_tensors = 10 |
| tensors = [] |
| for i in range(num_tensors): |
| tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0))) |
| ret = rpc.rpc_sync( |
| "worker{}".format(self._next_rank()), torch.stack, args=(tensors,) |
| ) |
| self.assertEqual(torch.stack(tensors), ret) |
| |
| # Verify appropriate tensors have been attached the autograd graph. |
| next_funcs = list( |
| dist_autograd._current_context()._send_functions().values() |
| )[0].next_functions |
| idx = 0 |
| for i in range(num_tensors): |
| if i % 2 == 0: |
| self.assertEqual( |
| "torch::autograd::AccumulateGrad", next_funcs[i][0].name() |
| ) |
| self.assertEqual(tensors[i], next_funcs[i][0].variable) |
| else: |
| self.assertIsNone(next_funcs[i][0]) |
| |
| # Verify that the worker id has been recorded in the context |
| ctx = dist_autograd._current_context() |
| worker_ids = ctx._known_worker_ids() |
| self.assertEqual(len(worker_ids), 1) |
| dst_rank = (self.rank + 1) % self.world_size |
| self.assertEqual(worker_ids[0], dst_rank) |
| |
| |
| @dist_init |
| def test_context_cleanup_many_workers(self): |
| global known_context_ids |
| dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} |
| with dist_autograd.context() as context_id: |
| t1 = torch.ones(3, 3, requires_grad=True) |
| t2 = torch.zeros(3, 3, requires_grad=True) |
| for dst_rank in dst_ranks: |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) |
| rpc.rpc_sync("worker{}".format(dst_rank), _store_context_id, args=(context_id,)) |
| # the thread's context id should be cleaned up |
| with self.assertRaises(RuntimeError): |
| dist_autograd._retrieve_context(context_id) |
| # check that all contexts have been cleaned up. |
| success = _all_contexts_cleaned_up(num_contexts=len(dst_ranks)) |
| self.assertTrue(success) |
| |
| @dist_init |
| def test_worker_ids_recorded(self): |
| dst_ranks = {rank for rank in range(self.world_size) if rank != self.rank} |
| with dist_autograd.context() as context_id: |
| # if no tensors require grad, we do not add the send functions, so |
| # no worker ids should be recorded. |
| t1 = torch.ones(3, 3, requires_grad=False) |
| t2 = torch.zeros(3, 3, requires_grad=False) |
| for dst_rank in dst_ranks: |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) |
| rpc.rpc_sync( |
| "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1) |
| ) |
| # no worker ids should be recorded. |
| ctx = dist_autograd._current_context() |
| worker_ids = ctx._known_worker_ids() |
| self.assertEqual(len(worker_ids), 0) |
| |
| # worker_ids should be recorded when tensors do require grad |
| t1.requires_grad = True |
| t2.requires_grad = True |
| for dst_rank in dst_ranks: |
| ret = rpc.rpc_sync("worker{}".format(dst_rank), torch.add, args=(t1, t2)) |
| rpc.rpc_sync( |
| "worker{}".format(dst_rank), _set_rpc_done, args=(context_id, 1) |
| ) |
| # all worker_ids in dst_ranks should be recorded. |
| worker_ids = ctx._known_worker_ids() |
| self.assertEqual(len(worker_ids), len(dst_ranks)) |
| self.assertEqual(set(worker_ids), dst_ranks) |
| |
| @dist_init |
| def test_error_in_context(self): |
| with dist_autograd.context() as context_id: |
| t1 = torch.rand(3, 3, requires_grad=True) |
| t2 = torch.rand(6, 6, requires_grad=True) |
| |
| |
| with self.assertRaises(RuntimeError): |
| # This should throw an error since matrix sizes don't match. |
| rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul, |
| args=(t1, t2)) |
| |
| def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args): |
| if exec_mode == ExecMode.REMOTE: |
| self._verify_backwards_remote(tensors, context_id, local_grads, *args) |
| else: |
| torch.autograd.backward(tensors) |
| return [arg.grad for arg in args] |
| |
| def _verify_backwards_remote(self, tensors, context_id, local_grads, *args): |
| dist_autograd.backward(tensors) |
| |
| # Verify grads were accumulated appropriately. |
| grads = dist_autograd.get_gradients(context_id) |
| nargs = len(args) |
| ngrads = 0 |
| for i in range(0, nargs): |
| if local_grads[i] is not None: |
| self.assertIn(args[i], grads) |
| self.assertEqual(local_grads[i], grads[args[i]]) |
| ngrads += 1 |
| else: |
| self.assertNotIn(args[i], grads) |
| |
| self.assertEqual(ngrads, len(grads)) |
| |
| |
| @dist_init |
| def test_backward_simple(self): |
| # Run the same code locally and with dist autograd and verify gradients |
| # are same. |
| local_grads = None |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: |
| with dist_autograd.context() as context_id: |
| ret = self._exec_func(exec_mode, torch.add, t1, t2) |
| loss = ret.sum() |
| local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2) |
| |
| @dist_init |
| def test_backward_multiple_round_trips(self): |
| local_grads = None |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3)) |
| t3 = torch.rand((3, 3), requires_grad=True) |
| t4 = torch.rand((3, 3)) |
| t5 = torch.rand((3, 3), requires_grad=True) |
| |
| for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: |
| with dist_autograd.context() as context_id: |
| # Multiple RPCs between different nodes. |
| val = self._exec_func(exec_mode, torch.add, t1, t2) |
| val = self._exec_func(exec_mode, torch.mul, t3, val) |
| s1 = self._exec_func(exec_mode, torch.stack, (t4, val)) |
| s2 = self._exec_func(exec_mode, torch.stack, (t5, val)) |
| val = self._exec_func(exec_mode, torch.bmm, s1, s2) |
| val = self._exec_func(exec_mode, torch.matmul, val, val) |
| loss = val.sum() |
| |
| local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5) |
| |
| @dist_init |
| def test_backward_different_tensor_dims(self): |
| local_grads = None |
| t1 = torch.rand((4, 6), requires_grad=True) |
| t2 = torch.rand((6, 5)) |
| t3 = torch.rand((5, 7), requires_grad=True) |
| t4 = torch.rand((7, 9)) |
| |
| for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: |
| with dist_autograd.context() as context_id: |
| val = self._exec_func(exec_mode, torch.matmul, t1, t2) |
| val = self._exec_func(exec_mode, torch.chain_matmul, [val, t3, t4]) |
| loss = val.sum() |
| |
| local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4) |
| |
| @dist_init |
| def test_backward_unused_tensors(self): |
| local_grads = None |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| t3 = torch.rand((3, 3), requires_grad=True) |
| for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: |
| with dist_autograd.context() as context_id: |
| s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3)) |
| val = self._exec_func(exec_mode, torch.matmul, torch.narrow(s, 0, 0, 1), torch.narrow(s, 0, 2, 1)) |
| |
| loss = val.sum() |
| local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3) |
| |
| @dist_init |
| def test_backward_multiple_output_tensors(self): |
| local_grads = None |
| t = torch.rand((10, 2), requires_grad=True) |
| for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: |
| with dist_autograd.context() as context_id: |
| tensor_list = self._exec_func(exec_mode, torch.split, t, 2) |
| t1 = tensor_list[0] |
| t2 = tensor_list[2] |
| t3 = tensor_list[4] |
| |
| val = self._exec_func(exec_mode, torch.chain_matmul, [t1, t2, t3]) |
| |
| loss = val.sum() |
| local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t) |
| |
| def _run_test_backward_unused_send_function_in_thread(self): |
| with dist_autograd.context() as context_id: |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| |
| # We don't use the result of an RPC function, as a result the |
| # backward pass would hang in the "FAST" mode. |
| res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add, |
| args=(t1, t2)) |
| |
| val = torch.mul(t1, t2) |
| |
| # Run backward, this would hang forever. |
| dist_autograd.backward([val.sum()]) |
| |
| |
| @dist_init |
| def test_backward_unused_send_function(self): |
| # Run the test in a thread which would never finish. |
| t = threading.Thread(target=self._run_test_backward_unused_send_function_in_thread) |
| t.daemon = True |
| t.start() |
| t.join(10) # Wait for 10s. |
| |
| # Verify thread is still alive (indicating backward hasn't completed yet). |
| self.assertTrue(t.is_alive()) |
| |
| @dist_init |
| def test_backward_autograd_engine_error(self): |
| with dist_autograd.context() as context_id: |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| t3 = SimulateBackwardError.apply(t1) |
| |
| # Run multiple round trips across different nodes and verify the |
| # original node receives an error thrown on a node deep in the chain. |
| val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add, |
| args=(t2, t3)) |
| val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.mul, |
| args=(val, t2)) |
| val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul, |
| args=(val, t2)) |
| val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.div, |
| args=(val, t2)) |
| |
| with self.assertRaises(RuntimeError): |
| # Run backwards, and validate we receive an error. |
| dist_autograd.backward([val.sum()]) |
| |
| @dist_init |
| @unittest.skip("Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures") |
| def test_backward_node_failure(self): |
| with dist_autograd.context() as context_id: |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| |
| res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add, |
| args=(t1, t2)) |
| |
| if self.rank == 0: |
| # Wait a bit for all other nodes to die. |
| time.sleep(3) |
| with self.assertRaises(RuntimeError): |
| # Run backwards, and validate we receive an error since all |
| # other nodes are dead. |
| dist_autograd.backward([res.sum()]) |
| else: |
| # Kill all other nodes. |
| sys.exit(0) |
| |
| @dist_init |
| def test_backward_without_context(self): |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| |
| with self.assertRaisesRegex(RuntimeError, "Current thread doesn't have a valid autograd context"): |
| res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add, |
| args=(t1, t2)) |
| dist_autograd.backward([res.sum()]) |
| |
| @dist_init |
| def test_backward_without_rpc(self): |
| dst_rank = self.rank |
| with dist_autograd.context() as context_id: |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| t3 = torch.add(t1, t2) |
| |
| dist_autograd.backward([t3.sum()]) |
| grads = dist_autograd.get_gradients(context_id) |
| self.assertEqual(2, len(grads)) |
| self.assertIn(t1, grads) |
| self.assertIn(t2, grads) |
| self.assertEqual(torch.ones(3, 3), grads[t1]) |
| self.assertEqual(torch.ones(3, 3), grads[t2]) |
| |
| @dist_init |
| def test_backward_invalid_args(self): |
| with dist_autograd.context() as context_id: |
| |
| with self.assertRaisesRegex(TypeError, "incompatible function arguments"): |
| dist_autograd.backward(None) |
| |
| with self.assertRaisesRegex(RuntimeError, "No tensors provided for gradient computation"): |
| dist_autograd.backward([]) |
| |
| with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"): |
| t = torch.rand(3, 3) |
| dist_autograd.backward([t]) |
| |
| with self.assertRaisesRegex(RuntimeError, "is not a scalar, all roots need to be scalar"): |
| t = torch.rand(3, 3, requires_grad=True) |
| dist_autograd.backward([t]) |
| |
| with self.assertRaisesRegex(RuntimeError, "does not have a valid gradient function"): |
| t = torch.rand(1, requires_grad=True) |
| dist_autograd.backward([t]) |
| |
| @dist_init |
| def test_backward_multiple_roots(self): |
| local_grads = None |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]: |
| with dist_autograd.context() as context_id: |
| r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum() |
| r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum() |
| r3 = self._exec_func(exec_mode, torch.cos, t1).sum() |
| r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum() |
| |
| local_grads = self._verify_backwards(exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |