DDP communication hook: skip dividing grads by world_size if hook registered. (#42400)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42400
mcarilli spotted that in the original DDP communication hook design described in [39272](https://github.com/pytorch/pytorch/issues/39272), the hooks receive grads that are already predivided by world size.
It makes sense to skip the divide completely if hook registered. The hook is meant for the user to completely override DDP communication. For example, if the user would like to implement something like GossipGrad, always dividing by the world_size would not be a good idea.
We also included a warning in the register_comm_hook API as:
> GradBucket bucket's tensors will not be predivided by world_size. User is responsible to divide by the world_size in case of operations like allreduce.
ghstack-source-id: 109548696
**Update:** We discovered and fixed a bug with the sparse tensors case. See new unit test called `test_ddp_comm_hook_sparse_gradients` and changes in `reducer.cpp`.
Test Plan: python test/distributed/test_c10d.py and perf benchmark tests.
Reviewed By: ezyang
Differential Revision: D22883905
fbshipit-source-id: 3277323fe9bd7eb6e638b7ef0535cab1fc72f89e
diff --git a/test/distributed/test_c10d.py b/test/distributed/test_c10d.py
index 89de755..57ffec2 100644
--- a/test/distributed/test_c10d.py
+++ b/test/distributed/test_c10d.py
@@ -1942,6 +1942,15 @@
return self.t0(x + rank)
+class SparseGradientModule(nn.Module):
+ def __init__(self):
+ super(SparseGradientModule, self).__init__()
+ self.embedding = nn.EmbeddingBag(10, 10, sparse=True)
+
+ def forward(self, x):
+ return F.softmax(self.embedding(x), dim=1)
+
+
@unittest.skipIf(TEST_WITH_TSAN, "TSAN is not fork-safe since we're forking in a multi-threaded environment")
class DistributedDataParallelTest(MultiProcessTestCase):
def setUp(self):
@@ -2822,28 +2831,7 @@
loss = criterion(output, target)
loss.backward()
- @requires_gloo()
- def test_sparse_gradients(self):
- store = c10d.FileStore(self.file_name, self.world_size)
- process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
-
- class SparseGradientModule(nn.Module):
- def __init__(self):
- super(SparseGradientModule, self).__init__()
- self.embedding = nn.EmbeddingBag(10, 10, sparse=True)
-
- def forward(self, x):
- return F.softmax(self.embedding(x), dim=1)
-
- # Ensure initialized weights and inputs are identical across processes
- torch.manual_seed(1337)
-
- vanilla_model = SparseGradientModule()
- ddp_model = DistributedDataParallel(
- copy.deepcopy(vanilla_model),
- process_group=process_group,
- )
-
+ def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
mult = 2
batch_size = mult * self.world_size
criterion = nn.CrossEntropyLoss()
@@ -2863,6 +2851,22 @@
ddp_parameter = next(ddp_model.parameters())
self.assertEqual(vanilla_parameter.grad, ddp_parameter.grad)
+ @requires_gloo()
+ def test_sparse_gradients(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
+
+ # Ensure initialized weights and inputs are identical across processes
+ torch.manual_seed(1337)
+
+ vanilla_model = SparseGradientModule()
+ ddp_model = DistributedDataParallel(
+ copy.deepcopy(vanilla_model),
+ process_group=process_group,
+ )
+
+ self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
+
def _test_grad_layout(self, replica_devices, layer_devs, local_batch_size):
store = c10d.FileStore(self.file_name, self.world_size)
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
@@ -3113,7 +3117,8 @@
process_group = c10d.ProcessGroupNCCL(store, self.rank, self.world_size)
def allreduce_hook(state: object, bucket: dist._GradBucket) -> torch._C.Future:
- return process_group.allreduce(bucket.get_tensors()).get_future()
+ tensors = [t / self.world_size for t in bucket.get_tensors()]
+ return process_group.allreduce(tensors).get_future()
# Get GPU model with allreduce_hook registered.
gpu_model = self._gpu_model_with_ddp_comm_hook(process_group, allreduce_hook)
@@ -3134,7 +3139,8 @@
def allreduce_with_then_hook(
state: object, bucket: dist._GradBucket
) -> torch.futures.Future:
- fut = process_group.allreduce(bucket.get_tensors()).get_future()
+ tensors = [t / self.world_size for t in bucket.get_tensors()]
+ fut = process_group.allreduce(tensors).get_future()
def mult(fut):
# Multiply the result by 10.
@@ -3240,6 +3246,39 @@
):
model._register_comm_hook(None, dummy_hook)
+ @requires_gloo()
+ def test_ddp_comm_hook_sparse_gradients(self):
+ """
+ Runs "test_sparse_gradients" unit test with DDP communication hook. We define a
+ simple hook that does allreduce and works with gloo backend for this test.
+ """
+ store = c10d.FileStore(self.file_name, self.world_size)
+ process_group = c10d.ProcessGroupGloo(store, self.rank, self.world_size)
+
+ # Ensure initialized weights and inputs are identical across processes
+ torch.manual_seed(1337)
+
+ vanilla_model = SparseGradientModule()
+ ddp_model = DistributedDataParallel(
+ copy.deepcopy(vanilla_model),
+ process_group=process_group,
+ )
+
+ # "get_future" API does not support gloo backend, see GH Issue #42048.
+ # Instead, we wait for an allreduce work, and write its result to a Future.
+ def allreduce_hook_gloo(state: object, bucket: dist._GradBucket) -> torch.futures.Future:
+ # Prepare allreduced grad bucket tensors by running an async work.
+ work = process_group.allreduce(bucket.get_tensors())
+ work.wait()
+
+ fut = torch.futures.Future()
+ fut.set_result([t / self.world_size for t in bucket.get_tensors()])
+ return fut
+
+ ddp_model._register_comm_hook(None, allreduce_hook_gloo)
+
+ self._run_and_verify_sparse_gradients(vanilla_model, ddp_model)
+
class ReducerModule(nn.Module):
def __init__(self):
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index 5d58d3a..15e4658 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -727,7 +727,8 @@
``allreduce`` work.
>>> def allreduce(state: object, bucket: dist._GradBucket): -> torch._C.Future
- >>> work = process_group.allreduce(bucket.get_tensors())
+ >>> tensors = [t / process_group.world_size for t in bucket.get_tensors()]
+ >>> work = process_group.allreduce(tensors)
>>> return work.get_future()
>>> ddp_model._register_comm_hook(state = None, hook = allreduce)
diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp
index ff645e0..5a68b40 100644
--- a/torch/csrc/distributed/c10d/reducer.cpp
+++ b/torch/csrc/distributed/c10d/reducer.cpp
@@ -175,7 +175,11 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~
// If DDP communication hook is not registered, the reducer reduces the buckets
// by just calling allreduce. If registered, it calls the hook and uses future
-// work handle.
+// work handle. If registered, reducer also skips dividing grads by world size.
+// The reason for this is that the communication hook is expected to completely
+// override how we perform communication and the user should have complete
+// control over how the grads are handled.
+//
// DDP communication hook is an enhancement that provides a hook which can be
// used to override how DDP communicates gradients across ranks, this can be
// used for algorithms like Gradient Compression/GossipGrad. This hook can be
@@ -351,12 +355,17 @@
", strides() = ",
bucket_view.strides());
}
- // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
- auto wrapped =
- c10::scalar_to_tensor(double(1.) / process_group_->getSize());
- wrapped.unsafeGetTensorImpl()->set_wrapped_number(true);
- // Divides while copying into the bucket view.
- at::native::mul_out(bucket_view, grad, wrapped);
+ // See Note [DDP Communication Hook]
+ if (comm_hook_ == nullptr) {
+ // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
+ auto wrapped =
+ c10::scalar_to_tensor(double(1.) / process_group_->getSize());
+ wrapped.unsafeGetTensorImpl()->set_wrapped_number(true);
+ // Divides while copying into the bucket view.
+ at::native::mul_out(bucket_view, grad, wrapped);
+ } else {
+ bucket_view.copy_(grad);
+ }
} else {
bucket_view.zero_();
}
@@ -385,7 +394,10 @@
// struct are empty, and there is no pre-existing accumulation tensor.
// Directly assign the sparse tensor to the `contents` field.
replica.contents = grad;
- replica.contents.div_(process_group_->getSize());
+ // See Note [DDP Communication Hook]
+ if (comm_hook_ == nullptr) {
+ replica.contents.div_(process_group_->getSize());
+ }
// The grad is modified in place and needs to be written back.
return true;
});
@@ -968,11 +980,15 @@
auto future_result =
comm_hook_->processFuture(bucket.future_work->value());
- // Reinitialize bucket_views with the future_result by following
- // the same logic in `inititalize_buckets`.
for (size_t i = 0; i < future_result.size(); i++) {
- bucket.replicas[i].bucket_views.clear();
- initialize_bucketviews(bucket.replicas[i], future_result[i]);
+ if (bucket.expect_sparse_gradient) {
+ bucket.replicas[i].contents.copy_(future_result[i]);
+ } else {
+ // Reinitialize bucket_views with the future_result by following
+ // the same logic in `inititalize_buckets`.
+ bucket.replicas[i].bucket_views.clear();
+ initialize_bucketviews(bucket.replicas[i], future_result[i]);
+ }
}
}
if (!bucket.expect_sparse_gradient) {
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index c0f958c..26626f3 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -645,6 +645,10 @@
Future associated with the completion of ``c10d.ProcessGroup.work``.
.. warning ::
+ Grad bucket's tensors will not be predivided by world_size. User is responsible
+ to divide by the world_size in case of operations like allreduce.
+
+ .. warning ::
DDP communication hook can only be registered once and should be registered
before calling backward.
@@ -680,7 +684,8 @@
allreduce, and then decoded after allreduce.
>>> def encode_and_decode(state: object, bucket: dist._GradBucket): -> torch.futures.Future
- >>> encoded_tensors = encode(bucket.get_tensors()) # encode gradients
+ >>> tensors = [t / process_group.world_size for t in bucket.get_tensors()]
+ >>> encoded_tensors = encode(tensors) # encode gradients
>>> fut = process_group.allreduce(encoded_tensors).get_future()
>>> # Define the then callback to decode.
>>> def decode(fut):