distributed: templated ring attention (#124215)

This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR.

This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way.

Misc changes:

* Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test
* Adds compile support to the ring attention implementations (required some tweaks to process groups)

Test plan:

```
pytest test/distributed/_tensor/test_attention.py
pytest test/distributed/test_functional_api.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py
index 3a34af1..db5a26d 100644
--- a/test/distributed/_tensor/test_attention.py
+++ b/test/distributed/_tensor/test_attention.py
@@ -10,6 +10,8 @@
     _CausalBehavior,
     _is_causal_behavior,
     _scaled_dot_product_chunk_flash_attention,
+    _scaled_dot_product_ring_efficient_attention,
+    _scaled_dot_product_ring_flash_attention,
     attention_context_parallel,
     AttentionContextParallel,
 )
@@ -295,6 +297,86 @@
             },
         )
 
+    @skip_if_lt_x_gpu(2)
+    @unittest.skipIf(
+        not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention"
+    )
+    @with_comms
+    @parametrize(
+        "attention_fn",
+        [
+            _scaled_dot_product_ring_flash_attention,
+            _scaled_dot_product_ring_efficient_attention,
+            # _scaled_dot_product_ring_cudnn_attention, # TODO: not built by default
+        ],
+    )
+    def test_ring_attention_compile(self, attention_fn: object) -> None:
+        device_mesh = DeviceMesh(
+            self.device_type,
+            torch.arange(0, self.world_size),
+        )
+        dtype = torch.bfloat16
+        bs = 8
+        query_tokens = 8
+        context_tokens = 24
+        dim = 32
+        nheads = 8
+        query = torch.rand(
+            (bs, nheads, self.world_size * query_tokens, dim),
+            device=self.device_type,
+            dtype=dtype,
+            requires_grad=True,
+        )
+        key = torch.rand(
+            (bs, nheads, self.world_size * context_tokens, dim),
+            device=self.device_type,
+            dtype=dtype,
+        )
+        value = torch.rand(
+            (bs, nheads, self.world_size * context_tokens, dim),
+            device=self.device_type,
+            dtype=dtype,
+        )
+
+        query_placement = [Shard(2)]
+        dquery = distribute_tensor(query, device_mesh, query_placement)
+        self.assertEqual(query.shape, (bs, nheads, self.world_size * query_tokens, dim))
+
+        context_placement = [Shard(2)]
+        dkey = distribute_tensor(key, device_mesh, context_placement)
+        dvalue = distribute_tensor(value, device_mesh, context_placement)
+
+        # compiled = attention_fn
+        compiled = torch.compile(attention_fn, fullgraph=True, backend="aot_eager")
+
+        out, lse, *args = compiled(
+            device_mesh.get_group(),
+            dquery.to_local(),
+            dkey.to_local(),
+            dvalue.to_local(),
+        )
+        self.assertEqual(out.shape, (bs, nheads, query_tokens, dim))
+        self.assertIsInstance(lse, torch.Tensor)
+
+        (
+            out_chunk,
+            *others,
+        ) = _scaled_dot_product_chunk_flash_attention(
+            query,
+            key,
+            value,
+            size=self.world_size,
+            is_causal=False,
+        )
+        self.assertEqual(
+            out,
+            out_chunk[
+                :, :, self.rank * query_tokens : (self.rank + 1) * query_tokens, :
+            ],
+        )
+
+        out.sum().backward()
+
 
 instantiate_parametrized_tests(RingAttentionTest)
 
diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py
index d26dcf9..f225563 100644
--- a/test/distributed/test_functional_api.py
+++ b/test/distributed/test_functional_api.py
@@ -634,14 +634,14 @@
     def test_all_to_all_single(self, compile: bool = True) -> None:
         group = dist.group.WORLD.group_name
 
-        t = torch.rand((self.world_size, 2), requires_grad=True)
+        t = torch.ones((self.world_size, 2), requires_grad=True)
 
         def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
             sizes = [1] * world_size
-            t = t * 10
+            t = t * 2
             assert t.requires_grad
             out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
-            out = out + 2
+            out = out + 0
             return out
 
         if compile:
@@ -650,11 +650,13 @@
             compiled = my_func
 
         out = compiled(t, self.world_size)
+        self.assertEqual(out.shape, t.shape)
+        self.assertEqual(out, torch.full_like(t, 2.0))
         self.assertIsNotNone(out.grad_fn)
         self.assertTrue(out.requires_grad)
         loss = out.sum()
         loss.backward()
-        self.assertIsNotNone(t.grad)
+        self.assertEqual(t.grad, torch.full_like(t, 2.0))
 
     def test_all_to_all_single_inductor(self) -> None:
         group = dist.group.WORLD.group_name
@@ -752,5 +754,61 @@
             self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
 
 
+class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
+    def setUp(self):
+        super().setUp()
+        os.environ["WORLD_SIZE"] = str(self.world_size)
+        os.environ["BACKEND"] = dist.Backend.NCCL
+        self._spawn_processes()
+
+    @property
+    def device(self):
+        return torch.device(self.rank)
+
+    @property
+    def world_size(self):
+        return 2
+
+    @property
+    def process_group(self):
+        return dist.group.WORLD
+
+    def dist_init(self):
+        dist.init_process_group(
+            backend=BACKEND,
+            world_size=self.world_size,
+            rank=self.rank,
+            init_method=f"file://{self.file_name}",
+        )
+
+        # set device for nccl pg for collectives
+        if BACKEND == "nccl":
+            torch.cuda.set_device(self.rank)
+
+    def destroy_comms(self):
+        # Wait for all ranks to reach here before starting shutdown.
+        dist.barrier()
+        dist.destroy_process_group()
+
+    @requires_nccl()
+    @with_comms()
+    def test_all_to_all_single(self) -> None:
+        group = self.process_group.group_name
+
+        t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
+
+        sizes = [1] * self.world_size
+        assert t.requires_grad
+        out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0
+
+        self.assertEqual(out.shape, t.shape)
+        self.assertEqual(out, torch.full_like(t, 2.0))
+        self.assertIsNotNone(out.grad_fn)
+        self.assertTrue(out.requires_grad)
+        loss = out.sum()
+        loss.backward()
+        self.assertEqual(t.grad, torch.full_like(t, 2.0))
+
+
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp
index 942ae73..5728774 100644
--- a/torch/csrc/distributed/c10d/Functional.cpp
+++ b/torch/csrc/distributed/c10d/Functional.cpp
@@ -409,7 +409,7 @@
     const std::string& group_name = ctx->saved_data["group_name"].toStringRef();
 
     DCHECK(grad_out_list.size() == 1);
-    auto grad_out = grad_out_list[0];
+    auto grad_out = grad_out_list[0].contiguous();
 
     auto out =
         c10::Dispatcher::singleton()
@@ -434,7 +434,7 @@
     const std::vector<int64_t>& input_split_sizes,
     const std::string& group_name) {
   return AllToAllSingle::apply(
-      input, output_split_sizes, input_split_sizes, group_name)[0];
+      input, output_split_sizes, input_split_sizes, group_name);
 }
 
 class ReduceScatterTensor
diff --git a/torch/distributed/_tensor/debug/comm_mode.py b/torch/distributed/_tensor/debug/comm_mode.py
index 43def0b..b195b30 100644
--- a/torch/distributed/_tensor/debug/comm_mode.py
+++ b/torch/distributed/_tensor/debug/comm_mode.py
@@ -8,6 +8,7 @@
 
 funcol_native = torch.ops._c10d_functional
 funcol_py = torch.ops.c10d_functional
+funcol_autograd = torch.ops._c10d_functional_autograd
 
 NATIVE_TO_PY_MAPPING = {
     funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor,
@@ -17,6 +18,8 @@
     funcol_native.broadcast: funcol_py.broadcast,
     funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor,
     funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced,
+    # functional ops
+    funcol_autograd.all_to_all_single: funcol_py.all_to_all_single,
 }
 
 
diff --git a/torch/distributed/_tensor/experimental/attention.py b/torch/distributed/_tensor/experimental/attention.py
index 195a94f..eb7703a 100644
--- a/torch/distributed/_tensor/experimental/attention.py
+++ b/torch/distributed/_tensor/experimental/attention.py
@@ -1,7 +1,7 @@
 import contextlib
 import weakref
 from enum import Enum
-from typing import Any, Dict, Generator, List, Optional, Tuple, Union
+from typing import Any, Dict, Generator, List, Optional, Protocol, Tuple, Union
 
 import torch
 import torch.distributed as dist
@@ -54,6 +54,10 @@
     """
     assert len(chunks) == len(logsumexps)
 
+    # LSE may be padded in the sequence dimension such as with memory efficient attention.
+    seq_len = chunks[0].size(2)
+    logsumexps = [lse[:, :, :seq_len] for lse in logsumexps]
+
     softmax_lse = torch.stack([lse.exp() for lse in logsumexps]).sum(dim=0).log_()
 
     out = []
@@ -80,19 +84,148 @@
     if return_debug_mask:
         raise NotImplementedError("return_debug_mask is not supported yet")
 
+    return _templated_ring_attention(
+        mesh,
+        torch.ops.aten._scaled_dot_product_flash_attention,
+        query=query,
+        key=key,
+        value=value,
+        dropout_p=dropout_p,
+        is_causal=is_causal,
+        scale=scale,
+    )
+
+
+def _scaled_dot_product_ring_efficient_attention(
+    mesh: DeviceMesh,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_bias: Optional[torch.Tensor] = None,
+    dropout_p: float = 0.0,
+    is_causal: bool = False,
+    compute_log_sumexp: bool = True,
+    *,
+    scale: Optional[float] = None,
+) -> Tuple[torch.Tensor, ...]:
+    if attn_bias is not None:
+        raise NotImplementedError("attn_bias is not supported yet")
+    if not compute_log_sumexp:
+        raise NotImplementedError("compute_log_sumexp must be set")
+
+    return _templated_ring_attention(
+        mesh,
+        torch.ops.aten._scaled_dot_product_efficient_attention,
+        query=query,
+        key=key,
+        value=value,
+        attn_bias=attn_bias,
+        dropout_p=dropout_p,
+        is_causal=is_causal,
+        scale=scale,
+        compute_log_sumexp=compute_log_sumexp,
+    )
+
+
+def _scaled_dot_product_ring_cudnn_attention(
+    mesh: DeviceMesh,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    attn_bias: Optional[torch.Tensor] = None,
+    dropout_p: float = 0.0,
+    is_causal: bool = False,
+    return_debug_mask: bool = True,
+    *,
+    scale: Optional[float] = None,
+) -> Tuple[torch.Tensor, ...]:
+    if not return_debug_mask:
+        raise NotImplementedError("return_debug_mask must be set")
+
+    return _templated_ring_attention(
+        mesh,
+        torch.ops.aten._scaled_dot_product_cudnn_attention,
+        query=query,
+        key=key,
+        value=value,
+        dropout_p=dropout_p,
+        is_causal=is_causal,
+        return_debug_mask=return_debug_mask,
+        scale=scale,
+    )
+
+
+def _ring_rotate(block: torch.Tensor, pg: dist.ProcessGroup) -> torch.Tensor:
+    rank = dist.get_rank(pg)
+    size = dist.get_world_size(pg)
+
+    # rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0
+    input_split_sizes = [0] * size
+    input_split_sizes[(rank + 1) % size] = len(block)
+    output_split_sizes = [0] * size
+    output_split_sizes[(rank - 1) % size] = len(block)
+
+    out = ft_c.all_to_all_single_autograd(
+        block, input_split_sizes, output_split_sizes, pg
+    )
+    return out
+
+
+class AttentionOp(Protocol):
+    def __call__(
+        self,
+        query: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        *args: object,
+        is_causal: bool = False,
+        **kwargs: object,
+    ) -> Tuple[torch.Tensor, ...]:
+        ...
+
+
+def _templated_ring_attention(
+    mesh: DeviceMesh,
+    op: AttentionOp,
+    query: torch.Tensor,
+    key: torch.Tensor,
+    value: torch.Tensor,
+    *args: object,
+    is_causal: bool = False,
+    **kwargs: object,
+) -> Tuple[torch.Tensor, ...]:
+    """
+    This is a generalized ring attention implementation that can support multiple attention ops.
+
+    Parameters
+    ----------
+    op:
+        The attention op to use
+    *args:
+        additional args are passed to the op
+    **kwargs:
+        additional kwargs are passed to the op
+
+    Returns
+    -------
+    out:
+        The merged attention output
+    softmax_lse:
+        The logsumexp of the merged attention output
+    """
     if is_causal and (query.size(2) != key.size(2)):
         raise NotImplementedError(
             "is_causal requires the same query and context sequence lengths"
         )
 
-    pg = mesh.get_group()
-    assert isinstance(pg, dist.ProcessGroup), "must be single dimension"
+    if isinstance(mesh, dist.ProcessGroup):
+        pg: Union[dist.ProcessGroup, List[dist.ProcessGroup]] = mesh
+    else:
+        pg = mesh.get_group()
+    assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension"
     rank = dist.get_rank(pg)
     size = dist.get_world_size(pg)
 
-    # rank 0 sends to rank 1, rank 1 sends to rank 2, ..., rank n-1 sends to rank 0
-    right_dsts = list(range(1, size)) + [0]
-
     next_kv = None
 
     chunks = []
@@ -106,20 +239,20 @@
 
         if i < (size - 1):
             next_kv = torch.cat([key.flatten(), value.flatten()])
-            next_kv = ft_c.permute_tensor(next_kv, right_dsts, pg)
+            next_kv = _ring_rotate(next_kv, pg)
 
         is_causal_behavior = _is_causal_behavior(
             rank=rank, world_size=size, i=i, is_causal=is_causal
         )
 
         if is_causal_behavior != _CausalBehavior.SKIP:
-            local_results = torch.ops.aten._scaled_dot_product_flash_attention(
+            local_results = op(
                 query,
                 key,
                 value,
-                dropout_p=dropout_p,
+                *args,
                 is_causal=is_causal_behavior.value,
-                scale=scale,
+                **kwargs,
             )
             chunks.append(local_results[0])
             logsumexps.append(local_results[1])