[PTD] Introduce tracing friendly collectives. (#93990)
This change adds torch.distributed.traceable_collectives.
This experimental API enables collectives to be fully traced by dynamo and FX.
See #93173 for the RFC
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93990
Approved by: https://github.com/wconstab, https://github.com/wanchaol, https://github.com/H-Huang
diff --git a/aten/src/ATen/native/Collectives.cpp b/aten/src/ATen/native/Collectives.cpp
new file mode 100644
index 0000000..44e1399
--- /dev/null
+++ b/aten/src/ATen/native/Collectives.cpp
@@ -0,0 +1,29 @@
+#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
+
+#include <ATen/core/Tensor.h>
+#include <ATen/Parallel.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
+#endif
+
+namespace at {
+namespace native {
+
+// Dummy impl required by codegen infra, not used
+at::Tensor all_reduce(at::Tensor const& self, const c10::string_view reduceOp, const c10::string_view tag, c10::ArrayRef<int64_t> ranks, int64_t group_size) {
+ // This should never get called
+ // Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
+ TORCH_INTERNAL_ASSERT(false);
+}
+
+at::Tensor wait_tensor(at::Tensor const& self) {
+ // This should never get called
+ // Defer to python impls in torch/distributed/_functional_collectives.py and _meta_registrations.py
+ TORCH_INTERNAL_ASSERT(false);
+}
+
+} // namespace native
+} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2b1ffb3..522cdcc 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -14670,3 +14670,18 @@
dispatch:
CUDA: _fused_adamw_kernel_cuda_
autogen: _fused_adamw, _fused_adamw.out
+
+# Collectives
+- func: all_reduce(Tensor self, str reduceOp, str tag, int[] ranks, int group_size) -> Tensor
+ # This should be changed to distributed but it requires changes all over the place to work
+ python_module: nn
+ dispatch:
+ CompositeExplicitAutograd: all_reduce
+ variants: function
+
+- func: wait_tensor(Tensor self) -> Tensor
+ # This should be changed to distributed but it requires changes all over the place to work
+ python_module: nn
+ dispatch:
+ CompositeExplicitAutograd: wait_tensor
+ variants: function
diff --git a/build_variables.bzl b/build_variables.bzl
index 59e21c3..f5a465a 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -1231,6 +1231,7 @@
"aten/src/ATen/native/Bucketization.cpp",
"aten/src/ATen/native/CPUBlas.cpp",
"aten/src/ATen/native/ChanelShuffle.cpp",
+ "aten/src/ATen/native/Collectives.cpp",
"aten/src/ATen/native/Col2Im.cpp",
"aten/src/ATen/native/PadNd.cpp",
"aten/src/ATen/native/Convolution.cpp",
diff --git a/test/distributed/test_functional_api.py b/test/distributed/test_functional_api.py
new file mode 100644
index 0000000..8ccff34
--- /dev/null
+++ b/test/distributed/test_functional_api.py
@@ -0,0 +1,269 @@
+# Owner(s): ["oncall: distributed"]
+
+import sys
+import torch
+import torch.distributed as dist
+import torch.distributed._functional_collectives as ft_c
+import torch.distributed.distributed_c10d as c10d
+import torch.distributed._tensor as dt
+
+from functorch import make_fx
+
+if not dist.is_available():
+ print("Distributed not available, skipping tests", file=sys.stderr)
+ sys.exit(0)
+
+from torch.testing._internal.common_distributed import (
+ MultiThreadedTestCase,
+)
+from torch.testing._internal.common_utils import (
+ run_tests,
+ TestCase
+)
+
+def new_subgroups(group_size: int, pg_tag=None):
+ world_size = dist.get_world_size()
+ subgroups = []
+ cur_subgroup = None
+
+ for subgroup_id in range(world_size // group_size):
+ start_rank = subgroup_id * group_size
+ end_rank = start_rank + group_size
+ ranks_in_subgroup = list(range(start_rank, end_rank))
+ subgroup = c10d._new_group_with_tag(
+ ranks=ranks_in_subgroup,
+ pg_tag=pg_tag,
+ )
+ subgroups.append(subgroup)
+
+ rank = dist.get_rank()
+ if rank in ranks_in_subgroup:
+ cur_subgroup = subgroup
+
+ return cur_subgroup, subgroups
+
+
+class TestExpand(MultiThreadedTestCase):
+ @property
+ def world_size(self):
+ return 4
+
+ def setUp(self):
+ super().setUp()
+ self._spawn_threads()
+
+ def test_expand_1d_rank_list(self):
+ tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
+ self.assertEqual("", tag)
+ self.assertEqual([0, 1, 2, 3], rankset)
+ self.assertEqual(4, group_size)
+
+ tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
+ self.assertEqual("bla", tag)
+
+ def test_expand_2d_rank_list(self):
+ tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
+ self.assertEqual("", tag)
+ self.assertEqual([0, 1, 2, 3], rankset)
+ self.assertEqual(2, group_size)
+
+ tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
+ self.assertEqual("blu", tag)
+
+ with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
+ ft_c._expand_group([[0], [1, 2, 3]])
+
+ def test_expand_process_group(self):
+ tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
+ self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
+ self.assertEqual([0, 1, 2, 3], rankset)
+ self.assertEqual(4, group_size)
+
+ tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
+ self.assertEqual("bla", tag)
+
+ my_pg, others = new_subgroups(group_size=2)
+ tag, rankset, group_size = ft_c._expand_group(my_pg)
+ self.assertEqual(c10d._get_group_tag(my_pg), tag)
+ self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
+ self.assertEqual(2, group_size)
+
+ my_pg = None
+ for i in range(dist.get_world_size()):
+ group = c10d._new_group_with_tag([i], pg_tag="my_pg")
+ if i == dist.get_rank():
+ my_pg = group
+ tag, rankset, group_size = ft_c._expand_group(my_pg)
+ self.assertEqual("my_pg", tag)
+ self.assertEqual([dist.get_rank()], rankset)
+ self.assertEqual(1, group_size)
+
+ tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
+ self.assertEqual("bla", tag)
+
+ def test_expand_device_mesh(self):
+ mesh = dt.DeviceMesh("cpu", torch.arange(4))
+ tag, rankset, group_size = ft_c._expand_group(mesh)
+ self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+ self.assertEqual([0, 1, 2, 3], rankset)
+ self.assertEqual(4, group_size)
+
+ mesh = dt.DeviceMesh("cpu", torch.arange(4))
+ tag, rankset, group_size = ft_c._expand_group(mesh)
+ self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+ self.assertEqual([0, 1, 2, 3], rankset)
+ self.assertEqual(4, group_size)
+
+ def test_expand_device_mesh_tuple(self):
+ mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
+ tag, rankset, group_size = ft_c._expand_group(mesh)
+ self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+ self.assertEqual([0, 2, 1, 3], rankset)
+ self.assertEqual(2, group_size)
+
+ tag, rankset, group_size = ft_c._expand_group((mesh, 0))
+ self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[0]), tag)
+ self.assertEqual([0, 2, 1, 3], rankset)
+ self.assertEqual(2, group_size)
+
+ tag, rankset, group_size = ft_c._expand_group((mesh, 1))
+ self.assertEqual(c10d._get_group_tag(mesh.get_dim_groups()[1]), tag)
+ self.assertEqual([0, 1, 2, 3], rankset)
+ self.assertEqual(2, group_size)
+
+class TestPgTag(MultiThreadedTestCase):
+ @property
+ def world_size(self):
+ return 4
+
+ def setUp(self):
+ super().setUp()
+ self._spawn_threads()
+
+ """
+ The behavior we want is as follow:
+
+ - rankset+tag will always result in the same PG.
+ Do we enforce this by failing creation of new PGs or returning existing ones?
+ Return existing one.
+
+ - default tag gives existing behavior.
+ This means we should create duplicates.
+ - _expand_group on _default-tagged pg should always resolve to it
+ This mean we can't depend on empty tag + rankset.
+ """
+ def test_pg_creation_with_tag(self):
+ my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
+ my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
+ self.assertEqual(my_group, my_group2)
+
+ my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
+ self.assertNotEqual(my_group, my_group3)
+
+ my_group4, _ = new_subgroups(group_size=2)
+ self.assertNotEqual(my_group, my_group4)
+
+ my_group5, _ = new_subgroups(group_size=2)
+ self.assertNotEqual(my_group4, my_group5)
+
+ def test_pg_lookup_roundtrip(self):
+ pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
+ pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
+ pg_notag0, _ = new_subgroups(group_size=2)
+ pg_notag1, _ = new_subgroups(group_size=2)
+
+ def roundtrip(pg):
+ tag, rankset, _ = ft_c._expand_group(pg)
+ return c10d._find_pg_by_ranks_and_tag(tag, rankset)
+
+ self.assertEqual(pg_tag0, roundtrip(pg_tag0))
+ self.assertEqual(pg_tag1, roundtrip(pg_tag1))
+ self.assertEqual(pg_notag0, roundtrip(pg_notag0))
+ self.assertEqual(pg_notag1, roundtrip(pg_notag1))
+
+ def test_pg_lookup_with_tag(self):
+ pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
+ pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
+ pg_notag0, _ = new_subgroups(group_size=2)
+
+ def roundtrip(pg, pg_tag):
+ tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
+ return c10d._find_pg_by_ranks_and_tag(tag, rankset)
+
+ self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
+ self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
+ # Cannot erase the tag of a PG
+ self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
+
+ def test_find_or_create_pg(self):
+ pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
+ pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
+ self.assertEqual(pg, pg_tag0)
+
+ def test_find_root_pg(self):
+ pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
+ self.assertEqual(dist.group.WORLD, pg)
+
+class TestTraceableCollectives(MultiThreadedTestCase):
+ @property
+ def world_size(self):
+ return 4
+
+ def setUp(self):
+ super().setUp()
+ self._spawn_threads()
+
+ def test_all_reduce_eager(self):
+ tensor = torch.ones([4])
+ mesh = dt.DeviceMesh("cpu", torch.arange(4))
+
+ res = ft_c.all_reduce(tensor, "sum", mesh)
+ self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
+
+ mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
+ res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
+ self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
+
+class TestMetaCollectives(TestCase):
+ def test_all_reduce(self):
+ x = torch.rand((2, 3, 4), device="meta")
+ out = ft_c.all_reduce(x, "sum", [1])
+ self.assertEqual(x.size(), out.size())
+
+class TestGradCollectives(MultiThreadedTestCase):
+ @property
+ def world_size(self):
+ return 2
+
+ def setUp(self):
+ super().setUp()
+ self._spawn_threads()
+
+ def test_all_reduce(self):
+ x = torch.rand([4], requires_grad=True)
+ y = torch.rand([4], requires_grad=True)
+ out = ft_c.all_reduce(x, "sum", [0, 1])
+ (out + y).sum().backward()
+ self.assertIsNone(x.grad)
+
+class TestMakeFx(MultiThreadedTestCase):
+ @property
+ def world_size(self):
+ return 2
+
+ def setUp(self):
+ super().setUp()
+ self._spawn_threads()
+
+ def test_all_reduce_tracing(self):
+ def allred(input):
+ return ft_c.all_reduce(input, "sum", group=[0, 1]) + 1
+
+ graph = make_fx(allred)(torch.rand(4))
+ nodes = list(graph.graph.nodes)
+
+ self.assertEqual("aten::all_reduce", nodes[1].target.name())
+ self.assertEqual("aten::wait_tensor", nodes[2].target.name())
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index d5a174e..debcf63 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -575,6 +575,7 @@
aten::affine_grid_generator.out
aten::alias_copy
aten::alias_copy.out
+aten::all_reduce
aten::allclose
aten::aminmax
aten::aminmax.out
@@ -1339,6 +1340,7 @@
aten::view_copy.dtype
aten::view_copy.dtype_out
aten::view_copy.out
+aten::wait_tensor
aten::zeros.names
aten::zeros.names_out
aten::zeros.out
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 3ad1866..837c12b 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -2702,4 +2702,14 @@
_meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
+@register_meta(aten.all_reduce)
+def all_reduce_meta(self, reduceOp, tag, rankset, stride):
+ return torch.empty_like(self)
+
+
+@register_meta(aten.wait_tensor)
+def wait_tensor_meta(self):
+ return torch.empty_like(self)
+
+
activate_meta()
diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py
new file mode 100644
index 0000000..8af8f5f
--- /dev/null
+++ b/torch/distributed/_functional_collectives.py
@@ -0,0 +1,237 @@
+from typing import Any, Tuple, Union, List, cast
+
+import weakref
+import warnings
+
+import torch
+import torch.distributed as dist
+
+from torch._C import _disabled_torch_function_impl
+from torch.utils._pytree import tree_map
+
+import torch.distributed.distributed_c10d as c10d
+"""
+New traceable, functional collectives.
+RFC: https://github.com/pytorch/pytorch/issues/93173
+
+ compiler: trace these ops with plain-old-data schemas, then choose how to lower them.
+ eager: execute these 'functional' ops which in eager return AsyncCollectiveTensor subclasses,
+ automatically calling .wait() on underlying/hidden async 'work' obj only when fed to
+ a downstream op.
+
+Issues:
+* Where should these ops live? Couldn't `import torch` if putting these ops in existing torch.distributed files
+* Proper support for eager requires inplace ops. We should explore having it as an option for the API.
+"""
+
+"""
+Functional collectives are asynchronous only and we perform implicit stream synchronization
+on behalf of the user.
+
+We use AsyncCollectiveTensor to wrap the result tensor of a collective and it lets us witness
+first usage of the tensor and insert cross stream sync at the right place.
+
+The above are the easy bits, the hard one is how we match the Work object returned by
+c10d and the tensor AsyncCollectiveTensor wraps. We alloc the tensor inside the collective
+op implementation (see ``clone()`` call in ``_all_reduce``) and then it's handled by the
+dispatcher which might call other implementations that are allowed to change the returned
+tensor - even return a tensor with a different shape (see ``torch.vmap``).
+
+This means the caller of our ops receives a Tensor that is not guaranteed to be the same
+allocated by our implementations and that makes pairing The AsyncTensor to the original
+tensor a lot harder. This pairing is needed so we can lookup the Work object to use.
+
+Originally, we tried WeakKeyDictionary to map from Tensor to Work, but because Tensor's
+identity is not stable across dispatch, the op caller would end up with a different Tensor
+instance that would not match any in the dictionary.
+
+With Tensor identity out of the question, we decided use the tensor data pointer, which
+should be stable across all the Tensor changes done during dispatch.
+
+We have a dictionary of tensor::data_ptr -> Work that we insert right after we call into c10d.
+
+We use this dictionary when AsyncCollectiveTensor is used to invoke Work::wait()
+
+Finally, we setup a finalizer against the tensor wrapper to observe it getting collected so we
+can clean up stale entries in the dictionary.
+
+To eliminate the possiblity of races we have a global version counter that is used by the finalizer.
+
+As a wise man said once: Don't cross the streams (https://www.youtube.com/watch?v=wyKQe_i9yyo)
+
+"""
+data_ptr_to_work = dict()
+work_version = 0
+
+def _register_tensor_work(tensor, work):
+ global data_ptr_to_work
+ global work_version
+ data_ptr_to_work[tensor.data_ptr()] = (work_version, work)
+ work_version += 1
+
+def _clear_tensor(data_ptr, version):
+ global data_ptr_to_work
+ version_and_work = data_ptr_to_work.get(data_ptr)
+
+ if version_and_work is not None and version_and_work[0] == version:
+ del data_ptr_to_work[data_ptr]
+
+def _register_wrapper_tensor(tensor_wrapper, tensor):
+ global data_ptr_to_work
+ version, _ = data_ptr_to_work.get(tensor.data_ptr(), (None, None))
+ if version is None:
+ warnings.warn("Trying to register finalizers to AsyncCollectiveTensor but the inner tensor is already gone")
+ else:
+ weakref.finalize(tensor_wrapper, _clear_tensor, tensor.data_ptr(), version)
+
+def _wait_tensor(tensor: torch.Tensor) -> torch.Tensor:
+ global data_ptr_to_work
+ data_ptr = tensor.data_ptr()
+ version_and_work = data_ptr_to_work.get(data_ptr)
+ if version_and_work is not None:
+ version_and_work[1].wait()
+ _clear_tensor(data_ptr, version_and_work[0])
+ return tensor
+
+
+class AsyncCollectiveTensor(torch.Tensor):
+ r"""
+ A Tensor subclass that is only used in eager mode, to hold a 'work' object
+ and then wait on it before invoking a real op.
+
+ Usage, from inside functional collective:
+ def functional_collective(input):
+ input = input.clone()
+ mutated_input, work = c10d.{inplace_collective}(input)
+ return AsyncCollectiveTensor(mutated_input, work)
+ """
+ _tensor: torch.Tensor
+
+ __torch_function__ = _disabled_torch_function_impl
+
+ @staticmethod
+ def __new__(cls, tensor: torch.Tensor):
+ t = tensor
+ r = torch.Tensor._make_subclass(cls, t, require_grad=t.requires_grad)
+ r._tensor = tensor # type: ignore[attr-defined]
+ return r
+
+ def __repr__(self):
+ return f"AsyncCollectiveTensor({self._tensor})"
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ def unwrap(e: Any):
+ if isinstance(e, AsyncCollectiveTensor):
+ return wait_tensor(e._tensor)
+ return e
+
+ unwrapped_args = tree_map(unwrap, args)
+ unwrapped_kwargs = tree_map(unwrap, kwargs)
+
+ out = func(*unwrapped_args, **unwrapped_kwargs)
+ return out
+
+def _str_to_reduce_op(reduceOp: str) -> dist.ReduceOp:
+ reduceOp = reduceOp.upper()
+ op = dist.ReduceOp.RedOpType.__members__.get(reduceOp)
+ if op is None:
+ raise ValueError(f"Invalid reduce operation {reduceOp}")
+ return cast(dist.ReduceOp, op)
+
+# TODO assert if ranks has duplicated entries
+def _all_reduce(self, reduceOp, tag, ranks, group_size):
+ op = _str_to_reduce_op(reduceOp)
+ group = c10d._find_or_create_pg_by_ranks_and_tag(tag, ranks, group_size)
+ assert group is not None
+
+ inplace_tensor = self.clone()
+ work = dist.all_reduce(inplace_tensor, op=op, group=group, async_op=True)
+ _register_tensor_work(inplace_tensor, work)
+
+ return inplace_tensor
+
+c10_lib_cpu = torch.library.Library("aten", "IMPL", "CPU")
+c10_lib_cuda = torch.library.Library("aten", "IMPL", "CUDA")
+
+c10_lib_cpu.impl("all_reduce", _all_reduce)
+c10_lib_cuda.impl("all_reduce", _all_reduce)
+
+c10_lib_cpu.impl("wait_tensor", _wait_tensor)
+c10_lib_cuda.impl("wait_tensor", _wait_tensor)
+
+
+RANK_TYPES = Union[List[int], List[List[int]], dist.ProcessGroup, "dist._tensor.DeviceMesh", Tuple["dist._tensor.DeviceMesh", int]]
+
+def _expand_group(group: RANK_TYPES, tag: str = "") -> Tuple[str, List[int], int]:
+ # Cannot import on the top level to avoid circular imports
+ import torch.distributed._tensor as dt
+ rankset: List[int]
+ if isinstance(group, list):
+ if isinstance(group[0], list):
+ nested_list = cast(List[List[int]], group)
+ rankset = []
+ group_size = -1
+ for rs in nested_list:
+ rankset.extend(rs)
+ if group_size != -1 and group_size != len(rs):
+ raise ValueError(f"group sizes must be identical found {group_size} and {len(rs)}")
+ group_size = len(rs)
+ else:
+ rankset = cast(List[int], group)
+ group_size = len(rankset)
+ elif isinstance(group, dist.ProcessGroup):
+ rankset = dist.get_process_group_ranks(group)
+ group_size = len(rankset)
+ tag = tag or c10d._get_group_tag(group)
+ elif isinstance(group, dt.DeviceMesh):
+ rankset = group.mesh.flatten().tolist()
+ group_size = group.mesh.size(0)
+ rankset = group.mesh.swapdims(-1, 0).reshape(-1, group_size).flatten().tolist()
+ tag = tag or c10d._get_group_tag(group.get_dim_groups()[0])
+ elif isinstance(group, tuple):
+ if len(group) == 2 and isinstance(group[0], dt.DeviceMesh) and isinstance(group[1], int):
+ dmesh = group[0]
+ dim = group[1]
+ group_size = dmesh.mesh.size(dim)
+ rankset = dmesh.mesh.swapdims(-1, dim).reshape(-1, group_size).flatten().tolist()
+ tag = tag or c10d._get_group_tag(dmesh.get_dim_groups()[dim])
+ else:
+ raise ValueError("Invalid tuple for group must be (DeviceMesh, int)")
+ else:
+ raise ValueError("Invalid type for group, must be one of List, Processgroup, DeviceMesh or (DeviceMesh, int).")
+
+ return (tag, rankset, group_size)
+
+
+def wait_tensor(tensor):
+ """
+ Wait on a tensor returned by the collectives ops.
+
+ Waiting follows device semantics, which means blocking on CPU and synchronizing streams on CUDA.
+ """
+ return torch._C._nn.wait_tensor(tensor) # type: ignore[attr-defined]
+
+
+def all_reduce(self: torch.Tensor, reduceOp: str, group: RANK_TYPES, tag: str = ""):
+ """
+ Reduces the tensor data across all machines in such a way that all get
+ the final result.
+
+ The input tensor is left unmodified.
+
+ Group can be one of:
+ List[int]: ranks participating in the collective.
+ List[List[int]]: 2D mesh of ranks taking part of this collective in MPMD.
+ ProcessGroup: Will perform a collective using the ranks and tag of the PG.
+ DeviceMesh: Do a SPMD collective over all ranks of the mesh
+ (DeviceMesh, int): Do a MPMD collective over one dimension of the DeviceMesh
+
+ :: N.B. If you pass a PG or a 1D list to perform a MPMD collective, the compiler won't be able to recover
+ that information and perform collective algebraic optimization. Use other forms of input for that.
+ """
+ tag, rankset, group_size = _expand_group(group, tag)
+ tensor = torch._C._nn.all_reduce(self, reduceOp, tag, rankset, group_size) # type: ignore[attr-defined]
+ res = AsyncCollectiveTensor(tensor)
+ _register_wrapper_tensor(res, tensor)
+ return res
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index be0006d..98fefed 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -10,7 +10,7 @@
import warnings
from collections import namedtuple
from datetime import timedelta
-from typing import Any, Dict, Optional, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union, List
import torch
from torch._C._distributed_c10d import (
@@ -298,6 +298,8 @@
# For a pg, it is a map from ProcessGroup to BackendConfig
_pg_backend_config: Dict[ProcessGroup, str] = {}
_group_count = 0
+_tags_to_pg: Dict[str, List[ProcessGroup]] = {}
+_pg_to_tag: Dict[ProcessGroup, str] = {}
class _World:
"""
@@ -380,6 +382,15 @@
global _group_count
_group_count = value
+ @property
+ def tags_to_pg(self) -> Dict[str, List[ProcessGroup]]:
+ global _tags_to_pg
+ return _tags_to_pg
+
+ @property
+ def pg_to_tag(self) -> Dict[ProcessGroup, str]:
+ global _pg_to_tag
+ return _pg_to_tag
_world = _World()
"""Holds the singleton instance of ``_World`` used by c10. Experimental extension point to override it"""
@@ -900,7 +911,7 @@
store,
pg_options=pg_options,
group_name=group_name,
- timeout=timeout,
+ timeout=timeout
)
_update_default_pg(default_pg)
@@ -929,6 +940,7 @@
pg_options=None,
group_name=None,
timeout=default_pg_timeout,
+ pg_tag=None
):
"""
Create a new distributed process group.
@@ -956,6 +968,12 @@
"Expected timeout argument to be of type" "datetime.timedelta"
)
+ if pg_tag not in [None, ""]:
+ # creating with the same tag and rank set results in the same underlying PG
+ existing_group = _find_pg_by_ranks_and_tag(pg_tag, global_ranks_in_group)
+ if existing_group:
+ return existing_group
+
# The list of group ranks is empty if we're creating the default group.
is_default_group = len(global_ranks_in_group) == 0
@@ -1084,8 +1102,16 @@
_world.pg_map[pg] = (backend, prefix_store)
_world.pg_names[pg] = group_name
_world.pg_backend_config[pg] = str(backend_config)
- return pg
+ # "" is the default tag for user PGs
+ if pg_tag in [None, ""]:
+ pg_tag = f"ptd:{group_name}"
+ _world.tags_to_pg.setdefault("", []).append(pg)
+ else:
+ pg_tag = f"user:{pg_tag}"
+ _world.tags_to_pg.setdefault(pg_tag, []).append(pg)
+ _world.pg_to_tag[pg] = pg_tag
+ return pg
def destroy_process_group(group: Optional[ProcessGroup] = None):
"""
@@ -3460,7 +3486,15 @@
Returns:
A handle of distributed group that can be given to collective calls.
"""
+ return _new_group_with_tag(ranks, timeout, backend, pg_options)
+def _new_group_with_tag(ranks=None, timeout=default_pg_timeout, backend=None, pg_options=None, pg_tag=None):
+ """
+ This is a variant of ``new_group`` that exposes tag creation.
+
+ :: N.B. The mechanism is experimental and tied to the functional collectives effort, see
+ ``torch.distributed._functional_collectives`` for reference on how to use it.
+ """
global _world
default_pg = _get_default_group()
@@ -3510,6 +3544,7 @@
default_store,
pg_options=pg_options,
timeout=timeout,
+ pg_tag=pg_tag
)
# Create the global rank to group rank mapping
@@ -3767,3 +3802,53 @@
logger.info("Rank {} is assigned to subgroup {}".format(rank, ranks))
return cur_subgroup, subgroups
+
+
+def _find_pg_by_ranks_and_tag(tag: str, ranks: List[int]) -> ProcessGroup:
+ if len(tag) > 0 and not tag.startswith("ptd:") and not tag.startswith("user:"):
+ tag = f"user:{tag}"
+
+ for group in _world.tags_to_pg.get(tag, []):
+ if group.size() != len(ranks):
+ continue
+
+ group_ranks = get_process_group_ranks(group)
+ good = all(r in group_ranks for r in ranks)
+ if good:
+ return group
+ return None
+
+def _find_or_create_pg_by_ranks_and_tag(tag: str, ranks: List[int], stride: int) -> ProcessGroup:
+ assert len(ranks) % stride == 0, f"Ranks length ({len(ranks)}) must be divisible by stride ({stride})"
+
+ my_rank = get_rank()
+ my_ranks = None
+
+ if stride == len(ranks):
+ my_ranks = ranks.copy()
+ assert my_rank in my_ranks, "rankset doesn't include the current node"
+ else:
+ for i in range(0, len(ranks), stride):
+ rank_set = ranks[i : i + stride]
+ if my_rank in rank_set:
+ my_ranks = rank_set
+ assert my_ranks is not None, "rankset doesn't include the current node"
+
+ my_ranks.sort()
+
+ pg = _find_pg_by_ranks_and_tag(tag, my_ranks)
+ if pg is not None:
+ return pg
+ if tag == "":
+ raise ValueError("Cannot automatically create PG with empty tag")
+ # TODO copy settings and timeout from default PG
+ return _new_group_with_tag(my_ranks, pg_tag=tag)
+
+def _get_group_tag(pg: ProcessGroup) -> str:
+ """
+ Returns the tag associated with ``pg``.
+ """
+ tag = _world.pg_to_tag[pg]
+ if tag.startswith("user:"):
+ tag = tag[5:]
+ return tag
diff --git a/torch/testing/_internal/distributed/multi_threaded_pg.py b/torch/testing/_internal/distributed/multi_threaded_pg.py
index 6b83d2d..c089103 100644
--- a/torch/testing/_internal/distributed/multi_threaded_pg.py
+++ b/torch/testing/_internal/distributed/multi_threaded_pg.py
@@ -1,7 +1,7 @@
import sys
import threading
from dataclasses import dataclass
-from typing import Dict, Optional, Tuple
+from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
@@ -297,14 +297,15 @@
pg_group_ranks: Dict[dist.ProcessGroup, Dict[int, int]]
pg_backend_config: Dict[dist.ProcessGroup, str]
group_count: int
-
+ tags_to_pg: Dict[str, List[dist.ProcessGroup]]
+ pg_to_tag: Dict[dist.ProcessGroup, str]
class ThreadLocalWorld:
_world = threading.local()
def _get_world(self) -> WorldData:
if not hasattr(ThreadLocalWorld._world, "world"):
- ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0)
+ ThreadLocalWorld._world.world = WorldData(None, {}, {}, {}, {}, 0, {}, {})
return ThreadLocalWorld._world.world
@property
@@ -339,6 +340,14 @@
def group_count(self, value):
self._get_world().group_count = value
+ @property
+ def tags_to_pg(self):
+ return self._get_world().tags_to_pg
+
+ @property
+ def pg_to_tag(self):
+ return self._get_world().pg_to_tag
+
_old_pg_world = None