| # Copyright (c) Meta Platforms, Inc. and affiliates |
| import logging |
| import math |
| from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union |
| |
| import torch |
| import torch.distributed._functional_collectives as funcol |
| |
| from torch.distributed.distributed_c10d import ( |
| _find_pg_by_ranks_and_tag, |
| _get_default_group, |
| _get_group_tag, |
| get_rank, |
| get_world_size, |
| init_process_group, |
| is_initialized, |
| new_group, |
| ProcessGroup, |
| ) |
| |
| |
| logger = logging.getLogger(__name__) |
| |
| # only import numpy typing when type checking |
| if TYPE_CHECKING: |
| try: |
| from numpy.typing import ArrayLike |
| except ImportError: |
| logger.warning( |
| "DeviceMesh requires numpy >= 1.21 to be installed for type checking" |
| ) |
| |
| |
| class _MeshEnv: |
| def __init__(self) -> None: |
| self.mesh_stack: List[DeviceMesh] = [] |
| self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {} |
| |
| def get_current_mesh(self) -> "DeviceMesh": |
| if len(self.mesh_stack) == 0: |
| raise RuntimeError("No device mesh is currently active!") |
| return self.mesh_stack[-1] |
| |
| def create_child_mesh( |
| self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str |
| ) -> "DeviceMesh": |
| # swap the current dim to the last dim then reshape to flatten out other |
| # dims, so we can just extract the list of ranks which contains cur_rank. |
| cur_rank = device_mesh.get_rank() |
| pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape( |
| -1, device_mesh.mesh.size(mesh_dim) |
| ) |
| |
| for mesh_1d in pg_ranks_by_dim: |
| sub_mesh = DeviceMesh( |
| device_mesh.device_type, |
| mesh_1d, |
| mesh_dim_names=(mesh_dim_name,), |
| _init_process_groups=False, |
| ) |
| if cur_rank in mesh_1d: |
| res_sub_mesh = sub_mesh |
| |
| res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] |
| # Assign the current DeviceMesh as the parent of the child DeviceMesh. |
| self.child_to_parent_mapping[res_sub_mesh] = device_mesh |
| return res_sub_mesh |
| |
| def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: |
| return self.child_to_parent_mapping.get(device_mesh, None) |
| |
| def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: |
| """ |
| Return the index of the mesh dim in the parent mesh. |
| The device_mesh passed in needs to be sliced out from a parent mesh. |
| """ |
| parent_mesh = self.get_parent_mesh(device_mesh) |
| child_mesh_dim_names = device_mesh.mesh_dim_names |
| if parent_mesh and child_mesh_dim_names: |
| assert ( |
| len(child_mesh_dim_names) == 1 |
| ), "The child mesh can only be a 1D mesh." |
| child_mesh_dim_name = child_mesh_dim_names[0] |
| if parent_mesh.mesh_dim_names: |
| return parent_mesh.mesh_dim_names.index(child_mesh_dim_name) |
| return None |
| |
| @staticmethod |
| def num_devices_per_host(device_type: str) -> int: |
| return _get_device_handle(device_type).device_count() |
| |
| @staticmethod |
| def num_hosts(device_type: str) -> int: |
| # ProcessGroup can't tell us this info so we have to infer it, assume |
| # homogeneous hardware for now |
| return get_world_size() // _MeshEnv.num_devices_per_host(device_type) |
| |
| |
| _mesh_resources: _MeshEnv = _MeshEnv() |
| |
| |
| def _get_device_handle(device_type: str = "cuda"): |
| """ |
| Get the module corresponding to the device_type which is cuda or cuda-like device. |
| For example, when the device_type is cuda, the module `torch.cuda` is returned. |
| Return None when there is no corresponding module for device_type, otherwise |
| return the corresponding module. |
| """ |
| return getattr(torch, device_type, None) |
| |
| |
| class DeviceMesh: |
| """ |
| DeviceMesh represents a mesh of devices, where layout of devices could be |
| represented as a n-d dimension array, and each value of the n-d dimensional |
| array is the global id of the default process group ranks. |
| |
| DeviceMesh could be used to describe the layout of devices across the cluster, |
| and serves as a proxy for communication among the device lists within the cluster. |
| |
| We use the default ProcessGroup in this DeviceMesh class to implement proper |
| communications. Note that we also add collective wrappers in this class. This is |
| used to decouple detailed communication backend with the underlying |
| DTensor implementation. |
| |
| DeviceMesh can be used as a context manager. |
| Args: |
| device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like. |
| mesh (ndarray): could be a multi-dimension array or an integer tensor that |
| describes the layout of devices, the ids are global ids of the |
| default process group. |
| |
| Returns: |
| A :class:`DeviceMesh` object |
| |
| Example (2 host with 4 GPUs each): |
| ``` |
| # The following program runs on each process/rank in SPMD manner. |
| # initialize device mesh as (2, 4) to represent the topology |
| # of cross-host(dim 0), and within-host (dim 1) |
| mesh = DeviceMesh(device_type="cuda", |
| mesh=[ |
| [0, 1, 2, 3], |
| [4, 5, 6, 7] |
| ]) |
| ``` |
| A reduction over the first dimension of mesh will reduce across |
| columns (0, 4), .. and (3, 7), a reduction over the second dimension |
| of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7) |
| |
| """ |
| |
| device_type: str |
| mesh: torch.Tensor |
| mesh_dim_names: Optional[Tuple[str, ...]] |
| |
| def __init__( |
| self, |
| device_type: str, |
| mesh: Union[torch.Tensor, "ArrayLike"], |
| *, |
| mesh_dim_names: Optional[Tuple[str, ...]] = None, |
| _init_process_groups: bool = True, |
| _validate_mesh: bool = True, |
| ) -> None: |
| self.device_type = device_type |
| self.mesh = ( |
| mesh.detach() |
| if isinstance(mesh, torch.Tensor) |
| else torch.tensor(mesh, dtype=torch.int) |
| ) |
| self.mesh_dim_names = mesh_dim_names |
| |
| # private field to pre-generate DeviceMesh's hash |
| self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) |
| self._hash = hash((self._flatten_mesh_list, self.mesh.shape)) |
| |
| # Skip process group initialization if xla device. |
| # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. |
| if device_type != "xla": |
| # always try to create default (world) pg, even if it is not initialized |
| # already. The world pg is used for device mesh identity (rank) on each |
| # process (we need to know if the current global rank is in the mesh or not). |
| self._get_or_create_default_group() |
| if _init_process_groups: |
| self._init_process_groups(_validate_mesh) |
| |
| def _get_or_create_default_group(self): |
| default_initialized = is_initialized() |
| if not default_initialized: |
| init_process_group() |
| |
| world_size = get_world_size() |
| if self.mesh.numel() > world_size: |
| raise RuntimeError( |
| f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!" |
| ) |
| |
| device_handle = _get_device_handle(self.device_type) |
| # TODO: if user want to pass pg_options, offer a way to do it |
| if not default_initialized and device_handle: |
| # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host |
| # NOTE: This device selection would only work for homogeneous hardware. |
| num_devices_per_host = device_handle.device_count() |
| if ( |
| world_size > num_devices_per_host |
| and world_size % num_devices_per_host != 0 |
| ): |
| raise RuntimeError( |
| f"DeviceMesh only support homogeneous hardware, but found " |
| f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" |
| ) |
| device_handle.set_device(get_rank() % num_devices_per_host) |
| |
| # calculate the coordinates of the current global rank on the mesh |
| rank_coords = (self.mesh == get_rank()).nonzero() |
| assert rank_coords.size(0) in (0, 1) |
| self._coordinate_on_dim: Optional[List[int]] = ( |
| rank_coords[0].tolist() if rank_coords.size(0) > 0 else None |
| ) |
| return _get_default_group() |
| |
| def _validate_mesh(self): |
| # check mesh tensor validity |
| unique_mesh_values = self.mesh.unique(sorted=True) |
| if unique_mesh_values.numel() != self.mesh.numel(): |
| raise RuntimeError( |
| f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}" |
| ) |
| |
| # validate that all calling ranks pass in the same `mesh` argument. |
| self_mesh = self.mesh.to(self.device_type).contiguous() |
| mesh_tensor = funcol.all_gather_tensor( |
| self_mesh, gather_dim=0, group=_get_default_group() |
| ) |
| mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size()) |
| for other_rank, other_mesh in enumerate(mesh_tensor_chunked): |
| if not torch.equal(self_mesh, other_mesh): |
| raise RuntimeError( |
| f"DeviceMesh initialization does not allow different mesh argument:" |
| f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}" |
| f"has mesh {other_mesh}!" |
| ) |
| |
| def _init_process_groups(self, _validate_mesh): |
| if _validate_mesh: |
| self._validate_mesh() |
| |
| # group tag/ranks associated with each mesh dimension, each mesh dimension should |
| # have one sub-group per rank |
| dim_group_infos: List[Tuple[str, List[int]]] = [] |
| |
| if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): |
| # if the mesh is the same as world_pg, we just append the default |
| # pg to the first dim groups, as new_group cannot have the exact |
| # same ranks as world |
| dim_group_infos.append( |
| (_get_group_tag(_get_default_group()), list(range(get_world_size()))) |
| ) |
| else: |
| # create sub pgs base on the mesh argument specified |
| for dim in range(self.mesh.ndim): |
| # swap the current dim to the last dim |
| # then reshape to flatten out other dims |
| pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( |
| -1, self.mesh.size(dim) |
| ) |
| # multi-dim mesh, create subgroups by looping over the pg_ranks |
| # for each dim and append the groups |
| for dim_mesh in pg_ranks_by_dim: |
| subgroup_ranks = dim_mesh.tolist() |
| # call new_group regardless of the current rank in the |
| # pg or not, it's required that all ranks participate |
| # in subgroup construction |
| dim_group = new_group(ranks=subgroup_ranks) |
| # only add to dim_groups if the current rank in the subgroup |
| if self.get_rank() in subgroup_ranks: |
| if len(dim_group_infos) > dim: |
| raise RuntimeError( |
| f"Each device mesh dimension should get only one process group, but got {self.get_rank} " |
| f"in {subgroup_ranks}!" |
| ) |
| dim_group_infos.append( |
| (_get_group_tag(dim_group), subgroup_ranks) |
| ) |
| self._dim_group_infos = dim_group_infos |
| |
| def __enter__(self) -> "DeviceMesh": |
| # set this mesh as the current mesh in mesh env |
| _mesh_resources.mesh_stack.append(self) |
| return self |
| |
| # pyre-fixme[2]: Parameter must be annotated. |
| def __exit__(self, exc_type, exc_value, exc_traceback) -> None: |
| # pop this mesh from mesh env |
| _mesh_resources.mesh_stack.pop() |
| |
| def __repr__(self) -> str: |
| return f"DeviceMesh:({self.mesh.tolist()})" |
| |
| def __hash__(self): |
| return self._hash |
| |
| def __eq__(self, other: object) -> bool: |
| if not isinstance(other, DeviceMesh): |
| return False |
| if id(self.mesh) == id(other.mesh): |
| return True |
| return ( |
| self.mesh.shape == other.mesh.shape |
| and self._flatten_mesh_list == other._flatten_mesh_list |
| ) |
| |
| def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": |
| """ |
| Slice the current DeviceMesh based on the mesh_dim_name given to create a child |
| DeviceMesh. |
| |
| Args: |
| mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh |
| to create a child DeviceMesh for. |
| Returns: |
| A :class:`DeviceMesh` object |
| |
| Example (2 host with 4 GPUs each): |
| ``` |
| # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp") |
| mesh = DeviceMesh(device_type="cuda", |
| mesh=[ |
| [0, 1, 2, 3], |
| [4, 5, 6, 7] |
| ], |
| mesh_dim_names=["dp", "tp"]) |
| ) |
| ``` |
| Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). |
| Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). |
| Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). |
| Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). |
| Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). |
| Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). |
| """ |
| if self.mesh.ndim <= 1: |
| raise RuntimeError( |
| f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension." |
| ) |
| if self.mesh_dim_names is None: |
| raise KeyError( |
| "No `mesh_dim_names` found.", |
| "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.", |
| ) |
| if mesh_dim_name not in self.mesh_dim_names: |
| raise KeyError( |
| f"Mesh dimension '{mesh_dim_name}' does not exist.", |
| f"Available mesh dimensions are: {self.mesh_dim_names}", |
| ) |
| mesh_dim = self.mesh_dim_names.index(mesh_dim_name) |
| submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) |
| |
| return submesh |
| |
| def get_dim_groups( |
| self, mesh_dim: Optional[int] = None |
| ) -> Union[ProcessGroup, List[ProcessGroup]]: |
| if not hasattr(self, "_dim_group_infos"): |
| raise RuntimeError("DeviceMesh process groups not initialized!") |
| if mesh_dim is not None: |
| return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim]) |
| else: |
| dim_groups = [] |
| for mesh_dim in range(self.mesh.ndim): |
| dim_groups.append( |
| _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim]) |
| ) |
| return dim_groups |
| |
| def size(self, dim: Optional[int] = None) -> int: |
| return self.mesh.numel() if dim is None else self.mesh.size(dim) |
| |
| @property |
| def ndim(self) -> int: |
| return self.mesh.ndim |
| |
| @property |
| def shape(self) -> Tuple[int, ...]: |
| return tuple(self.mesh.shape) |
| |
| def get_rank(self) -> int: |
| return get_rank() |
| |
| def get_coordinate(self) -> Optional[List[int]]: |
| """ |
| Return the relative indices of this rank relative to all |
| dimensions of the mesh. If this rank is not part of the mesh, return None. |
| """ |
| return self._coordinate_on_dim if self._coordinate_on_dim else None |
| |
| |
| def init_device_mesh( |
| device_type: str, |
| mesh_shape: Tuple[int, ...], |
| *, |
| mesh_dim_names: Optional[Tuple[str, ...]] = None, |
| ) -> DeviceMesh: |
| """ |
| Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. |
| This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape) |
| and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is |
| labeled as mesh_dim_names[i]. |
| |
| |
| Args: |
| device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like. |
| mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array |
| that describes the layout of devices. |
| Kwargs: |
| mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension |
| of the multi-dimensional array that describes the layout of devices. Its length must match the length |
| of `mesh_shape`. Each string in mesh_dim_names must be unique. |
| |
| Returns: |
| A :class:`DeviceMesh` object |
| |
| .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups |
| behind the scene, which are required for distributed communications. |
| |
| Example: |
| >>> # xdoctest: +SKIP |
| >>> from torch.distributed._tensor.device_mesh import init_device_mesh |
| >>> |
| >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) |
| >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) |
| """ |
| if mesh_dim_names is not None: |
| if len(set(mesh_dim_names)) != len(mesh_dim_names): |
| raise RuntimeError( |
| "Each mesh_dim_name must be uqique.", |
| f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", |
| ) |
| |
| if len(mesh_shape) != len(mesh_dim_names): |
| raise RuntimeError( |
| "mesh_shape and mesh_dim_names should have same length!", |
| f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", |
| ) |
| |
| mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape) |
| device_mesh = DeviceMesh( |
| device_type=device_type, |
| mesh=mesh, |
| mesh_dim_names=mesh_dim_names, |
| ) |
| |
| return device_mesh |