Ensure tensors are contiguous in functional all_gather.
We called `tensor.contiguous()` in the forward pass, however this was
after the `out_tensor_list` was built which results in the `out_tensor_list`
containing non-contiguous tensors resulting in errors.
Fixing this by moving the contiguous call above.
Differential Revision: [D37222870](https://our.internmc.facebook.com/intern/diff/D37222870/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79747
Approved by: https://github.com/fduwjj, https://github.com/wanchaol
diff --git a/test/distributed/test_c10d_spawn_nccl.py b/test/distributed/test_c10d_spawn_nccl.py
index 54f176b..11583f2 100644
--- a/test/distributed/test_c10d_spawn_nccl.py
+++ b/test/distributed/test_c10d_spawn_nccl.py
@@ -157,6 +157,34 @@
@requires_nccl()
@skip_if_lt_x_gpu(2)
@sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+ def test_reduce_scatter_non_contiguous(self):
+ store = c10d.FileStore(self.file_name, self.world_size)
+ # This is required because these functions calls directly to the .dist and needs
+ # the world to be initialized
+ c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
+ device = torch.device(f"cuda:{self.rank}")
+
+ class NonContiguousGrad(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input):
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # Make grad non-contiguous
+ return grad_output.clone().transpose(0, 1)
+
+ x0 = torch.rand(5, 5, device=device, requires_grad=True)
+ x1 = torch.rand(5, 5, device=device, requires_grad=True)
+ y = torch.empty(5, 5, device=device)
+
+ y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
+ NonContiguousGrad.apply(y).sum().backward()
+
+ @requires_nccl()
+ @skip_if_lt_x_gpu(2)
+ @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
def test_all_gather_base(self):
store = c10d.FileStore(self.file_name, self.world_size)
c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py
index cb98ba0..6460239 100644
--- a/torch/distributed/nn/functional.py
+++ b/torch/distributed/nn/functional.py
@@ -318,12 +318,15 @@
class _AllGather(Function):
@staticmethod
def forward(ctx, group, tensor):
+ # Need contiguous tensors for collectives.
+ tensor = tensor.contiguous()
+
ctx.group = group
out_tensor_list = [
torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group))
]
- dist.all_gather(out_tensor_list, tensor.contiguous(), group=group)
+ dist.all_gather(out_tensor_list, tensor, group=group)
return tuple(out_tensor_list)
@staticmethod