[DeviceMesh] Move DeviceMesh out from torch.distributed._tensor (#112364)
Move DeviceMesh out as a standalone module. Once we make sure everything is migrated and doc is ready, we will make `torch.distributed._device_mesh` public in follow-up PRs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112364
Approved by: https://github.com/wanchaol, https://github.com/fegin, https://github.com/fduwjj
diff --git a/.ci/pytorch/multigpu-test.sh b/.ci/pytorch/multigpu-test.sh
index f3bf768..70ae4d2 100755
--- a/.ci/pytorch/multigpu-test.sh
+++ b/.ci/pytorch/multigpu-test.sh
@@ -36,10 +36,12 @@
# DTensor tests
-time python test/run_test.py --verbose -i distributed/_tensor/test_device_mesh
time python test/run_test.py --verbose -i distributed/_tensor/test_random_ops
time python test/run_test.py --verbose -i distributed/_tensor/test_dtensor_compile
+# DeviceMesh test
+time python test/run_test.py --verbose -i distributed/test_device_mesh
+
# DTensor/TP tests
time python test/run_test.py --verbose -i distributed/tensor/parallel/test_ddp_2d_parallel
time python test/run_test.py --verbose -i distributed/tensor/parallel/test_fsdp_2d_parallel
diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/test_device_mesh.py
similarity index 100%
rename from test/distributed/_tensor/test_device_mesh.py
rename to test/distributed/test_device_mesh.py
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 08e4a17..546d88b 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -290,6 +290,7 @@
"torch.backends._coreml.preprocess",
"torch.contrib._tensorboard_vis",
"torch.distributed._composable",
+ "torch.distributed._device_mesh",
"torch.distributed._functional_collectives",
"torch.distributed._functional_collectives_impl",
"torch.distributed._shard",
diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py
index c9eadc1..f5196d3 100644
--- a/torch/_dynamo/skipfiles.py
+++ b/torch/_dynamo/skipfiles.py
@@ -183,6 +183,7 @@
LEGACY_MOD_INLINELIST |= {
"torch.distributed._tensor.api",
"torch.distributed._tensor.device_mesh",
+ "torch.distributed._device_mesh",
"torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
"torch.distributed.tensor.parallel._data_parallel_utils",
"torch.distributed.tensor.parallel._utils",
diff --git a/torch/distributed/_device_mesh.py b/torch/distributed/_device_mesh.py
new file mode 100644
index 0000000..04c82d1
--- /dev/null
+++ b/torch/distributed/_device_mesh.py
@@ -0,0 +1,454 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import logging
+import math
+from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
+
+import torch
+import torch.distributed._functional_collectives as funcol
+
+from torch.distributed.distributed_c10d import (
+ _find_pg_by_ranks_and_tag,
+ _get_default_group,
+ _get_group_tag,
+ get_rank,
+ get_world_size,
+ init_process_group,
+ is_initialized,
+ new_group,
+ ProcessGroup,
+)
+
+
+logger = logging.getLogger(__name__)
+
+# only import numpy typing when type checking
+if TYPE_CHECKING:
+ try:
+ from numpy.typing import ArrayLike
+ except ImportError:
+ logger.warning(
+ "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
+ )
+
+
+class _MeshEnv:
+ def __init__(self) -> None:
+ self.mesh_stack: List[DeviceMesh] = []
+ self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
+
+ def get_current_mesh(self) -> "DeviceMesh":
+ if len(self.mesh_stack) == 0:
+ raise RuntimeError("No device mesh is currently active!")
+ return self.mesh_stack[-1]
+
+ def create_child_mesh(
+ self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
+ ) -> "DeviceMesh":
+ # swap the current dim to the last dim then reshape to flatten out other
+ # dims, so we can just extract the list of ranks which contains cur_rank.
+ cur_rank = device_mesh.get_rank()
+ pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
+ -1, device_mesh.mesh.size(mesh_dim)
+ )
+
+ for mesh_1d in pg_ranks_by_dim:
+ sub_mesh = DeviceMesh(
+ device_mesh.device_type,
+ mesh_1d,
+ mesh_dim_names=(mesh_dim_name,),
+ _init_process_groups=False,
+ )
+ if cur_rank in mesh_1d:
+ res_sub_mesh = sub_mesh
+
+ res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
+ # Assign the current DeviceMesh as the parent of the child DeviceMesh.
+ self.child_to_parent_mapping[res_sub_mesh] = device_mesh
+ return res_sub_mesh
+
+ def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
+ return self.child_to_parent_mapping.get(device_mesh, None)
+
+ def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
+ """
+ Return the index of the mesh dim in the parent mesh.
+ The device_mesh passed in needs to be sliced out from a parent mesh.
+ """
+ parent_mesh = self.get_parent_mesh(device_mesh)
+ child_mesh_dim_names = device_mesh.mesh_dim_names
+ if parent_mesh and child_mesh_dim_names:
+ assert (
+ len(child_mesh_dim_names) == 1
+ ), "The child mesh can only be a 1D mesh."
+ child_mesh_dim_name = child_mesh_dim_names[0]
+ if parent_mesh.mesh_dim_names:
+ return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
+ return None
+
+ @staticmethod
+ def num_devices_per_host(device_type: str) -> int:
+ return _get_device_handle(device_type).device_count()
+
+ @staticmethod
+ def num_hosts(device_type: str) -> int:
+ # ProcessGroup can't tell us this info so we have to infer it, assume
+ # homogeneous hardware for now
+ return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
+
+
+_mesh_resources: _MeshEnv = _MeshEnv()
+
+
+def _get_device_handle(device_type: str = "cuda"):
+ """
+ Get the module corresponding to the device_type which is cuda or cuda-like device.
+ For example, when the device_type is cuda, the module `torch.cuda` is returned.
+ Return None when there is no corresponding module for device_type, otherwise
+ return the corresponding module.
+ """
+ return getattr(torch, device_type, None)
+
+
+class DeviceMesh:
+ """
+ DeviceMesh represents a mesh of devices, where layout of devices could be
+ represented as a n-d dimension array, and each value of the n-d dimensional
+ array is the global id of the default process group ranks.
+
+ DeviceMesh could be used to describe the layout of devices across the cluster,
+ and serves as a proxy for communication among the device lists within the cluster.
+
+ We use the default ProcessGroup in this DeviceMesh class to implement proper
+ communications. Note that we also add collective wrappers in this class. This is
+ used to decouple detailed communication backend with the underlying
+ DTensor implementation.
+
+ DeviceMesh can be used as a context manager.
+ Args:
+ device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
+ mesh (ndarray): could be a multi-dimension array or an integer tensor that
+ describes the layout of devices, the ids are global ids of the
+ default process group.
+
+ Returns:
+ A :class:`DeviceMesh` object
+
+ Example (2 host with 4 GPUs each):
+ ```
+ # The following program runs on each process/rank in SPMD manner.
+ # initialize device mesh as (2, 4) to represent the topology
+ # of cross-host(dim 0), and within-host (dim 1)
+ mesh = DeviceMesh(device_type="cuda",
+ mesh=[
+ [0, 1, 2, 3],
+ [4, 5, 6, 7]
+ ])
+ ```
+ A reduction over the first dimension of mesh will reduce across
+ columns (0, 4), .. and (3, 7), a reduction over the second dimension
+ of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
+
+ """
+
+ device_type: str
+ mesh: torch.Tensor
+ mesh_dim_names: Optional[Tuple[str, ...]]
+
+ def __init__(
+ self,
+ device_type: str,
+ mesh: Union[torch.Tensor, "ArrayLike"],
+ *,
+ mesh_dim_names: Optional[Tuple[str, ...]] = None,
+ _init_process_groups: bool = True,
+ _validate_mesh: bool = True,
+ ) -> None:
+ self.device_type = device_type
+ self.mesh = (
+ mesh.detach()
+ if isinstance(mesh, torch.Tensor)
+ else torch.tensor(mesh, dtype=torch.int)
+ )
+ self.mesh_dim_names = mesh_dim_names
+
+ # private field to pre-generate DeviceMesh's hash
+ self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
+ self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
+
+ # Skip process group initialization if xla device.
+ # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
+ if device_type != "xla":
+ # always try to create default (world) pg, even if it is not initialized
+ # already. The world pg is used for device mesh identity (rank) on each
+ # process (we need to know if the current global rank is in the mesh or not).
+ self._get_or_create_default_group()
+ if _init_process_groups:
+ self._init_process_groups(_validate_mesh)
+
+ def _get_or_create_default_group(self):
+ default_initialized = is_initialized()
+ if not default_initialized:
+ init_process_group()
+
+ world_size = get_world_size()
+ if self.mesh.numel() > world_size:
+ raise RuntimeError(
+ f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
+ )
+
+ device_handle = _get_device_handle(self.device_type)
+ # TODO: if user want to pass pg_options, offer a way to do it
+ if not default_initialized and device_handle:
+ # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
+ # NOTE: This device selection would only work for homogeneous hardware.
+ num_devices_per_host = device_handle.device_count()
+ if (
+ world_size > num_devices_per_host
+ and world_size % num_devices_per_host != 0
+ ):
+ raise RuntimeError(
+ f"DeviceMesh only support homogeneous hardware, but found "
+ f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
+ )
+ device_handle.set_device(get_rank() % num_devices_per_host)
+
+ # calculate the coordinates of the current global rank on the mesh
+ rank_coords = (self.mesh == get_rank()).nonzero()
+ assert rank_coords.size(0) in (0, 1)
+ self._coordinate_on_dim: Optional[List[int]] = (
+ rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
+ )
+ return _get_default_group()
+
+ def _validate_mesh(self):
+ # check mesh tensor validity
+ unique_mesh_values = self.mesh.unique(sorted=True)
+ if unique_mesh_values.numel() != self.mesh.numel():
+ raise RuntimeError(
+ f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
+ )
+
+ # validate that all calling ranks pass in the same `mesh` argument.
+ self_mesh = self.mesh.to(self.device_type).contiguous()
+ mesh_tensor = funcol.all_gather_tensor(
+ self_mesh, gather_dim=0, group=_get_default_group()
+ )
+ mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size())
+ for other_rank, other_mesh in enumerate(mesh_tensor_chunked):
+ if not torch.equal(self_mesh, other_mesh):
+ raise RuntimeError(
+ f"DeviceMesh initialization does not allow different mesh argument:"
+ f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}"
+ f"has mesh {other_mesh}!"
+ )
+
+ def _init_process_groups(self, _validate_mesh):
+ if _validate_mesh:
+ self._validate_mesh()
+
+ # group tag/ranks associated with each mesh dimension, each mesh dimension should
+ # have one sub-group per rank
+ dim_group_infos: List[Tuple[str, List[int]]] = []
+
+ if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
+ # if the mesh is the same as world_pg, we just append the default
+ # pg to the first dim groups, as new_group cannot have the exact
+ # same ranks as world
+ dim_group_infos.append(
+ (_get_group_tag(_get_default_group()), list(range(get_world_size())))
+ )
+ else:
+ # create sub pgs base on the mesh argument specified
+ for dim in range(self.mesh.ndim):
+ # swap the current dim to the last dim
+ # then reshape to flatten out other dims
+ pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
+ -1, self.mesh.size(dim)
+ )
+ # multi-dim mesh, create subgroups by looping over the pg_ranks
+ # for each dim and append the groups
+ for dim_mesh in pg_ranks_by_dim:
+ subgroup_ranks = dim_mesh.tolist()
+ # call new_group regardless of the current rank in the
+ # pg or not, it's required that all ranks participate
+ # in subgroup construction
+ dim_group = new_group(ranks=subgroup_ranks)
+ # only add to dim_groups if the current rank in the subgroup
+ if self.get_rank() in subgroup_ranks:
+ if len(dim_group_infos) > dim:
+ raise RuntimeError(
+ f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
+ f"in {subgroup_ranks}!"
+ )
+ dim_group_infos.append(
+ (_get_group_tag(dim_group), subgroup_ranks)
+ )
+ self._dim_group_infos = dim_group_infos
+
+ def __enter__(self) -> "DeviceMesh":
+ # set this mesh as the current mesh in mesh env
+ _mesh_resources.mesh_stack.append(self)
+ return self
+
+ # pyre-fixme[2]: Parameter must be annotated.
+ def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
+ # pop this mesh from mesh env
+ _mesh_resources.mesh_stack.pop()
+
+ def __repr__(self) -> str:
+ return f"DeviceMesh:({self.mesh.tolist()})"
+
+ def __hash__(self):
+ return self._hash
+
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, DeviceMesh):
+ return False
+ if id(self.mesh) == id(other.mesh):
+ return True
+ return (
+ self.mesh.shape == other.mesh.shape
+ and self._flatten_mesh_list == other._flatten_mesh_list
+ )
+
+ def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
+ """
+ Slice the current DeviceMesh based on the mesh_dim_name given to create a child
+ DeviceMesh.
+
+ Args:
+ mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
+ to create a child DeviceMesh for.
+ Returns:
+ A :class:`DeviceMesh` object
+
+ Example (2 host with 4 GPUs each):
+ ```
+ # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
+ mesh = DeviceMesh(device_type="cuda",
+ mesh=[
+ [0, 1, 2, 3],
+ [4, 5, 6, 7]
+ ],
+ mesh_dim_names=["dp", "tp"])
+ )
+ ```
+ Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
+ Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
+ Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
+ Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
+ Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
+ Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
+ """
+ if self.mesh.ndim <= 1:
+ raise RuntimeError(
+ f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
+ )
+ if self.mesh_dim_names is None:
+ raise KeyError(
+ "No `mesh_dim_names` found.",
+ "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
+ )
+ if mesh_dim_name not in self.mesh_dim_names:
+ raise KeyError(
+ f"Mesh dimension '{mesh_dim_name}' does not exist.",
+ f"Available mesh dimensions are: {self.mesh_dim_names}",
+ )
+ mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
+ submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
+
+ return submesh
+
+ def get_dim_groups(
+ self, mesh_dim: Optional[int] = None
+ ) -> Union[ProcessGroup, List[ProcessGroup]]:
+ if not hasattr(self, "_dim_group_infos"):
+ raise RuntimeError("DeviceMesh process groups not initialized!")
+ if mesh_dim is not None:
+ return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
+ else:
+ dim_groups = []
+ for mesh_dim in range(self.mesh.ndim):
+ dim_groups.append(
+ _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
+ )
+ return dim_groups
+
+ def size(self, dim: Optional[int] = None) -> int:
+ return self.mesh.numel() if dim is None else self.mesh.size(dim)
+
+ @property
+ def ndim(self) -> int:
+ return self.mesh.ndim
+
+ @property
+ def shape(self) -> Tuple[int, ...]:
+ return tuple(self.mesh.shape)
+
+ def get_rank(self) -> int:
+ return get_rank()
+
+ def get_coordinate(self) -> Optional[List[int]]:
+ """
+ Return the relative indices of this rank relative to all
+ dimensions of the mesh. If this rank is not part of the mesh, return None.
+ """
+ return self._coordinate_on_dim if self._coordinate_on_dim else None
+
+
+def init_device_mesh(
+ device_type: str,
+ mesh_shape: Tuple[int, ...],
+ *,
+ mesh_dim_names: Optional[Tuple[str, ...]] = None,
+) -> DeviceMesh:
+ """
+ Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
+ This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
+ and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
+ labeled as mesh_dim_names[i].
+
+
+ Args:
+ device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
+ mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
+ that describes the layout of devices.
+ Kwargs:
+ mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
+ of the multi-dimensional array that describes the layout of devices. Its length must match the length
+ of `mesh_shape`. Each string in mesh_dim_names must be unique.
+
+ Returns:
+ A :class:`DeviceMesh` object
+
+ .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
+ behind the scene, which are required for distributed communications.
+
+ Example:
+ >>> # xdoctest: +SKIP
+ >>> from torch.distributed._tensor.device_mesh import init_device_mesh
+ >>>
+ >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
+ >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
+ """
+ if mesh_dim_names is not None:
+ if len(set(mesh_dim_names)) != len(mesh_dim_names):
+ raise RuntimeError(
+ "Each mesh_dim_name must be uqique.",
+ f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
+ )
+
+ if len(mesh_shape) != len(mesh_dim_names):
+ raise RuntimeError(
+ "mesh_shape and mesh_dim_names should have same length!",
+ f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
+ )
+
+ mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
+ device_mesh = DeviceMesh(
+ device_type=device_type,
+ mesh=mesh,
+ mesh_dim_names=mesh_dim_names,
+ )
+
+ return device_mesh
diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py
index 04c82d1..7489cc0 100644
--- a/torch/distributed/_tensor/device_mesh.py
+++ b/torch/distributed/_tensor/device_mesh.py
@@ -1,454 +1,6 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-import logging
-import math
-from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
-
-import torch
-import torch.distributed._functional_collectives as funcol
-
-from torch.distributed.distributed_c10d import (
- _find_pg_by_ranks_and_tag,
- _get_default_group,
- _get_group_tag,
- get_rank,
- get_world_size,
- init_process_group,
- is_initialized,
- new_group,
- ProcessGroup,
+from torch.distributed._device_mesh import ( # noqa: F401
+ _get_device_handle,
+ _mesh_resources,
+ DeviceMesh,
+ init_device_mesh,
)
-
-
-logger = logging.getLogger(__name__)
-
-# only import numpy typing when type checking
-if TYPE_CHECKING:
- try:
- from numpy.typing import ArrayLike
- except ImportError:
- logger.warning(
- "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
- )
-
-
-class _MeshEnv:
- def __init__(self) -> None:
- self.mesh_stack: List[DeviceMesh] = []
- self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
-
- def get_current_mesh(self) -> "DeviceMesh":
- if len(self.mesh_stack) == 0:
- raise RuntimeError("No device mesh is currently active!")
- return self.mesh_stack[-1]
-
- def create_child_mesh(
- self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
- ) -> "DeviceMesh":
- # swap the current dim to the last dim then reshape to flatten out other
- # dims, so we can just extract the list of ranks which contains cur_rank.
- cur_rank = device_mesh.get_rank()
- pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
- -1, device_mesh.mesh.size(mesh_dim)
- )
-
- for mesh_1d in pg_ranks_by_dim:
- sub_mesh = DeviceMesh(
- device_mesh.device_type,
- mesh_1d,
- mesh_dim_names=(mesh_dim_name,),
- _init_process_groups=False,
- )
- if cur_rank in mesh_1d:
- res_sub_mesh = sub_mesh
-
- res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
- # Assign the current DeviceMesh as the parent of the child DeviceMesh.
- self.child_to_parent_mapping[res_sub_mesh] = device_mesh
- return res_sub_mesh
-
- def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
- return self.child_to_parent_mapping.get(device_mesh, None)
-
- def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
- """
- Return the index of the mesh dim in the parent mesh.
- The device_mesh passed in needs to be sliced out from a parent mesh.
- """
- parent_mesh = self.get_parent_mesh(device_mesh)
- child_mesh_dim_names = device_mesh.mesh_dim_names
- if parent_mesh and child_mesh_dim_names:
- assert (
- len(child_mesh_dim_names) == 1
- ), "The child mesh can only be a 1D mesh."
- child_mesh_dim_name = child_mesh_dim_names[0]
- if parent_mesh.mesh_dim_names:
- return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
- return None
-
- @staticmethod
- def num_devices_per_host(device_type: str) -> int:
- return _get_device_handle(device_type).device_count()
-
- @staticmethod
- def num_hosts(device_type: str) -> int:
- # ProcessGroup can't tell us this info so we have to infer it, assume
- # homogeneous hardware for now
- return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
-
-
-_mesh_resources: _MeshEnv = _MeshEnv()
-
-
-def _get_device_handle(device_type: str = "cuda"):
- """
- Get the module corresponding to the device_type which is cuda or cuda-like device.
- For example, when the device_type is cuda, the module `torch.cuda` is returned.
- Return None when there is no corresponding module for device_type, otherwise
- return the corresponding module.
- """
- return getattr(torch, device_type, None)
-
-
-class DeviceMesh:
- """
- DeviceMesh represents a mesh of devices, where layout of devices could be
- represented as a n-d dimension array, and each value of the n-d dimensional
- array is the global id of the default process group ranks.
-
- DeviceMesh could be used to describe the layout of devices across the cluster,
- and serves as a proxy for communication among the device lists within the cluster.
-
- We use the default ProcessGroup in this DeviceMesh class to implement proper
- communications. Note that we also add collective wrappers in this class. This is
- used to decouple detailed communication backend with the underlying
- DTensor implementation.
-
- DeviceMesh can be used as a context manager.
- Args:
- device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
- mesh (ndarray): could be a multi-dimension array or an integer tensor that
- describes the layout of devices, the ids are global ids of the
- default process group.
-
- Returns:
- A :class:`DeviceMesh` object
-
- Example (2 host with 4 GPUs each):
- ```
- # The following program runs on each process/rank in SPMD manner.
- # initialize device mesh as (2, 4) to represent the topology
- # of cross-host(dim 0), and within-host (dim 1)
- mesh = DeviceMesh(device_type="cuda",
- mesh=[
- [0, 1, 2, 3],
- [4, 5, 6, 7]
- ])
- ```
- A reduction over the first dimension of mesh will reduce across
- columns (0, 4), .. and (3, 7), a reduction over the second dimension
- of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
-
- """
-
- device_type: str
- mesh: torch.Tensor
- mesh_dim_names: Optional[Tuple[str, ...]]
-
- def __init__(
- self,
- device_type: str,
- mesh: Union[torch.Tensor, "ArrayLike"],
- *,
- mesh_dim_names: Optional[Tuple[str, ...]] = None,
- _init_process_groups: bool = True,
- _validate_mesh: bool = True,
- ) -> None:
- self.device_type = device_type
- self.mesh = (
- mesh.detach()
- if isinstance(mesh, torch.Tensor)
- else torch.tensor(mesh, dtype=torch.int)
- )
- self.mesh_dim_names = mesh_dim_names
-
- # private field to pre-generate DeviceMesh's hash
- self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
- self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
-
- # Skip process group initialization if xla device.
- # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
- if device_type != "xla":
- # always try to create default (world) pg, even if it is not initialized
- # already. The world pg is used for device mesh identity (rank) on each
- # process (we need to know if the current global rank is in the mesh or not).
- self._get_or_create_default_group()
- if _init_process_groups:
- self._init_process_groups(_validate_mesh)
-
- def _get_or_create_default_group(self):
- default_initialized = is_initialized()
- if not default_initialized:
- init_process_group()
-
- world_size = get_world_size()
- if self.mesh.numel() > world_size:
- raise RuntimeError(
- f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
- )
-
- device_handle = _get_device_handle(self.device_type)
- # TODO: if user want to pass pg_options, offer a way to do it
- if not default_initialized and device_handle:
- # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
- # NOTE: This device selection would only work for homogeneous hardware.
- num_devices_per_host = device_handle.device_count()
- if (
- world_size > num_devices_per_host
- and world_size % num_devices_per_host != 0
- ):
- raise RuntimeError(
- f"DeviceMesh only support homogeneous hardware, but found "
- f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
- )
- device_handle.set_device(get_rank() % num_devices_per_host)
-
- # calculate the coordinates of the current global rank on the mesh
- rank_coords = (self.mesh == get_rank()).nonzero()
- assert rank_coords.size(0) in (0, 1)
- self._coordinate_on_dim: Optional[List[int]] = (
- rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
- )
- return _get_default_group()
-
- def _validate_mesh(self):
- # check mesh tensor validity
- unique_mesh_values = self.mesh.unique(sorted=True)
- if unique_mesh_values.numel() != self.mesh.numel():
- raise RuntimeError(
- f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
- )
-
- # validate that all calling ranks pass in the same `mesh` argument.
- self_mesh = self.mesh.to(self.device_type).contiguous()
- mesh_tensor = funcol.all_gather_tensor(
- self_mesh, gather_dim=0, group=_get_default_group()
- )
- mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size())
- for other_rank, other_mesh in enumerate(mesh_tensor_chunked):
- if not torch.equal(self_mesh, other_mesh):
- raise RuntimeError(
- f"DeviceMesh initialization does not allow different mesh argument:"
- f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}"
- f"has mesh {other_mesh}!"
- )
-
- def _init_process_groups(self, _validate_mesh):
- if _validate_mesh:
- self._validate_mesh()
-
- # group tag/ranks associated with each mesh dimension, each mesh dimension should
- # have one sub-group per rank
- dim_group_infos: List[Tuple[str, List[int]]] = []
-
- if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
- # if the mesh is the same as world_pg, we just append the default
- # pg to the first dim groups, as new_group cannot have the exact
- # same ranks as world
- dim_group_infos.append(
- (_get_group_tag(_get_default_group()), list(range(get_world_size())))
- )
- else:
- # create sub pgs base on the mesh argument specified
- for dim in range(self.mesh.ndim):
- # swap the current dim to the last dim
- # then reshape to flatten out other dims
- pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
- -1, self.mesh.size(dim)
- )
- # multi-dim mesh, create subgroups by looping over the pg_ranks
- # for each dim and append the groups
- for dim_mesh in pg_ranks_by_dim:
- subgroup_ranks = dim_mesh.tolist()
- # call new_group regardless of the current rank in the
- # pg or not, it's required that all ranks participate
- # in subgroup construction
- dim_group = new_group(ranks=subgroup_ranks)
- # only add to dim_groups if the current rank in the subgroup
- if self.get_rank() in subgroup_ranks:
- if len(dim_group_infos) > dim:
- raise RuntimeError(
- f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
- f"in {subgroup_ranks}!"
- )
- dim_group_infos.append(
- (_get_group_tag(dim_group), subgroup_ranks)
- )
- self._dim_group_infos = dim_group_infos
-
- def __enter__(self) -> "DeviceMesh":
- # set this mesh as the current mesh in mesh env
- _mesh_resources.mesh_stack.append(self)
- return self
-
- # pyre-fixme[2]: Parameter must be annotated.
- def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
- # pop this mesh from mesh env
- _mesh_resources.mesh_stack.pop()
-
- def __repr__(self) -> str:
- return f"DeviceMesh:({self.mesh.tolist()})"
-
- def __hash__(self):
- return self._hash
-
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, DeviceMesh):
- return False
- if id(self.mesh) == id(other.mesh):
- return True
- return (
- self.mesh.shape == other.mesh.shape
- and self._flatten_mesh_list == other._flatten_mesh_list
- )
-
- def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
- """
- Slice the current DeviceMesh based on the mesh_dim_name given to create a child
- DeviceMesh.
-
- Args:
- mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
- to create a child DeviceMesh for.
- Returns:
- A :class:`DeviceMesh` object
-
- Example (2 host with 4 GPUs each):
- ```
- # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
- mesh = DeviceMesh(device_type="cuda",
- mesh=[
- [0, 1, 2, 3],
- [4, 5, 6, 7]
- ],
- mesh_dim_names=["dp", "tp"])
- )
- ```
- Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
- Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
- Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
- Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
- Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
- Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
- """
- if self.mesh.ndim <= 1:
- raise RuntimeError(
- f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
- )
- if self.mesh_dim_names is None:
- raise KeyError(
- "No `mesh_dim_names` found.",
- "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
- )
- if mesh_dim_name not in self.mesh_dim_names:
- raise KeyError(
- f"Mesh dimension '{mesh_dim_name}' does not exist.",
- f"Available mesh dimensions are: {self.mesh_dim_names}",
- )
- mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
- submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
-
- return submesh
-
- def get_dim_groups(
- self, mesh_dim: Optional[int] = None
- ) -> Union[ProcessGroup, List[ProcessGroup]]:
- if not hasattr(self, "_dim_group_infos"):
- raise RuntimeError("DeviceMesh process groups not initialized!")
- if mesh_dim is not None:
- return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
- else:
- dim_groups = []
- for mesh_dim in range(self.mesh.ndim):
- dim_groups.append(
- _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
- )
- return dim_groups
-
- def size(self, dim: Optional[int] = None) -> int:
- return self.mesh.numel() if dim is None else self.mesh.size(dim)
-
- @property
- def ndim(self) -> int:
- return self.mesh.ndim
-
- @property
- def shape(self) -> Tuple[int, ...]:
- return tuple(self.mesh.shape)
-
- def get_rank(self) -> int:
- return get_rank()
-
- def get_coordinate(self) -> Optional[List[int]]:
- """
- Return the relative indices of this rank relative to all
- dimensions of the mesh. If this rank is not part of the mesh, return None.
- """
- return self._coordinate_on_dim if self._coordinate_on_dim else None
-
-
-def init_device_mesh(
- device_type: str,
- mesh_shape: Tuple[int, ...],
- *,
- mesh_dim_names: Optional[Tuple[str, ...]] = None,
-) -> DeviceMesh:
- """
- Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
- This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
- and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
- labeled as mesh_dim_names[i].
-
-
- Args:
- device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
- mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
- that describes the layout of devices.
- Kwargs:
- mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
- of the multi-dimensional array that describes the layout of devices. Its length must match the length
- of `mesh_shape`. Each string in mesh_dim_names must be unique.
-
- Returns:
- A :class:`DeviceMesh` object
-
- .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
- behind the scene, which are required for distributed communications.
-
- Example:
- >>> # xdoctest: +SKIP
- >>> from torch.distributed._tensor.device_mesh import init_device_mesh
- >>>
- >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
- >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
- """
- if mesh_dim_names is not None:
- if len(set(mesh_dim_names)) != len(mesh_dim_names):
- raise RuntimeError(
- "Each mesh_dim_name must be uqique.",
- f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
- )
-
- if len(mesh_shape) != len(mesh_dim_names):
- raise RuntimeError(
- "mesh_shape and mesh_dim_names should have same length!",
- f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
- )
-
- mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
- device_mesh = DeviceMesh(
- device_type=device_type,
- mesh=mesh,
- mesh_dim_names=mesh_dim_names,
- )
-
- return device_mesh