[DeviceMesh] Move DeviceMesh out from torch.distributed._tensor (#112364)

Move DeviceMesh out as a standalone module. Once we make sure everything is migrated and doc is ready, we will make `torch.distributed._device_mesh` public in follow-up PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112364
Approved by: https://github.com/wanchaol, https://github.com/fegin, https://github.com/fduwjj
diff --git a/.ci/pytorch/multigpu-test.sh b/.ci/pytorch/multigpu-test.sh
index f3bf768..70ae4d2 100755
--- a/.ci/pytorch/multigpu-test.sh
+++ b/.ci/pytorch/multigpu-test.sh
@@ -36,10 +36,12 @@
 
 
 # DTensor tests
-time python test/run_test.py --verbose -i distributed/_tensor/test_device_mesh
 time python test/run_test.py --verbose -i distributed/_tensor/test_random_ops
 time python test/run_test.py --verbose -i distributed/_tensor/test_dtensor_compile
 
+# DeviceMesh test
+time python test/run_test.py --verbose -i distributed/test_device_mesh
+
 # DTensor/TP tests
 time python test/run_test.py --verbose -i distributed/tensor/parallel/test_ddp_2d_parallel
 time python test/run_test.py --verbose -i distributed/tensor/parallel/test_fsdp_2d_parallel
diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/test_device_mesh.py
similarity index 100%
rename from test/distributed/_tensor/test_device_mesh.py
rename to test/distributed/test_device_mesh.py
diff --git a/test/test_public_bindings.py b/test/test_public_bindings.py
index 08e4a17..546d88b 100644
--- a/test/test_public_bindings.py
+++ b/test/test_public_bindings.py
@@ -290,6 +290,7 @@
             "torch.backends._coreml.preprocess",
             "torch.contrib._tensorboard_vis",
             "torch.distributed._composable",
+            "torch.distributed._device_mesh",
             "torch.distributed._functional_collectives",
             "torch.distributed._functional_collectives_impl",
             "torch.distributed._shard",
diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py
index c9eadc1..f5196d3 100644
--- a/torch/_dynamo/skipfiles.py
+++ b/torch/_dynamo/skipfiles.py
@@ -183,6 +183,7 @@
     LEGACY_MOD_INLINELIST |= {
         "torch.distributed._tensor.api",
         "torch.distributed._tensor.device_mesh",
+        "torch.distributed._device_mesh",
         "torch.distributed.algorithms._checkpoint.checkpoint_wrapper",
         "torch.distributed.tensor.parallel._data_parallel_utils",
         "torch.distributed.tensor.parallel._utils",
diff --git a/torch/distributed/_device_mesh.py b/torch/distributed/_device_mesh.py
new file mode 100644
index 0000000..04c82d1
--- /dev/null
+++ b/torch/distributed/_device_mesh.py
@@ -0,0 +1,454 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+import logging
+import math
+from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
+
+import torch
+import torch.distributed._functional_collectives as funcol
+
+from torch.distributed.distributed_c10d import (
+    _find_pg_by_ranks_and_tag,
+    _get_default_group,
+    _get_group_tag,
+    get_rank,
+    get_world_size,
+    init_process_group,
+    is_initialized,
+    new_group,
+    ProcessGroup,
+)
+
+
+logger = logging.getLogger(__name__)
+
+# only import numpy typing when type checking
+if TYPE_CHECKING:
+    try:
+        from numpy.typing import ArrayLike
+    except ImportError:
+        logger.warning(
+            "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
+        )
+
+
+class _MeshEnv:
+    def __init__(self) -> None:
+        self.mesh_stack: List[DeviceMesh] = []
+        self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
+
+    def get_current_mesh(self) -> "DeviceMesh":
+        if len(self.mesh_stack) == 0:
+            raise RuntimeError("No device mesh is currently active!")
+        return self.mesh_stack[-1]
+
+    def create_child_mesh(
+        self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
+    ) -> "DeviceMesh":
+        # swap the current dim to the last dim then reshape to flatten out other
+        # dims, so we can just extract the list of ranks which contains cur_rank.
+        cur_rank = device_mesh.get_rank()
+        pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
+            -1, device_mesh.mesh.size(mesh_dim)
+        )
+
+        for mesh_1d in pg_ranks_by_dim:
+            sub_mesh = DeviceMesh(
+                device_mesh.device_type,
+                mesh_1d,
+                mesh_dim_names=(mesh_dim_name,),
+                _init_process_groups=False,
+            )
+            if cur_rank in mesh_1d:
+                res_sub_mesh = sub_mesh
+
+        res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
+        # Assign the current DeviceMesh as the parent of the child DeviceMesh.
+        self.child_to_parent_mapping[res_sub_mesh] = device_mesh
+        return res_sub_mesh
+
+    def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
+        return self.child_to_parent_mapping.get(device_mesh, None)
+
+    def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
+        """
+        Return the index of the mesh dim in the parent mesh.
+        The device_mesh passed in needs to be sliced out from a parent mesh.
+        """
+        parent_mesh = self.get_parent_mesh(device_mesh)
+        child_mesh_dim_names = device_mesh.mesh_dim_names
+        if parent_mesh and child_mesh_dim_names:
+            assert (
+                len(child_mesh_dim_names) == 1
+            ), "The child mesh can only be a 1D mesh."
+            child_mesh_dim_name = child_mesh_dim_names[0]
+            if parent_mesh.mesh_dim_names:
+                return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
+        return None
+
+    @staticmethod
+    def num_devices_per_host(device_type: str) -> int:
+        return _get_device_handle(device_type).device_count()
+
+    @staticmethod
+    def num_hosts(device_type: str) -> int:
+        # ProcessGroup can't tell us this info so we have to infer it, assume
+        # homogeneous hardware for now
+        return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
+
+
+_mesh_resources: _MeshEnv = _MeshEnv()
+
+
+def _get_device_handle(device_type: str = "cuda"):
+    """
+    Get the module corresponding to the device_type which is cuda or cuda-like device.
+    For example, when the device_type is cuda, the module `torch.cuda` is returned.
+    Return None when there is no corresponding module for device_type, otherwise
+    return the corresponding module.
+    """
+    return getattr(torch, device_type, None)
+
+
+class DeviceMesh:
+    """
+    DeviceMesh represents a mesh of devices, where layout of devices could be
+    represented as a n-d dimension array, and each value of the n-d dimensional
+    array is the global id of the default process group ranks.
+
+    DeviceMesh could be used to describe the layout of devices across the cluster,
+    and serves as a proxy for communication among the device lists within the cluster.
+
+    We use the default ProcessGroup in this DeviceMesh class to implement proper
+    communications. Note that we also add collective wrappers in this class. This is
+    used to decouple detailed communication backend with the underlying
+    DTensor implementation.
+
+    DeviceMesh can be used as a context manager.
+    Args:
+        device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
+        mesh (ndarray): could be a multi-dimension array or an integer tensor that
+            describes the layout of devices, the ids are global ids of the
+            default process group.
+
+    Returns:
+        A :class:`DeviceMesh` object
+
+    Example (2 host with 4 GPUs each):
+        ```
+        # The following program runs on each process/rank in SPMD manner.
+        # initialize device mesh as (2, 4) to represent the topology
+        # of cross-host(dim 0), and within-host (dim 1)
+        mesh = DeviceMesh(device_type="cuda",
+                          mesh=[
+                            [0, 1, 2, 3],
+                            [4, 5, 6, 7]
+                          ])
+        ```
+        A reduction over the first dimension of mesh will reduce across
+        columns (0, 4), .. and (3, 7), a reduction over the second dimension
+        of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
+
+    """
+
+    device_type: str
+    mesh: torch.Tensor
+    mesh_dim_names: Optional[Tuple[str, ...]]
+
+    def __init__(
+        self,
+        device_type: str,
+        mesh: Union[torch.Tensor, "ArrayLike"],
+        *,
+        mesh_dim_names: Optional[Tuple[str, ...]] = None,
+        _init_process_groups: bool = True,
+        _validate_mesh: bool = True,
+    ) -> None:
+        self.device_type = device_type
+        self.mesh = (
+            mesh.detach()
+            if isinstance(mesh, torch.Tensor)
+            else torch.tensor(mesh, dtype=torch.int)
+        )
+        self.mesh_dim_names = mesh_dim_names
+
+        # private field to pre-generate DeviceMesh's hash
+        self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
+        self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
+
+        # Skip process group initialization if xla device.
+        # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
+        if device_type != "xla":
+            # always try to create default (world) pg, even if it is not initialized
+            # already. The world pg is used for device mesh identity (rank) on each
+            # process (we need to know if the current global rank is in the mesh or not).
+            self._get_or_create_default_group()
+            if _init_process_groups:
+                self._init_process_groups(_validate_mesh)
+
+    def _get_or_create_default_group(self):
+        default_initialized = is_initialized()
+        if not default_initialized:
+            init_process_group()
+
+        world_size = get_world_size()
+        if self.mesh.numel() > world_size:
+            raise RuntimeError(
+                f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
+            )
+
+        device_handle = _get_device_handle(self.device_type)
+        # TODO: if user want to pass pg_options, offer a way to do it
+        if not default_initialized and device_handle:
+            # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
+            # NOTE: This device selection would only work for homogeneous hardware.
+            num_devices_per_host = device_handle.device_count()
+            if (
+                world_size > num_devices_per_host
+                and world_size % num_devices_per_host != 0
+            ):
+                raise RuntimeError(
+                    f"DeviceMesh only support homogeneous hardware, but found "
+                    f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
+                )
+            device_handle.set_device(get_rank() % num_devices_per_host)
+
+        # calculate the coordinates of the current global rank on the mesh
+        rank_coords = (self.mesh == get_rank()).nonzero()
+        assert rank_coords.size(0) in (0, 1)
+        self._coordinate_on_dim: Optional[List[int]] = (
+            rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
+        )
+        return _get_default_group()
+
+    def _validate_mesh(self):
+        # check mesh tensor validity
+        unique_mesh_values = self.mesh.unique(sorted=True)
+        if unique_mesh_values.numel() != self.mesh.numel():
+            raise RuntimeError(
+                f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
+            )
+
+        # validate that all calling ranks pass in the same `mesh` argument.
+        self_mesh = self.mesh.to(self.device_type).contiguous()
+        mesh_tensor = funcol.all_gather_tensor(
+            self_mesh, gather_dim=0, group=_get_default_group()
+        )
+        mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size())
+        for other_rank, other_mesh in enumerate(mesh_tensor_chunked):
+            if not torch.equal(self_mesh, other_mesh):
+                raise RuntimeError(
+                    f"DeviceMesh initialization does not allow different mesh argument:"
+                    f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}"
+                    f"has mesh {other_mesh}!"
+                )
+
+    def _init_process_groups(self, _validate_mesh):
+        if _validate_mesh:
+            self._validate_mesh()
+
+        # group tag/ranks associated with each mesh dimension, each mesh dimension should
+        # have one sub-group per rank
+        dim_group_infos: List[Tuple[str, List[int]]] = []
+
+        if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
+            # if the mesh is the same as world_pg, we just append the default
+            # pg to the first dim groups, as new_group cannot have the exact
+            # same ranks as world
+            dim_group_infos.append(
+                (_get_group_tag(_get_default_group()), list(range(get_world_size())))
+            )
+        else:
+            # create sub pgs base on the mesh argument specified
+            for dim in range(self.mesh.ndim):
+                # swap the current dim to the last dim
+                # then reshape to flatten out other dims
+                pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
+                    -1, self.mesh.size(dim)
+                )
+                # multi-dim mesh, create subgroups by looping over the pg_ranks
+                # for each dim and append the groups
+                for dim_mesh in pg_ranks_by_dim:
+                    subgroup_ranks = dim_mesh.tolist()
+                    # call new_group regardless of the current rank in the
+                    # pg or not, it's required that all ranks participate
+                    # in subgroup construction
+                    dim_group = new_group(ranks=subgroup_ranks)
+                    # only add to dim_groups if the current rank in the subgroup
+                    if self.get_rank() in subgroup_ranks:
+                        if len(dim_group_infos) > dim:
+                            raise RuntimeError(
+                                f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
+                                f"in {subgroup_ranks}!"
+                            )
+                        dim_group_infos.append(
+                            (_get_group_tag(dim_group), subgroup_ranks)
+                        )
+        self._dim_group_infos = dim_group_infos
+
+    def __enter__(self) -> "DeviceMesh":
+        # set this mesh as the current mesh in mesh env
+        _mesh_resources.mesh_stack.append(self)
+        return self
+
+    # pyre-fixme[2]: Parameter must be annotated.
+    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
+        # pop this mesh from mesh env
+        _mesh_resources.mesh_stack.pop()
+
+    def __repr__(self) -> str:
+        return f"DeviceMesh:({self.mesh.tolist()})"
+
+    def __hash__(self):
+        return self._hash
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, DeviceMesh):
+            return False
+        if id(self.mesh) == id(other.mesh):
+            return True
+        return (
+            self.mesh.shape == other.mesh.shape
+            and self._flatten_mesh_list == other._flatten_mesh_list
+        )
+
+    def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
+        """
+        Slice the current DeviceMesh based on the mesh_dim_name given to create a child
+        DeviceMesh.
+
+        Args:
+            mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
+            to create a child DeviceMesh for.
+        Returns:
+            A :class:`DeviceMesh` object
+
+        Example (2 host with 4 GPUs each):
+        ```
+        # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
+        mesh = DeviceMesh(device_type="cuda",
+                          mesh=[
+                            [0, 1, 2, 3],
+                            [4, 5, 6, 7]
+                          ],
+                          mesh_dim_names=["dp", "tp"])
+                          )
+        ```
+        Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
+        Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
+        Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
+        Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
+        Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
+        Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
+        """
+        if self.mesh.ndim <= 1:
+            raise RuntimeError(
+                f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
+            )
+        if self.mesh_dim_names is None:
+            raise KeyError(
+                "No `mesh_dim_names` found.",
+                "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
+            )
+        if mesh_dim_name not in self.mesh_dim_names:
+            raise KeyError(
+                f"Mesh dimension '{mesh_dim_name}' does not exist.",
+                f"Available mesh dimensions are: {self.mesh_dim_names}",
+            )
+        mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
+        submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
+
+        return submesh
+
+    def get_dim_groups(
+        self, mesh_dim: Optional[int] = None
+    ) -> Union[ProcessGroup, List[ProcessGroup]]:
+        if not hasattr(self, "_dim_group_infos"):
+            raise RuntimeError("DeviceMesh process groups not initialized!")
+        if mesh_dim is not None:
+            return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
+        else:
+            dim_groups = []
+            for mesh_dim in range(self.mesh.ndim):
+                dim_groups.append(
+                    _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
+                )
+            return dim_groups
+
+    def size(self, dim: Optional[int] = None) -> int:
+        return self.mesh.numel() if dim is None else self.mesh.size(dim)
+
+    @property
+    def ndim(self) -> int:
+        return self.mesh.ndim
+
+    @property
+    def shape(self) -> Tuple[int, ...]:
+        return tuple(self.mesh.shape)
+
+    def get_rank(self) -> int:
+        return get_rank()
+
+    def get_coordinate(self) -> Optional[List[int]]:
+        """
+        Return the relative indices of this rank relative to all
+        dimensions of the mesh. If this rank is not part of the mesh, return None.
+        """
+        return self._coordinate_on_dim if self._coordinate_on_dim else None
+
+
+def init_device_mesh(
+    device_type: str,
+    mesh_shape: Tuple[int, ...],
+    *,
+    mesh_dim_names: Optional[Tuple[str, ...]] = None,
+) -> DeviceMesh:
+    """
+    Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
+    This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
+    and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
+    labeled as mesh_dim_names[i].
+
+
+    Args:
+        device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
+        mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
+        that describes the layout of devices.
+    Kwargs:
+        mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
+        of the multi-dimensional array that describes the layout of devices. Its length must match the length
+        of `mesh_shape`. Each string in mesh_dim_names must be unique.
+
+    Returns:
+        A :class:`DeviceMesh` object
+
+    .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
+    behind the scene, which are required for distributed communications.
+
+    Example:
+        >>> # xdoctest: +SKIP
+        >>> from torch.distributed._tensor.device_mesh import init_device_mesh
+        >>>
+        >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
+        >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
+    """
+    if mesh_dim_names is not None:
+        if len(set(mesh_dim_names)) != len(mesh_dim_names):
+            raise RuntimeError(
+                "Each mesh_dim_name must be uqique.",
+                f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
+            )
+
+        if len(mesh_shape) != len(mesh_dim_names):
+            raise RuntimeError(
+                "mesh_shape and mesh_dim_names should have same length!",
+                f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
+            )
+
+    mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
+    device_mesh = DeviceMesh(
+        device_type=device_type,
+        mesh=mesh,
+        mesh_dim_names=mesh_dim_names,
+    )
+
+    return device_mesh
diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py
index 04c82d1..7489cc0 100644
--- a/torch/distributed/_tensor/device_mesh.py
+++ b/torch/distributed/_tensor/device_mesh.py
@@ -1,454 +1,6 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-import logging
-import math
-from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union
-
-import torch
-import torch.distributed._functional_collectives as funcol
-
-from torch.distributed.distributed_c10d import (
-    _find_pg_by_ranks_and_tag,
-    _get_default_group,
-    _get_group_tag,
-    get_rank,
-    get_world_size,
-    init_process_group,
-    is_initialized,
-    new_group,
-    ProcessGroup,
+from torch.distributed._device_mesh import (  # noqa: F401
+    _get_device_handle,
+    _mesh_resources,
+    DeviceMesh,
+    init_device_mesh,
 )
-
-
-logger = logging.getLogger(__name__)
-
-# only import numpy typing when type checking
-if TYPE_CHECKING:
-    try:
-        from numpy.typing import ArrayLike
-    except ImportError:
-        logger.warning(
-            "DeviceMesh requires numpy >= 1.21 to be installed for type checking"
-        )
-
-
-class _MeshEnv:
-    def __init__(self) -> None:
-        self.mesh_stack: List[DeviceMesh] = []
-        self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {}
-
-    def get_current_mesh(self) -> "DeviceMesh":
-        if len(self.mesh_stack) == 0:
-            raise RuntimeError("No device mesh is currently active!")
-        return self.mesh_stack[-1]
-
-    def create_child_mesh(
-        self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str
-    ) -> "DeviceMesh":
-        # swap the current dim to the last dim then reshape to flatten out other
-        # dims, so we can just extract the list of ranks which contains cur_rank.
-        cur_rank = device_mesh.get_rank()
-        pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(
-            -1, device_mesh.mesh.size(mesh_dim)
-        )
-
-        for mesh_1d in pg_ranks_by_dim:
-            sub_mesh = DeviceMesh(
-                device_mesh.device_type,
-                mesh_1d,
-                mesh_dim_names=(mesh_dim_name,),
-                _init_process_groups=False,
-            )
-            if cur_rank in mesh_1d:
-                res_sub_mesh = sub_mesh
-
-        res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]]
-        # Assign the current DeviceMesh as the parent of the child DeviceMesh.
-        self.child_to_parent_mapping[res_sub_mesh] = device_mesh
-        return res_sub_mesh
-
-    def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]:
-        return self.child_to_parent_mapping.get(device_mesh, None)
-
-    def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]:
-        """
-        Return the index of the mesh dim in the parent mesh.
-        The device_mesh passed in needs to be sliced out from a parent mesh.
-        """
-        parent_mesh = self.get_parent_mesh(device_mesh)
-        child_mesh_dim_names = device_mesh.mesh_dim_names
-        if parent_mesh and child_mesh_dim_names:
-            assert (
-                len(child_mesh_dim_names) == 1
-            ), "The child mesh can only be a 1D mesh."
-            child_mesh_dim_name = child_mesh_dim_names[0]
-            if parent_mesh.mesh_dim_names:
-                return parent_mesh.mesh_dim_names.index(child_mesh_dim_name)
-        return None
-
-    @staticmethod
-    def num_devices_per_host(device_type: str) -> int:
-        return _get_device_handle(device_type).device_count()
-
-    @staticmethod
-    def num_hosts(device_type: str) -> int:
-        # ProcessGroup can't tell us this info so we have to infer it, assume
-        # homogeneous hardware for now
-        return get_world_size() // _MeshEnv.num_devices_per_host(device_type)
-
-
-_mesh_resources: _MeshEnv = _MeshEnv()
-
-
-def _get_device_handle(device_type: str = "cuda"):
-    """
-    Get the module corresponding to the device_type which is cuda or cuda-like device.
-    For example, when the device_type is cuda, the module `torch.cuda` is returned.
-    Return None when there is no corresponding module for device_type, otherwise
-    return the corresponding module.
-    """
-    return getattr(torch, device_type, None)
-
-
-class DeviceMesh:
-    """
-    DeviceMesh represents a mesh of devices, where layout of devices could be
-    represented as a n-d dimension array, and each value of the n-d dimensional
-    array is the global id of the default process group ranks.
-
-    DeviceMesh could be used to describe the layout of devices across the cluster,
-    and serves as a proxy for communication among the device lists within the cluster.
-
-    We use the default ProcessGroup in this DeviceMesh class to implement proper
-    communications. Note that we also add collective wrappers in this class. This is
-    used to decouple detailed communication backend with the underlying
-    DTensor implementation.
-
-    DeviceMesh can be used as a context manager.
-    Args:
-        device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
-        mesh (ndarray): could be a multi-dimension array or an integer tensor that
-            describes the layout of devices, the ids are global ids of the
-            default process group.
-
-    Returns:
-        A :class:`DeviceMesh` object
-
-    Example (2 host with 4 GPUs each):
-        ```
-        # The following program runs on each process/rank in SPMD manner.
-        # initialize device mesh as (2, 4) to represent the topology
-        # of cross-host(dim 0), and within-host (dim 1)
-        mesh = DeviceMesh(device_type="cuda",
-                          mesh=[
-                            [0, 1, 2, 3],
-                            [4, 5, 6, 7]
-                          ])
-        ```
-        A reduction over the first dimension of mesh will reduce across
-        columns (0, 4), .. and (3, 7), a reduction over the second dimension
-        of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7)
-
-    """
-
-    device_type: str
-    mesh: torch.Tensor
-    mesh_dim_names: Optional[Tuple[str, ...]]
-
-    def __init__(
-        self,
-        device_type: str,
-        mesh: Union[torch.Tensor, "ArrayLike"],
-        *,
-        mesh_dim_names: Optional[Tuple[str, ...]] = None,
-        _init_process_groups: bool = True,
-        _validate_mesh: bool = True,
-    ) -> None:
-        self.device_type = device_type
-        self.mesh = (
-            mesh.detach()
-            if isinstance(mesh, torch.Tensor)
-            else torch.tensor(mesh, dtype=torch.int)
-        )
-        self.mesh_dim_names = mesh_dim_names
-
-        # private field to pre-generate DeviceMesh's hash
-        self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
-        self._hash = hash((self._flatten_mesh_list, self.mesh.shape))
-
-        # Skip process group initialization if xla device.
-        # TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
-        if device_type != "xla":
-            # always try to create default (world) pg, even if it is not initialized
-            # already. The world pg is used for device mesh identity (rank) on each
-            # process (we need to know if the current global rank is in the mesh or not).
-            self._get_or_create_default_group()
-            if _init_process_groups:
-                self._init_process_groups(_validate_mesh)
-
-    def _get_or_create_default_group(self):
-        default_initialized = is_initialized()
-        if not default_initialized:
-            init_process_group()
-
-        world_size = get_world_size()
-        if self.mesh.numel() > world_size:
-            raise RuntimeError(
-                f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!"
-            )
-
-        device_handle = _get_device_handle(self.device_type)
-        # TODO: if user want to pass pg_options, offer a way to do it
-        if not default_initialized and device_handle:
-            # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host
-            # NOTE: This device selection would only work for homogeneous hardware.
-            num_devices_per_host = device_handle.device_count()
-            if (
-                world_size > num_devices_per_host
-                and world_size % num_devices_per_host != 0
-            ):
-                raise RuntimeError(
-                    f"DeviceMesh only support homogeneous hardware, but found "
-                    f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!"
-                )
-            device_handle.set_device(get_rank() % num_devices_per_host)
-
-        # calculate the coordinates of the current global rank on the mesh
-        rank_coords = (self.mesh == get_rank()).nonzero()
-        assert rank_coords.size(0) in (0, 1)
-        self._coordinate_on_dim: Optional[List[int]] = (
-            rank_coords[0].tolist() if rank_coords.size(0) > 0 else None
-        )
-        return _get_default_group()
-
-    def _validate_mesh(self):
-        # check mesh tensor validity
-        unique_mesh_values = self.mesh.unique(sorted=True)
-        if unique_mesh_values.numel() != self.mesh.numel():
-            raise RuntimeError(
-                f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}"
-            )
-
-        # validate that all calling ranks pass in the same `mesh` argument.
-        self_mesh = self.mesh.to(self.device_type).contiguous()
-        mesh_tensor = funcol.all_gather_tensor(
-            self_mesh, gather_dim=0, group=_get_default_group()
-        )
-        mesh_tensor_chunked = torch.chunk(mesh_tensor, get_world_size())
-        for other_rank, other_mesh in enumerate(mesh_tensor_chunked):
-            if not torch.equal(self_mesh, other_mesh):
-                raise RuntimeError(
-                    f"DeviceMesh initialization does not allow different mesh argument:"
-                    f"rank {get_rank()} has mesh {self_mesh} while rank {other_rank}"
-                    f"has mesh {other_mesh}!"
-                )
-
-    def _init_process_groups(self, _validate_mesh):
-        if _validate_mesh:
-            self._validate_mesh()
-
-        # group tag/ranks associated with each mesh dimension, each mesh dimension should
-        # have one sub-group per rank
-        dim_group_infos: List[Tuple[str, List[int]]] = []
-
-        if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size():
-            # if the mesh is the same as world_pg, we just append the default
-            # pg to the first dim groups, as new_group cannot have the exact
-            # same ranks as world
-            dim_group_infos.append(
-                (_get_group_tag(_get_default_group()), list(range(get_world_size())))
-            )
-        else:
-            # create sub pgs base on the mesh argument specified
-            for dim in range(self.mesh.ndim):
-                # swap the current dim to the last dim
-                # then reshape to flatten out other dims
-                pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(
-                    -1, self.mesh.size(dim)
-                )
-                # multi-dim mesh, create subgroups by looping over the pg_ranks
-                # for each dim and append the groups
-                for dim_mesh in pg_ranks_by_dim:
-                    subgroup_ranks = dim_mesh.tolist()
-                    # call new_group regardless of the current rank in the
-                    # pg or not, it's required that all ranks participate
-                    # in subgroup construction
-                    dim_group = new_group(ranks=subgroup_ranks)
-                    # only add to dim_groups if the current rank in the subgroup
-                    if self.get_rank() in subgroup_ranks:
-                        if len(dim_group_infos) > dim:
-                            raise RuntimeError(
-                                f"Each device mesh dimension should get only one process group, but got {self.get_rank} "
-                                f"in {subgroup_ranks}!"
-                            )
-                        dim_group_infos.append(
-                            (_get_group_tag(dim_group), subgroup_ranks)
-                        )
-        self._dim_group_infos = dim_group_infos
-
-    def __enter__(self) -> "DeviceMesh":
-        # set this mesh as the current mesh in mesh env
-        _mesh_resources.mesh_stack.append(self)
-        return self
-
-    # pyre-fixme[2]: Parameter must be annotated.
-    def __exit__(self, exc_type, exc_value, exc_traceback) -> None:
-        # pop this mesh from mesh env
-        _mesh_resources.mesh_stack.pop()
-
-    def __repr__(self) -> str:
-        return f"DeviceMesh:({self.mesh.tolist()})"
-
-    def __hash__(self):
-        return self._hash
-
-    def __eq__(self, other: object) -> bool:
-        if not isinstance(other, DeviceMesh):
-            return False
-        if id(self.mesh) == id(other.mesh):
-            return True
-        return (
-            self.mesh.shape == other.mesh.shape
-            and self._flatten_mesh_list == other._flatten_mesh_list
-        )
-
-    def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh":
-        """
-        Slice the current DeviceMesh based on the mesh_dim_name given to create a child
-        DeviceMesh.
-
-        Args:
-            mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
-            to create a child DeviceMesh for.
-        Returns:
-            A :class:`DeviceMesh` object
-
-        Example (2 host with 4 GPUs each):
-        ```
-        # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp")
-        mesh = DeviceMesh(device_type="cuda",
-                          mesh=[
-                            [0, 1, 2, 3],
-                            [4, 5, 6, 7]
-                          ],
-                          mesh_dim_names=["dp", "tp"])
-                          )
-        ```
-        Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
-        Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
-        Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
-        Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
-        Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
-        Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
-        """
-        if self.mesh.ndim <= 1:
-            raise RuntimeError(
-                f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension."
-            )
-        if self.mesh_dim_names is None:
-            raise KeyError(
-                "No `mesh_dim_names` found.",
-                "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.",
-            )
-        if mesh_dim_name not in self.mesh_dim_names:
-            raise KeyError(
-                f"Mesh dimension '{mesh_dim_name}' does not exist.",
-                f"Available mesh dimensions are: {self.mesh_dim_names}",
-            )
-        mesh_dim = self.mesh_dim_names.index(mesh_dim_name)
-        submesh = _mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name)
-
-        return submesh
-
-    def get_dim_groups(
-        self, mesh_dim: Optional[int] = None
-    ) -> Union[ProcessGroup, List[ProcessGroup]]:
-        if not hasattr(self, "_dim_group_infos"):
-            raise RuntimeError("DeviceMesh process groups not initialized!")
-        if mesh_dim is not None:
-            return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
-        else:
-            dim_groups = []
-            for mesh_dim in range(self.mesh.ndim):
-                dim_groups.append(
-                    _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])
-                )
-            return dim_groups
-
-    def size(self, dim: Optional[int] = None) -> int:
-        return self.mesh.numel() if dim is None else self.mesh.size(dim)
-
-    @property
-    def ndim(self) -> int:
-        return self.mesh.ndim
-
-    @property
-    def shape(self) -> Tuple[int, ...]:
-        return tuple(self.mesh.shape)
-
-    def get_rank(self) -> int:
-        return get_rank()
-
-    def get_coordinate(self) -> Optional[List[int]]:
-        """
-        Return the relative indices of this rank relative to all
-        dimensions of the mesh. If this rank is not part of the mesh, return None.
-        """
-        return self._coordinate_on_dim if self._coordinate_on_dim else None
-
-
-def init_device_mesh(
-    device_type: str,
-    mesh_shape: Tuple[int, ...],
-    *,
-    mesh_dim_names: Optional[Tuple[str, ...]] = None,
-) -> DeviceMesh:
-    """
-    Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters.
-    This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape)
-    and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is
-    labeled as mesh_dim_names[i].
-
-
-    Args:
-        device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like.
-        mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array
-        that describes the layout of devices.
-    Kwargs:
-        mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension
-        of the multi-dimensional array that describes the layout of devices. Its length must match the length
-        of `mesh_shape`. Each string in mesh_dim_names must be unique.
-
-    Returns:
-        A :class:`DeviceMesh` object
-
-    .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups
-    behind the scene, which are required for distributed communications.
-
-    Example:
-        >>> # xdoctest: +SKIP
-        >>> from torch.distributed._tensor.device_mesh import init_device_mesh
-        >>>
-        >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,))
-        >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp"))
-    """
-    if mesh_dim_names is not None:
-        if len(set(mesh_dim_names)) != len(mesh_dim_names):
-            raise RuntimeError(
-                "Each mesh_dim_name must be uqique.",
-                f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}",
-            )
-
-        if len(mesh_shape) != len(mesh_dim_names):
-            raise RuntimeError(
-                "mesh_shape and mesh_dim_names should have same length!",
-                f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.",
-            )
-
-    mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape)
-    device_mesh = DeviceMesh(
-        device_type=device_type,
-        mesh=mesh,
-        mesh_dim_names=mesh_dim_names,
-    )
-
-    return device_mesh