Ensure tensors are contiguous in functional all_gather.

We called `tensor.contiguous()` in the forward pass, however this was
after the `out_tensor_list` was built which results in the `out_tensor_list`
containing non-contiguous tensors resulting in errors.

Fixing this by moving the contiguous call above.

Differential Revision: [D37222870](https://our.internmc.facebook.com/intern/diff/D37222870/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79747

Approved by: https://github.com/fduwjj, https://github.com/wanchaol
diff --git a/test/distributed/test_c10d_spawn_nccl.py b/test/distributed/test_c10d_spawn_nccl.py
index 54f176b..11583f2 100644
--- a/test/distributed/test_c10d_spawn_nccl.py
+++ b/test/distributed/test_c10d_spawn_nccl.py
@@ -157,6 +157,34 @@
         @requires_nccl()
         @skip_if_lt_x_gpu(2)
         @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
+        def test_reduce_scatter_non_contiguous(self):
+            store = c10d.FileStore(self.file_name, self.world_size)
+            # This is required because these functions calls directly to the .dist and needs
+            # the world to be initialized
+            c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
+            device = torch.device(f"cuda:{self.rank}")
+
+            class NonContiguousGrad(torch.autograd.Function):
+
+                @staticmethod
+                def forward(ctx, input):
+                    return input
+
+                @staticmethod
+                def backward(ctx, grad_output):
+                    # Make grad non-contiguous
+                    return grad_output.clone().transpose(0, 1)
+
+            x0 = torch.rand(5, 5, device=device, requires_grad=True)
+            x1 = torch.rand(5, 5, device=device, requires_grad=True)
+            y = torch.empty(5, 5, device=device)
+
+            y = torch.distributed.nn.reduce_scatter(y, [x0, x1])
+            NonContiguousGrad.apply(y).sum().backward()
+
+        @requires_nccl()
+        @skip_if_lt_x_gpu(2)
+        @sandcastle_skip_if(not _torch_dist_nn_available, "torch.distributed.nn is not available")
         def test_all_gather_base(self):
             store = c10d.FileStore(self.file_name, self.world_size)
             c10d.init_process_group(store=store, rank=self.rank, world_size=self.world_size, backend='nccl')
diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py
index cb98ba0..6460239 100644
--- a/torch/distributed/nn/functional.py
+++ b/torch/distributed/nn/functional.py
@@ -318,12 +318,15 @@
 class _AllGather(Function):
     @staticmethod
     def forward(ctx, group, tensor):
+        # Need contiguous tensors for collectives.
+        tensor = tensor.contiguous()
+
         ctx.group = group
         out_tensor_list = [
             torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group))
         ]
 
-        dist.all_gather(out_tensor_list, tensor.contiguous(), group=group)
+        dist.all_gather(out_tensor_list, tensor, group=group)
         return tuple(out_tensor_list)
 
     @staticmethod