[PTD] Introduce tracing friendly collectives. (#93990)

This change adds torch.distributed.traceable_collectives.

This experimental API enables collectives to be fully traced by dynamo and FX.

See #93173 for the RFC

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93990
Approved by: https://github.com/wconstab, https://github.com/wanchaol, https://github.com/H-Huang
diff --git a/aten/src/ATen/native/Collectives.cpp b/aten/src/ATen/native/Collectives.cpp
new file mode 100644
index 0000000..44e1399
--- /dev/null
+++ b/aten/src/ATen/native/Collectives.cpp
@@ -0,0 +1,29 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+
+#include <ATen/core/Tensor.h>
+#include <ATen/Parallel.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
+#endif
+
+namespace at {
+namespace native {
+
+// Dummy impl required by codegen infra, not used
+at::Tensor all_reduce(at::Tensor const& self, const c10::string_view reduceOp, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
+    // This should never get called
+    // Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
+    TORCH_INTERNAL_ASSERT(false);
+}
+
+at::Tensor wait_tensor(at::Tensor const& self) {
+    // This should never get called
+    // Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
+    TORCH_INTERNAL_ASSERT(false);
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2b1ffb3..522cdcc 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -14670,3 +14670,18 @@
   dispatch:
     CUDA: _fused_adamw_kernel_cuda_
   autogen: _fused_adamw, _fused_adamw.out
+
+# Collectives
+- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
+  # This should be changed to distributed but it requires changes all over the place to work
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: all_reduce
+  variants: function
+
+- func: wait_tensor(Tensor self) -> Tensor
+  # This should be changed to distributed but it requires changes all over the place to work
+  python_module: nn
+  dispatch:
+    CompositeExplicitAutograd: wait_tensor
+  variants: function
diff --git a/build_variables.bzl b/build_variables.bzl
index 59e21c3..f5a465a 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -1231,6 +1231,7 @@
     "aten/src/ATen/native/Bucketization.cpp",
     "aten/src/ATen/native/CPUBlas.cpp",
     "aten/src/ATen/native/ChanelShuffle.cpp",
+    "aten/src/ATen/native/Collectives.cpp",
     "aten/src/ATen/native/Col2Im.cpp",
     "aten/src/ATen/native/PadNd.cpp",
     "aten/src/ATen/native/Convolution.cpp",
diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py
new file mode 100644
index 0000000..8ccff34
--- /dev/null
+++ b/test/distributed/test_functional_api.py
@@ -0,0 +1,269 @@
+# Owner(s): ["oncall: distributed"]
+
+import sys
+import torch
+import torch.distributed as dist
+import torch.distributed._functional_collectives as ft_c
+import torch.distributed.distributed_c10d as c10d
+import torch.distributed._tensor as dt
+
+from functorch import make_fx
+
+if not dist.is_available():
+    print("Distributed not available, skipping tests", file=sys.stderr)
+    sys.exit(0)
+
+from torch.testing._internal.common_distributed import (
+    MultiThreadedTestCase,
+)
+from torch.testing._internal.common_utils import (
+    run_tests,
+    TestCase
+)
+
+def new_subgroups(group_size: int, pg_tag=None):
+    world_size = dist.get_world_size()
+    subgroups = []
+    cur_subgroup = None
+
+    for subgroup_id in range(world_size // group_size):
+        start_rank = subgroup_id * group_size
+        end_rank = start_rank + group_size
+        ranks_in_subgroup = list(range(start_rank, end_rank))
+        subgroup = c10d._new_group_with_tag(
+            ranks=ranks_in_subgroup,
+            pg_tag=pg_tag,
+        )
+        subgroups.append(subgroup)
+
+        rank = dist.get_rank()
+        if rank in ranks_in_subgroup:
+            cur_subgroup = subgroup
+
+    return cur_subgroup, subgroups
+
+
+class TestExpand(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return 4
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def test_expand_1d_rank_list(self):
+        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
+        self.assertEqual("", tag)
+        self.assertEqual([0, 1, 2, 3], rankset)
+        self.assertEqual(4, group_size)
+
+        tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
+        self.assertEqual("bla", tag)
+
+    def test_expand_2d_rank_list(self):
+        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
+        self.assertEqual("", tag)
+        self.assertEqual([0, 1, 2, 3], rankset)
+        self.assertEqual(2, group_size)
+
+        tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
+        self.assertEqual("blu", tag)
+
+        with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
+            ft_c._expand_group([[0], [1, 2, 3]])
+
+    def test_expand_process_group(self):
+        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
+        self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
+        self.assertEqual([0, 1, 2, 3], rankset)
+        self.assertEqual(4, group_size)
+
+        tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
+        self.assertEqual("bla", tag)
+
+        my_pg, others = new_subgroups(group_size=2)
+        tag, rankset, group_size = ft_c._expand_group(my_pg)
+        self.assertEqual(c10d._get_group_tag(my_pg), tag)
+        self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
+        self.assertEqual(2, group_size)
+
+        my_pg = None
+        for i in range(dist.get_world_size()):
+            group = c10d._new_group_with_tag([i], pg_tag="my_pg")
+            if i == dist.get_rank():
+                my_pg = group
+        tag, rankset, group_size = ft_c._expand_group(my_pg)
+        self.assertEqual("my_pg", tag)
+        self.assertEqual([dist.get_rank()], rankset)
+        self.assertEqual(1, group_size)
+
+        tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
+        self.assertEqual("bla", tag)
+
+    def test_expand_device_mesh(self):
+        mesh = dt.DeviceMesh("cpu", torch.arange(4))
+        tag, rankset, group_size = ft_c._expand_group(mesh)
+        self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+        self.assertEqual([0, 1, 2, 3], rankset)
+        self.assertEqual(4, group_size)
+
+        mesh = dt.DeviceMesh("cpu", torch.arange(4))
+        tag, rankset, group_size = ft_c._expand_group(mesh)
+        self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+        self.assertEqual([0, 1, 2, 3], rankset)
+        self.assertEqual(4, group_size)
+
+    def test_expand_device_mesh_tuple(self):
+        mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
+        tag, rankset, group_size = ft_c._expand_group(mesh)
+        self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+        self.assertEqual([0, 2, 1, 3], rankset)
+        self.assertEqual(2, group_size)
+
+        tag, rankset, group_size = ft_c._expand_group((mesh, 0))
+        self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+        self.assertEqual([0, 2, 1, 3], rankset)
+        self.assertEqual(2, group_size)
+
+        tag, rankset, group_size = ft_c._expand_group((mesh, 1))
+        self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[1]), tag)
+        self.assertEqual([0, 1, 2, 3], rankset)
+        self.assertEqual(2, group_size)
+
+class TestPgTag(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return 4
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    """
+    The behavior we want is as follow:
+
+    - rankset+tag will always result in the same PG.
+    Do we enforce this by failing creation of new PGs or returning existing ones?
+        Return existing one.
+
+    - default tag gives existing behavior.
+        This means we should create duplicates.
+    - _expand_group on _default-tagged pg should always resolve to it
+        This mean we can't depend on empty tag + rankset.
+    """
+    def test_pg_creation_with_tag(self):
+        my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
+        my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
+        self.assertEqual(my_group, my_group2)
+
+        my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
+        self.assertNotEqual(my_group, my_group3)
+
+        my_group4, _ = new_subgroups(group_size=2)
+        self.assertNotEqual(my_group, my_group4)
+
+        my_group5, _ = new_subgroups(group_size=2)
+        self.assertNotEqual(my_group4, my_group5)
+
+    def test_pg_lookup_roundtrip(self):
+        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
+        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
+        pg_notag0, _ = new_subgroups(group_size=2)
+        pg_notag1, _ = new_subgroups(group_size=2)
+
+        def roundtrip(pg):
+            tag, rankset, _ = ft_c._expand_group(pg)
+            return c10d._find_pg_by_ranks_and_tag(tag, rankset)
+
+        self.assertEqual(pg_tag0, roundtrip(pg_tag0))
+        self.assertEqual(pg_tag1, roundtrip(pg_tag1))
+        self.assertEqual(pg_notag0, roundtrip(pg_notag0))
+        self.assertEqual(pg_notag1, roundtrip(pg_notag1))
+
+    def test_pg_lookup_with_tag(self):
+        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
+        pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
+        pg_notag0, _ = new_subgroups(group_size=2)
+
+        def roundtrip(pg, pg_tag):
+            tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
+            return c10d._find_pg_by_ranks_and_tag(tag, rankset)
+
+        self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
+        self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
+        # Cannot erase the tag of a PG
+        self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
+
+    def test_find_or_create_pg(self):
+        pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
+        pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
+        self.assertEqual(pg, pg_tag0)
+
+    def test_find_root_pg(self):
+        pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
+        self.assertEqual(dist.group.WORLD, pg)
+
+class TestTraceableCollectives(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return 4
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def test_all_reduce_eager(self):
+        tensor = torch.ones([4])
+        mesh = dt.DeviceMesh("cpu", torch.arange(4))
+
+        res = ft_c.all_reduce(tensor, "sum", mesh)
+        self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
+
+        mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
+        res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
+        self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
+
+class TestMetaCollectives(TestCase):
+    def test_all_reduce(self):
+        x = torch.rand((2, 3, 4), device="meta")
+        out = ft_c.all_reduce(x, "sum", [1])
+        self.assertEqual(x.size(), out.size())
+
+class TestGradCollectives(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return 2
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def test_all_reduce(self):
+        x = torch.rand([4], requires_grad=True)
+        y = torch.rand([4], requires_grad=True)
+        out = ft_c.all_reduce(x, "sum", [0, 1])
+        (out + y).sum().backward()
+        self.assertIsNone(x.grad)
+
+class TestMakeFx(MultiThreadedTestCase):
+    @property
+    def world_size(self):
+        return 2
+
+    def setUp(self):
+        super().setUp()
+        self._spawn_threads()
+
+    def test_all_reduce_tracing(self):
+        def allred(input):
+            return ft_c.all_reduce(input, "sum", group=[0, 1]) + 1
+
+        graph = make_fx(allred)(torch.rand(4))
+        nodes = list(graph.graph.nodes)
+
+        self.assertEqual("aten::all_reduce", nodes[1].target.name())
+        self.assertEqual("aten::wait_tensor", nodes[2].target.name())
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index d5a174e..debcf63 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -575,6 +575,7 @@
 aten::affine_grid_generator.out
 aten::alias_copy
 aten::alias_copy.out
+aten::all_reduce
 aten::allclose
 aten::aminmax
 aten::aminmax.out
@@ -1339,6 +1340,7 @@
 aten::view_copy.dtype
 aten::view_copy.dtype_out
 aten::view_copy.out
+aten::wait_tensor
 aten::zeros.names
 aten::zeros.names_out
 aten::zeros.out
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 3ad1866..837c12b 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -2702,4 +2702,14 @@
                 _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
 
 
+@register_meta(aten.all_reduce)
+def all_reduce_meta(self, reduceOp, tag, rankset, stride):
+    return torch.empty_like(self)
+
+
+@register_meta(aten.wait_tensor)
+def wait_tensor_meta(self):
+    return torch.empty_like(self)
+
+
 activate_meta()
diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py
new file mode 100644
index 0000000..8af8f5f
--- /dev/null
+++ b/torch/distributed/_functional_collectives.py
@@ -0,0 +1,237 @@
+from typing import Any, Tuple, Union, List, cast
+
+import weakref
+import warnings
+
+import torch
+import torch.distributed as dist
+
+from torch._C import _disabled_torch_function_impl
+from torch.utils._pytree import tree_map
+
+import torch.distributed.distributed_c10d as c10d
+"""
+New traceable, functional collectives.
+RFC: https://github.com/pytorch/pytorch/issues/93173
+
+  compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
+  eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
+         automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
+         a downstream op.
+
+Issues:
+* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
+* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
+"""
+
+"""
+Functional collectives are asynchronous only and we perform implicit stream synchronization
+on behalf of the user.
+
+We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
+first usage of the tensor and insert cross stream sync at the right place.
+
+The above are the easy bits, the hard one is how we match the Work object returned by
+c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
+op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
+dispatcher which might call other implementations that are allowed to change the returned
+tensor - even return a tensor with a different shape (see ``torch.vmap``).
+
+This means the caller of our ops receives a Tensor that is not guaranteed to be the same
+allocated by our implementations and that makes pairing The AsyncTensor to the original
+tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
+
+Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
+identity is not stable across dispatch, the op caller would end up with a different Tensor
+instance that would not match any in the dictionary.
+
+With Tensor identity out of the question, we decided use the tensor data pointer, which
+should be stable across all the Tensor changes done during dispatch.
+
+We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
+
+We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
+
+Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
+can clean up stale entries in the dictionary.
+
+To eliminate the possiblity of races we have a global version counter that is used by the finalizer.
+
+As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
+
+"""
+data_ptr_to_work = dict()
+work_version = 0
+
+def _register_tensor_work(tensor, work):
+    global data_ptr_to_work
+    global work_version
+    data_ptr_to_work[tensor.data_ptr()] = (work_version, work)
+    work_version += 1
+
+def _clear_tensor(data_ptr, version):
+    global data_ptr_to_work
+    version_and_work = data_ptr_to_work.get(data_ptr)
+
+    if version_and_work is not None and version_and_work[0] == version:
+        del data_ptr_to_work[data_ptr]
+
+def _register_wrapper_tensor(tensor_wrapper, tensor):
+    global data_ptr_to_work
+    version, _ = data_ptr_to_work.get(tensor.data_ptr(), (None, None))
+    if version is None:
+        warnings.warn("Trying to register finalizers to AsyncCollectiveTensor but the inner tensor is already gone")
+    else:
+        weakref.finalize(tensor_wrapper, _clear_tensor, tensor.data_ptr(), version)
+
+def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
+    global data_ptr_to_work
+    data_ptr = tensor.data_ptr()
+    version_and_work = data_ptr_to_work.get(data_ptr)
+    if version_and_work is not None:
+        version_and_work[1].wait()
+        _clear_tensor(data_ptr, version_and_work[0])
+    return tensor
+
+
+class AsyncCollectiveTensor(torch.Tensor):
+    r"""
+    A Tensor subclass that is only used in eager mode, to hold a 'work' object
+    and then wait on it before invoking a real op.
+
+    Usage, from inside functional collective:
+    def functional_collective(input):
+        input = input.clone()
+        mutated_input, work = c10d.{inplace_collective}(input)
+        return AsyncCollectiveTensor(mutated_input, work)
+    """
+    _tensor: torch.Tensor
+
+    __torch_function__ = _disabled_torch_function_impl
+
+    @staticmethod
+    def __new__(cls, tensor: torch.Tensor):
+        t = tensor
+        r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
+        r._tensor = tensor  # type: ignore[attr-defined]
+        return r
+
+    def __repr__(self):
+        return f"AsyncCollectiveTensor({self._tensor})"
+
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        def unwrap(e: Any):
+            if isinstance(e, AsyncCollectiveTensor):
+                return wait_tensor(e._tensor)
+            return e
+
+        unwrapped_args = tree_map(unwrap, args)
+        unwrapped_kwargs = tree_map(unwrap, kwargs)
+
+        out = func(*unwrapped_args, **unwrapped_kwargs)
+        return out
+
+def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp:
+    reduceOp = reduceOp.upper()
+    op = dist.ReduceOp.RedOpType.__members__.get(reduceOp)
+    if op is None:
+        raise ValueError(f"Invalid reduce operation {reduceOp}")
+    return cast(dist.ReduceOp, op)
+
+# TODO assert if ranks has duplicated entries
+def _all_reduce(self, reduceOp, tag, ranks, group_size):
+    op = _str_to_reduce_op(reduceOp)
+    group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+    assert group is not None
+
+    inplace_tensor = self.clone()
+    work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
+    _register_tensor_work(inplace_tensor, work)
+
+    return inplace_tensor
+
+c10_lib_cpu = torch.library.Library("aten", "IMPL", "CPU")
+c10_lib_cuda = torch.library.Library("aten", "IMPL", "CUDA")
+
+c10_lib_cpu.impl("all_reduce", _all_reduce)
+c10_lib_cuda.impl("all_reduce", _all_reduce)
+
+c10_lib_cpu.impl("wait_tensor", _wait_tensor)
+c10_lib_cuda.impl("wait_tensor", _wait_tensor)
+
+
+RANK_TYPES = Union[List[int], List[List[int]], dist.ProcessGroup, "dist._tensor.DeviceMesh", Tuple["dist._tensor.DeviceMesh", int]]
+
+def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
+    # Cannot import on the top level to avoid circular imports
+    import torch.distributed._tensor as dt
+    rankset: List[int]
+    if isinstance(group, list):
+        if isinstance(group[0], list):
+            nested_list = cast(List[List[int]], group)
+            rankset = []
+            group_size = -1
+            for rs in nested_list:
+                rankset.extend(rs)
+                if group_size != -1 and group_size != len(rs):
+                    raise ValueError(f"group sizes must be identical found {group_size} and {len(rs)}")
+                group_size = len(rs)
+        else:
+            rankset = cast(List[int], group)
+            group_size = len(rankset)
+    elif isinstance(group, dist.ProcessGroup):
+        rankset = dist.get_process_group_ranks(group)
+        group_size = len(rankset)
+        tag = tag or c10d._get_group_tag(group)
+    elif isinstance(group, dt.DeviceMesh):
+        rankset = group.mesh.flatten().tolist()
+        group_size = group.mesh.size(0)
+        rankset = group.mesh.swapdims(-1, 0).reshape(-1, group_size).flatten().tolist()
+        tag = tag or c10d._get_group_tag(group.get_dim_groups()[0])
+    elif isinstance(group, tuple):
+        if len(group) == 2 and isinstance(group[0], dt.DeviceMesh) and isinstance(group[1], int):
+            dmesh = group[0]
+            dim = group[1]
+            group_size = dmesh.mesh.size(dim)
+            rankset = dmesh.mesh.swapdims(-1, dim).reshape(-1, group_size).flatten().tolist()
+            tag = tag or c10d._get_group_tag(dmesh.get_dim_groups()[dim])
+        else:
+            raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
+    else:
+        raise ValueError("Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int).")
+
+    return (tag, rankset, group_size)
+
+
+def wait_tensor(tensor):
+    """
+    Wait on a tensor returned by the collectives ops.
+
+    Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
+    """
+    return torch._C._nn.wait_tensor(tensor)  # type: ignore[attr-defined]
+
+
+def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
+    """
+    Reduces the tensor data across all machines in such a way that all get
+    the final result.
+
+    The input tensor is left unmodified.
+
+    Group can be one of:
+        List[int]: ranks participating in the collective.
+        List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+        ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+        DeviceMesh: Do a SPMD collective over all ranks of the mesh
+        (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+    :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+    that information and perform collective algebraic optimization. Use other forms of input for that.
+    """
+    tag, rankset, group_size = _expand_group(group, tag)
+    tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size)  # type: ignore[attr-defined]
+    res = AsyncCollectiveTensor(tensor)
+    _register_wrapper_tensor(res, tensor)
+    return res
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index be0006d..98fefed 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -10,7 +10,7 @@
 import warnings
 from collections import namedtuple
 from datetime import timedelta
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union, List
 
 import torch
 from torch._C._distributed_c10d import (
@@ -298,6 +298,8 @@
 # For a pg, it is a map from ProcessGroup to BackendConfig
 _pg_backend_config: Dict[ProcessGroup, str] = {}
 _group_count = 0
+_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
+_pg_to_tag: Dict[ProcessGroup, str] = {}
 
 class _World:
     """
@@ -380,6 +382,15 @@
         global _group_count
         _group_count = value
 
+    @property
+    def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
+        global _tags_to_pg
+        return _tags_to_pg
+
+    @property
+    def pg_to_tag(self) -> Dict[ProcessGroup, str]:
+        global _pg_to_tag
+        return _pg_to_tag
 
 _world = _World()
 """Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@@ -900,7 +911,7 @@
             store,
             pg_options=pg_options,
             group_name=group_name,
-            timeout=timeout,
+            timeout=timeout
         )
         _update_default_pg(default_pg)
 
@@ -929,6 +940,7 @@
     pg_options=None,
     group_name=None,
     timeout=default_pg_timeout,
+    pg_tag=None
 ):
     """
     Create a new distributed process group.
@@ -956,6 +968,12 @@
             "Expected timeout argument to be of type" "datetime.timedelta"
         )
 
+    if pg_tag not in [None, ""]:
+        # creating with the same tag and rank set results in the same underlying PG
+        existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
+        if existing_group:
+            return existing_group
+
     # The list of group ranks is empty if we're creating the default group.
     is_default_group = len(global_ranks_in_group) == 0
 
@@ -1084,8 +1102,16 @@
     _world.pg_map[pg] = (backend, prefix_store)
     _world.pg_names[pg] = group_name
     _world.pg_backend_config[pg] = str(backend_config)
-    return pg
+    # "" is the default tag for user PGs
+    if pg_tag in [None, ""]:
+        pg_tag = f"ptd:{group_name}"
+        _world.tags_to_pg.setdefault("", []).append(pg)
+    else:
+        pg_tag = f"user:{pg_tag}"
 
+    _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
+    _world.pg_to_tag[pg] = pg_tag
+    return pg
 
 def destroy_process_group(group: Optional[ProcessGroup] = None):
     """
@@ -3460,7 +3486,15 @@
     Returns:
         A handle of distributed group that can be given to collective calls.
     """
+    return _new_group_with_tag(ranks, timeout, backend, pg_options)
 
+def _new_group_with_tag(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None, pg_tag=None):
+    """
+    This is a variant of ``new_group`` that exposes tag creation.
+
+    :: N.B. The mechanism is experimental and tied to the functional collectives effort, see
+    ``torch.distributed._functional_collectives`` for reference on how to use it.
+    """
     global _world
 
     default_pg = _get_default_group()
@@ -3510,6 +3544,7 @@
             default_store,
             pg_options=pg_options,
             timeout=timeout,
+            pg_tag=pg_tag
         )
 
     # Create the global rank to group rank mapping
@@ -3767,3 +3802,53 @@
                 logger.info("Rank {} is assigned to subgroup {}".format(rank, ranks))
 
     return cur_subgroup, subgroups
+
+
+def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> ProcessGroup:
+    if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
+        tag = f"user:{tag}"
+
+    for group in _world.tags_to_pg.get(tag, []):
+        if group.size() != len(ranks):
+            continue
+
+        group_ranks = get_process_group_ranks(group)
+        good = all(r in group_ranks for r in ranks)
+        if good:
+            return group
+    return None
+
+def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup:
+    assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
+
+    my_rank = get_rank()
+    my_ranks = None
+
+    if stride == len(ranks):
+        my_ranks = ranks.copy()
+        assert my_rank in my_ranks, "rankset doesn't include the current node"
+    else:
+        for i in range(0, len(ranks), stride):
+            rank_set = ranks[i : i + stride]
+            if my_rank in rank_set:
+                my_ranks = rank_set
+        assert my_ranks is not None, "rankset doesn't include the current node"
+
+    my_ranks.sort()
+
+    pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
+    if pg is not None:
+        return pg
+    if tag == "":
+        raise ValueError("Cannot automatically create PG with empty tag")
+    # TODO copy settings and timeout from default PG
+    return _new_group_with_tag(my_ranks, pg_tag=tag)
+
+def _get_group_tag(pg: ProcessGroup) -> str:
+    """
+    Returns the tag associated with ``pg``.
+    """
+    tag = _world.pg_to_tag[pg]
+    if tag.startswith("user:"):
+        tag = tag[5:]
+    return tag
diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py
index 6b83d2d..c089103 100644
--- a/torch/testing/_internal/distributed/multi_threaded_pg.py
+++ b/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -1,7 +1,7 @@
 import sys
 import threading
 from dataclasses import dataclass
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
 
 import torch
 import torch.distributed as dist
@@ -297,14 +297,15 @@
     pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
     pg_backend_config: Dict[dist.ProcessGroup, str]
     group_count: int
-
+    tags_to_pg: Dict[str, List[dist.ProcessGroup]]
+    pg_to_tag: Dict[dist.ProcessGroup, str]
 
 class ThreadLocalWorld:
     _world = threading.local()
 
     def _get_world(self) -> WorldData:
         if not hasattr(ThreadLocalWorld._world, "world"):
-            ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0)
+            ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {})
         return ThreadLocalWorld._world.world
 
     @property
@@ -339,6 +340,14 @@
     def group_count(self, value):
         self._get_world().group_count = value
 
+    @property
+    def tags_to_pg(self):
+        return self._get_world().tags_to_pg
+
+    @property
+    def pg_to_tag(self):
+        return self._get_world().pg_to_tag
+
 
 _old_pg_world = None