| # Copyright (c) Meta Platforms, Inc. and affiliates |
| # Owner(s): ["oncall: distributed"] |
| import os |
| |
| import torch |
| import torch.distributed._functional_collectives as funcol |
| from torch.distributed._tensor import DTensor |
| from torch.distributed._tensor._collective_utils import ( |
| mesh_all_to_all, |
| mesh_broadcast, |
| mesh_scatter, |
| ) |
| from torch.distributed._tensor.placement_types import _Partial, Shard |
| from torch.distributed.device_mesh import _mesh_resources, DeviceMesh, init_device_mesh |
| |
| from torch.distributed.distributed_c10d import ( |
| get_global_rank, |
| get_world_size, |
| init_process_group, |
| is_initialized, |
| is_nccl_available, |
| ProcessGroup, |
| ) |
| from torch.testing._internal.common_distributed import run_with_both_funcol_impls |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| run_tests, |
| ) |
| from torch.testing._internal.distributed._tensor.common_dtensor import ( |
| DTensorTestBase, |
| skip_if_lt_x_gpu, |
| skip_unless_torch_gpu, |
| with_comms, |
| ) |
| from torch.testing._internal.distributed.fake_pg import FakeStore |
| |
| |
| def _get_device_type(world_size): |
| if ( |
| torch.cuda.is_available() |
| and torch.cuda.device_count() >= world_size |
| and is_nccl_available() |
| ): |
| device_type = "cuda" |
| else: |
| device_type = "cpu" |
| return device_type |
| |
| |
| def _set_env_var(addr="localhost", port="25364", world_size=1, rank=0): |
| os.environ["MASTER_ADDR"] = addr |
| os.environ["MASTER_PORT"] = port |
| os.environ["WORLD_SIZE"] = f"{world_size}" |
| os.environ["RANK"] = f"{rank}" |
| |
| |
| @instantiate_parametrized_tests |
| class DeviceMeshTest(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 4 |
| |
| @run_with_both_funcol_impls |
| def test_init_process_group(self): |
| device_type = _get_device_type(self.world_size) |
| mesh_tensor = torch.arange(4).reshape(2, 2) |
| self.assertTrue(not is_initialized()) |
| _set_env_var(world_size=self.world_size, rank=self.rank) |
| DeviceMesh(device_type, mesh_tensor) |
| self.assertTrue(is_initialized()) |
| self.destroy_pg() |
| |
| @with_comms |
| @skip_unless_torch_gpu |
| def test_assert_invalid_mesh_tensor(self): |
| mesh = torch.arange(self.world_size).to(self.rank) |
| with self.assertRaises(ValueError): |
| device_mesh = DeviceMesh(self.device_type, mesh) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_group(self): |
| # TODO: `test_get_group` still periodically timeout on cpu |
| # remove `@skip_unless_torch_gpu` after the problem is fixed. |
| mesh_shape = (2, self.world_size // 2) |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") |
| ) |
| |
| tp_mesh = mesh_2d["tp"] |
| dp_mesh = mesh_2d["dp"] |
| |
| self.assertEqual(len(mesh_2d.get_group()), 2) |
| self.assertEqual(mesh_2d.get_group()[0], mesh_2d.get_group("dp")) |
| self.assertEqual(mesh_2d.get_group()[1], mesh_2d.get_group("tp")) |
| |
| self.assertEqual(mesh_2d.get_group(0), mesh_2d.get_group("dp")) |
| self.assertEqual(mesh_2d.get_group(1), mesh_2d.get_group("tp")) |
| |
| self.assertEqual(mesh_2d.get_group("dp"), dp_mesh.get_group()) |
| self.assertEqual(mesh_2d.get_group("tp"), tp_mesh.get_group()) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_local_rank_raises_exception(self): |
| # TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu |
| # remove `@skip_unless_torch_gpu` after the problem is fixed. |
| mesh_shape = (2, self.world_size // 2) |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", |
| ): |
| local_rank = mesh_2d.get_local_rank() |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_local_rank(self): |
| # TODO: `test_get_local_rank_raises_exception` still periodically timeout on cpu |
| # remove `@skip_unless_torch_gpu` after the problem is fixed. |
| mesh_shape = (2, self.world_size // 2) |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=("dp", "tp") |
| ) |
| self.assertEqual(mesh_2d.get_local_rank("dp"), mesh_2d.get_local_rank(0)) |
| self.assertEqual(mesh_2d.get_local_rank("tp"), mesh_2d.get_local_rank(1)) |
| |
| dp_mesh = mesh_2d["dp"] |
| tp_mesh = mesh_2d["tp"] |
| self.assertEqual(dp_mesh.get_local_rank(), mesh_2d.get_local_rank("dp")) |
| self.assertEqual(tp_mesh.get_local_rank(), mesh_2d.get_local_rank("tp")) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_device_mesh_2d(self): |
| mesh_tensor = torch.arange(4).reshape(2, 2) |
| # construct a cuda device mesh |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_group() |
| |
| expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]] |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| self.assertTrue(dim < 2) |
| dim_ranks = expected_ranks_by_dim[dim] |
| |
| dim_group_size = get_world_size(dim_group) |
| self.assertIsInstance(dim_group, ProcessGroup) |
| self.assertEqual(dim_group_size, 2) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| current_rank_expected_group_ranks = ( |
| dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1] |
| ) |
| self.assertEqual(global_ranks, current_rank_expected_group_ranks) |
| |
| @run_with_both_funcol_impls |
| def test_fake_pg_device_mesh(self): |
| fake_store = FakeStore() |
| init_process_group("fake", store=fake_store, rank=0, world_size=self.world_size) |
| device_type = "cuda" if torch.cuda.is_available() else "cpu" |
| mesh = DeviceMesh(device_type, torch.arange(self.world_size)) |
| |
| local_tensor = torch.randn(2, 8) |
| global_tensor = funcol.all_gather_tensor( |
| local_tensor, gather_dim=0, group=(mesh, 0) |
| ) |
| self.assertEqual(global_tensor.shape, (self.world_size * 2, 8)) |
| |
| |
| class DeviceMeshTestNDim(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 8 |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_device_mesh_nd(self): |
| # construct a cuda device mesh |
| mesh_tensor = torch.arange(8).reshape(2, 2, 2) |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_group() |
| |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| self.assertTrue(dim < mesh_tensor.ndim) |
| dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2) |
| |
| dim_group_size = get_world_size(dim_group) |
| self.assertIsInstance(dim_group, ProcessGroup) |
| self.assertEqual(dim_group_size, 2) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| for ranks in dim_ranks: |
| if self.rank in ranks: |
| self.assertEqual(global_ranks, ranks.tolist()) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_device_mesh_hash(self): |
| mesh_tensor_2d = torch.arange(8).reshape(4, 2) |
| mesh = DeviceMesh(self.device_type, mesh_tensor_2d) |
| mesh2 = DeviceMesh(self.device_type, mesh_tensor_2d) |
| self.assertNotEqual(hash(mesh), hash(mesh2)) |
| mesh_tensor_3d = torch.arange(8).reshape(2, 2, 2) |
| mesh3 = DeviceMesh(self.device_type, mesh_tensor_3d) |
| self.assertNotEqual(hash(mesh), hash(mesh3)) |
| self.assertNotEqual(hash(mesh2), hash(mesh3)) |
| |
| |
| class InitDeviceMeshTest(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 8 |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_init_device_mesh(self): |
| mesh_shape = (2, 4) |
| ref_mesh = DeviceMesh(self.device_type, torch.arange(8).view(mesh_shape)) |
| |
| # test init_device_mesh with mesh_dim_names |
| mesh_dim_names = ("DP", "TP") |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names |
| ) |
| self.assertEqual(mesh_2d, ref_mesh) |
| self.assertEqual(mesh_2d.mesh_dim_names, mesh_dim_names) |
| |
| # test init_device_mesh without mesh_dim_names |
| mesh_2d = init_device_mesh(self.device_type, mesh_shape) |
| self.assertEqual(mesh_2d, ref_mesh) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_raises_duplicate_mesh_dim_names(self): |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Each mesh_dim_name must be unique.", |
| ): |
| mesh = init_device_mesh( |
| self.device_type, |
| (2, 4), |
| mesh_dim_names=["dp", "dp"], |
| ) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_raises_mesh_shape_mesh_dim_names_mismatch(self): |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "mesh_shape and mesh_dim_names should have same length!", |
| ): |
| mesh = init_device_mesh( |
| self.device_type, |
| (8,), |
| mesh_dim_names=["dp", "tp"], |
| ) |
| |
| |
| @instantiate_parametrized_tests |
| class TestDeviceMeshGetItem(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 8 |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_raises_invalid_mesh_dim_names(self): |
| error_msg = "Invalid mesh_dim_name" |
| # Case 1: the DeviceMesh does not have a mesh_dim_names attribute |
| with self.assertRaisesRegex( |
| RuntimeError, "Cannot slice a DeviceMesh without mesh_dim_names." |
| ): |
| mesh = init_device_mesh(self.device_type, (2, 4)) |
| child_mesh = mesh["DP"] |
| |
| child_mesh_dim_names = "PP" |
| with self.assertRaisesRegex(ValueError, error_msg): |
| mesh_dim_names = ("DP", "TP") |
| mesh = init_device_mesh( |
| self.device_type, (2, 4), mesh_dim_names=mesh_dim_names |
| ) |
| child_mesh = mesh[child_mesh_dim_names] |
| |
| # Case 2 |
| child_mesh_dim_names = ["PP", "CP"] |
| with self.assertRaisesRegex(ValueError, error_msg): |
| mesh_dim_names = ("DP", "TP") |
| mesh = init_device_mesh( |
| self.device_type, (2, 4), mesh_dim_names=mesh_dim_names |
| ) |
| child_mesh = mesh[child_mesh_dim_names] |
| |
| # Case 3: a given child_mesh_dim_name is not a contiguous subset of the parent mesh's mesh_dim_names. |
| child_mesh_dim_names = ("TP", "DP") |
| with self.assertRaisesRegex(ValueError, error_msg): |
| mesh_dim_names = ("DP", "TP") |
| mesh = init_device_mesh( |
| self.device_type, (2, 4), mesh_dim_names=mesh_dim_names |
| ) |
| child_mesh = mesh[child_mesh_dim_names] |
| |
| # Case 3 |
| child_mesh_dim_names = ("PP", "TP") |
| with self.assertRaisesRegex(ValueError, error_msg): |
| mesh_dim_names = ("PP", "DP", "TP") |
| mesh = init_device_mesh( |
| self.device_type, (2, 2, 2), mesh_dim_names=mesh_dim_names |
| ) |
| child_mesh = mesh[child_mesh_dim_names] |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_if_lt_x_gpu(8) |
| def test_get_item_2d(self): |
| # TODO: `test_get_item_2d` still periodically timeout on cpu |
| # remove `@skip_if_lt_x_gpu` after the problem is fixed. |
| mesh_shape = (2, 4) |
| mesh_dim_names = ("DP", "TP") |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names |
| ) |
| |
| tp_mesh = mesh_2d["TP"] |
| tp_group = [[0, 1, 2, 3], [4, 5, 6, 7]] |
| tp_group_idx = self.rank // 4 |
| self.assertEqual(tp_mesh.mesh.tolist(), tp_group[tp_group_idx]) |
| |
| dp_mesh = mesh_2d["DP"] |
| dp_group = [[0, 4], [1, 5], [2, 6], [3, 7]] |
| dp_group_idx = self.rank % 4 |
| self.assertEqual(dp_mesh.mesh.tolist(), dp_group[dp_group_idx]) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_get_item_1d(self): |
| mesh = init_device_mesh(self.device_type, (8,), mesh_dim_names=("dp",)) |
| # Make sure slicing out 1D mesh from a 1D mesh works. |
| # We are just dummy return without the parent mesh here. |
| dp_mesh = mesh["dp"] |
| self.assertEqual(dp_mesh, mesh) |
| |
| with self.assertRaisesRegex(ValueError, "Invalid mesh_dim_name"): |
| dp_mesh = mesh["dim0"] |
| |
| @with_comms |
| @skip_if_lt_x_gpu(8) |
| def test_get_item_3d(self): |
| # TODO: `test_get_item_3d` still periodically timeout on cpu |
| # remove `@skip_if_lt_x_gpu` after the problem is fixed. |
| mesh_shape = (2, 2, 2) |
| mesh_dim_names = ("Replicate", "Shard", "TP") |
| mesh_3d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names |
| ) |
| |
| tp_group = [[0, 1], [2, 3], [4, 5], [6, 7]] |
| tp_group_idx = int(self.rank / 2) |
| self.assertEqual(mesh_3d["TP"].mesh.tolist(), tp_group[tp_group_idx]) |
| |
| shard_group = [[0, 2], [1, 3], [4, 6], [5, 7]] |
| shard_group_idx = self.rank % 2 + self.rank // 4 * 2 |
| self.assertEqual(mesh_3d["Shard"].mesh.tolist(), shard_group[shard_group_idx]) |
| |
| replicate_group = [[0, 4], [1, 5], [2, 6], [3, 7]] |
| replicate_group_idx = self.rank % 4 |
| self.assertEqual( |
| mesh_3d["Replicate"].mesh.tolist(), replicate_group[replicate_group_idx] |
| ) |
| |
| # We support both UX for nD slicing. |
| # mesh_3d[["Replicate", "Shard"]] or mesh_3d["Replicate", "Shard"] |
| hsdp_mesh_1 = mesh_3d[["Replicate", "Shard"]] |
| hsdp_mesh_2 = mesh_3d["Replicate", "Shard"] |
| hsdp_group = [[[0, 2], [4, 6]], [[1, 3], [5, 7]]] |
| hsdp_group_idx = self.rank % 2 |
| self.assertEqual(hsdp_mesh_1.mesh.tolist(), hsdp_group[hsdp_group_idx]) |
| self.assertEqual(hsdp_mesh_2.mesh.tolist(), hsdp_group[hsdp_group_idx]) |
| self.assertEqual(hsdp_mesh_1, hsdp_mesh_2) |
| |
| |
| @instantiate_parametrized_tests |
| class TestMeshEnv(DTensorTestBase): |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_parent_mesh(self): |
| mesh_shape = (2, self.world_size // 2) |
| mesh_dim_names = ("DP", "TP") |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names |
| ) |
| |
| self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d) |
| self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d) |
| |
| mesh_0_2 = DeviceMesh(self.device_type, [0, 2]) |
| mesh_1_3 = DeviceMesh(self.device_type, [1, 3]) |
| |
| self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["DP"]), mesh_2d) |
| self.assertEqual(_mesh_resources.get_parent_mesh(mesh_2d["TP"]), mesh_2d) |
| self.assertEqual(_mesh_resources.get_parent_mesh(mesh_0_2), None) |
| self.assertEqual(_mesh_resources.get_parent_mesh(mesh_1_3), None) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_parent_mesh_dim_exist(self): |
| mesh_shape = (2, self.world_size // 2) |
| mesh_dim_names = ("DP", "TP") |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names |
| ) |
| |
| self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["DP"]), 0) |
| self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh_2d["TP"]), 1) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_parent_mesh_dim_not_exist(self): |
| mesh_shape = (self.world_size,) |
| mesh = init_device_mesh(self.device_type, mesh_shape) |
| |
| self.assertEqual(_mesh_resources.get_parent_mesh_dim(mesh), None) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| @skip_unless_torch_gpu |
| def test_get_mesh_dim_by_name(self): |
| mesh_shape = (2, self.world_size // 2) |
| mesh_dim_names = ("DP", "TP") |
| mesh_2d = init_device_mesh( |
| self.device_type, mesh_shape, mesh_dim_names=mesh_dim_names |
| ) |
| |
| self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "DP"), 0) |
| self.assertEqual(_mesh_resources.get_mesh_dim_by_name(mesh_2d, "TP"), 1) |
| |
| |
| @instantiate_parametrized_tests |
| class DeviceMeshCollectiveTest(DTensorTestBase): |
| @property |
| def world_size(self): |
| return 8 |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_broadcast_1d(self): |
| mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) |
| local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank |
| mesh_broadcast(local_tensor, mesh, mesh_dim=0) |
| self.assertEqual(local_tensor, torch.zeros(3, 3)) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_scatter_1d(self): |
| mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) |
| scatter_tensor_shape = [3, 3, 3] |
| for scatter_dim in range(len(scatter_tensor_shape)): |
| shard_placement = Shard(scatter_dim) |
| scatter_tensor_shape[scatter_dim] *= self.world_size |
| # make the random seed same across rank |
| torch.manual_seed(0) |
| global_tensor = torch.randn(scatter_tensor_shape, device=self.device_type) |
| splitted_list, _ = shard_placement._split_tensor( |
| global_tensor, mesh.size(), with_padding=True, contiguous=True |
| ) |
| recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()]) |
| # scatter on dim > 0 would generate non-contiguous tensor, verify that works |
| mesh_scatter(recv_tensor, splitted_list, mesh, mesh_dim=0) |
| self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()]) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_scatter_uneven(self): |
| device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
| my_rank = device_mesh.get_rank() |
| tensor_to_split = torch.randn( |
| device_mesh.size() + 3, device_mesh.size() + 1, device=self.device_type |
| ) |
| |
| for shard_dim in range(tensor_to_split.ndim): |
| shard_placement = Shard(shard_dim) |
| |
| tensor_to_scatter = tensor_to_split.clone() |
| tensor_splitted_list = list( |
| torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) |
| ) |
| for _ in range(self.world_size - len(tensor_splitted_list)): |
| tensor_splitted_list.append(torch.tensor([], device=self.device_type)) |
| |
| padded_tensor_list, pad_sizes = shard_placement._split_tensor( |
| tensor_to_scatter, |
| device_mesh.size(), |
| with_padding=True, |
| contiguous=True, |
| ) |
| |
| scattered_tensor = torch.empty_like(padded_tensor_list[my_rank]) |
| mesh_scatter(scattered_tensor, padded_tensor_list, device_mesh, mesh_dim=0) |
| |
| if pad_sizes[my_rank] != 0: |
| scattered_tensor = shard_placement._unpad_tensor( |
| scattered_tensor, pad_sizes[my_rank] |
| ) |
| |
| if scattered_tensor.numel() == 0: |
| # We need to check numel() instead of size if a tensor is ([]) after unpadding, |
| # since the size could be ([0, 8]) after unpadding. |
| self.assertEqual( |
| scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() |
| ) |
| else: |
| self.assertEqual( |
| scattered_tensor.size(), tensor_splitted_list[my_rank].size() |
| ) |
| self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank]) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_all_gather_uneven(self): |
| device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
| my_rank = device_mesh.get_rank() |
| tensor_to_split = torch.ones( |
| device_mesh.size() + 3, |
| device_mesh.size() + 1, |
| device=self.device_type, |
| ) |
| |
| for shard_dim in range(tensor_to_split.ndim): |
| shard_placement = Shard(shard_dim) |
| tensor_padded_list, pad_sizes = shard_placement._split_tensor( |
| tensor_to_split, |
| device_mesh.size(), |
| with_padding=True, |
| contiguous=True, |
| ) |
| local_tensor = tensor_padded_list[my_rank] |
| big_tensor = funcol.all_gather_tensor( |
| local_tensor, gather_dim=shard_dim, group=(device_mesh, 0) |
| ) |
| big_tensor_chunks = list( |
| torch.chunk(big_tensor, device_mesh.size(), dim=shard_dim) |
| ) |
| unpadded_list = [ |
| shard_placement._unpad_tensor(big_tensor_chunks[i], pad_sizes[i]) |
| if pad_sizes[i] > 0 |
| else big_tensor_chunks[i] |
| for i, big_tensor in enumerate(big_tensor_chunks) |
| ] |
| all_gathered_tensor = torch.cat(unpadded_list, dim=shard_dim) |
| |
| self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size()) |
| self.assertEqual(all_gathered_tensor, tensor_to_split) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_reduce_scatter_contiguous(self): |
| device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
| my_rank = device_mesh.get_rank() |
| |
| # Init the tensor |
| step = self.world_size * 2 |
| total_elem = step**2 |
| tensor = torch.arange(0, total_elem).view(step, -1).to(device=self.device_type) |
| tensor = tensor * (my_rank + 1) |
| |
| # Get non-contiguous tensor by slicing |
| tensor_to_reduce = tensor[::2, :2] |
| tensor_contiguous = tensor_to_reduce.clone().contiguous() |
| |
| # Partial to Shard to trigger reduce_scatter |
| tensor_to_reduce = DTensor.from_local( |
| tensor_to_reduce, device_mesh, [_Partial()] |
| ) |
| tensor_contiguous = DTensor.from_local( |
| tensor_contiguous, device_mesh, [_Partial()] |
| ) |
| new_tensor = tensor_to_reduce.redistribute(device_mesh, [Shard(0)]) |
| new_tensor_contiguous = tensor_contiguous.redistribute(device_mesh, [Shard(0)]) |
| |
| # The output for contiguous and non-contiguous tensors of the same value |
| # should return the same reducescatter value. |
| new_tensor_local = new_tensor._local_tensor |
| new_tensor_contiguous_local = new_tensor_contiguous._local_tensor |
| self.assertEqual(new_tensor_local, new_tensor_contiguous_local) |
| self.assertEqual(list(new_tensor_local.size()), [1, 2]) |
| |
| # Check the reduce numerical value |
| sum_base = (1 + self.world_size) * self.world_size / 2 |
| first_elem = my_rank * sum_base * step * 2 |
| expected_tensor = torch.tensor( |
| [[first_elem, first_elem + sum_base]], |
| dtype=new_tensor_local.dtype, |
| device=self.device_type, |
| ) |
| self.assertEqual(new_tensor_local, expected_tensor) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_reduce_scatter_uneven(self): |
| device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) |
| my_rank = device_mesh.get_rank() |
| tensor_to_split = ( |
| torch.ones( |
| device_mesh.size() + 3, |
| device_mesh.size() + 1, |
| device=self.device_type, |
| ) |
| * self.rank |
| ) |
| |
| for shard_dim in range(tensor_to_split.ndim): |
| shard_placement = Shard(shard_dim) |
| tensor_to_scatter = tensor_to_split.clone() |
| |
| tensor_splitted_list = list( |
| torch.chunk(tensor_to_split, self.world_size, dim=shard_dim) |
| ) |
| for _ in range(self.world_size - len(tensor_splitted_list)): |
| tensor_splitted_list.append(torch.tensor([], device=self.device_type)) |
| |
| padded_tensor_list, pad_sizes = shard_placement._split_tensor( |
| tensor_to_scatter, |
| device_mesh.size(), |
| with_padding=True, |
| contiguous=True, |
| ) |
| |
| tensor_to_reduce = torch.cat(padded_tensor_list, shard_dim) |
| |
| res_num = ((0 + self.world_size - 1) * self.world_size) / 2 |
| |
| scattered_tensor = funcol.reduce_scatter_tensor( |
| tensor_to_reduce, |
| reduceOp="sum", |
| scatter_dim=shard_dim, |
| group=(device_mesh, 0), |
| ) |
| |
| # unpad scattered_tensor |
| if pad_sizes[my_rank] > 0: |
| scattered_tensor = shard_placement._unpad_tensor( |
| scattered_tensor, pad_sizes[my_rank] |
| ) |
| |
| if scattered_tensor.numel() == 0: |
| # We need to check numel() instead of size if a tensor is ([]) after unpadding, |
| # since the size could be ([0, 8]) after unpadding. |
| self.assertEqual( |
| scattered_tensor.numel(), tensor_splitted_list[my_rank].numel() |
| ) |
| else: |
| self.assertEqual( |
| scattered_tensor.size(), tensor_splitted_list[my_rank].size() |
| ) |
| self.assertEqual( |
| scattered_tensor, |
| torch.ones_like(tensor_splitted_list[my_rank]) * res_num, |
| ) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_broadcast_nd(self): |
| mesh_tensor = torch.arange(8).reshape(2, 2, 2) |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_group() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| cloned_local_tensor = local_tensor.clone() |
| mesh_broadcast(cloned_local_tensor, mesh, mesh_dim=dim) |
| res_num = global_ranks[0] |
| self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_scatter_nd(self): |
| mesh_tensor = torch.arange(8).reshape(2, 2, 2) |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| |
| # check all dim groups |
| dim_to_subgroups = mesh.get_group() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| scattered_tensors = [ |
| torch.ones(3, 3, device=self.device_type) * global_rank |
| for global_rank in global_ranks |
| ] |
| received_tensor = torch.empty_like( |
| scattered_tensors[mesh.get_coordinate()[dim]] |
| ) |
| mesh_scatter(received_tensor, scattered_tensors, mesh, mesh_dim=dim) |
| self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_all_to_all_1d(self): |
| # transpose on a 2D tensor distributed over N nodes: |
| mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) |
| tensor_shape = [3, 3] |
| input_tensor_list = [ |
| torch.ones(*tensor_shape, device=self.device_type) |
| * (rank + self.rank * self.world_size) |
| for rank in range(self.world_size) |
| ] |
| expected_tensor_list = [ |
| torch.ones(tensor_shape, device=self.device_type) |
| * (self.rank + rank * self.world_size) # i.e. transpose |
| for rank in range(self.world_size) |
| ] |
| for scatter_dim in range(len(tensor_shape)): |
| output_tensor_list = [ |
| torch.empty_like(input_tensor_list[idx]) |
| for idx in range(len(input_tensor_list)) |
| ] |
| # scatter on dim > 0 would generate non-contiguous tensor, verify that works |
| mesh_all_to_all(output_tensor_list, input_tensor_list, mesh, mesh_dim=0) |
| output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) |
| expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) |
| |
| self.assertEqual(output_tensor, expected_tensor) |
| |
| @with_comms |
| @run_with_both_funcol_impls |
| def test_all_to_all_nd(self): |
| mesh_tensor = torch.arange(8).reshape(2, 2, 2) |
| mesh = DeviceMesh(self.device_type, mesh_tensor) |
| tensor_shape = [3, 3, 3] |
| # check all dim groups |
| dim_to_subgroups = mesh.get_group() |
| for dim, dim_group in enumerate(dim_to_subgroups): |
| my_coordinate = mesh.get_coordinate()[dim] |
| dim_group_size = get_world_size(dim_group) |
| global_ranks = [ |
| get_global_rank(dim_group, i) for i in range(dim_group_size) |
| ] |
| input_tensor_list = [ |
| torch.ones(*tensor_shape, device=self.device_type) |
| * (i + self.rank * dim_group_size) |
| for i in range(dim_group_size) |
| ] |
| expected_tensor_list = [ |
| torch.ones(*tensor_shape, device=self.device_type) |
| * (my_coordinate + global_rank * dim_group_size) # i.e. transpose |
| for global_rank in global_ranks |
| ] |
| for scatter_dim in range(len(tensor_shape)): |
| # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim) |
| output_tensor_list = [ |
| torch.empty_like(input_tensor_list[idx]) |
| for idx in range(len(input_tensor_list)) |
| ] |
| # scatter on dim > 0 would generate non-contiguous tensor, verify that works |
| mesh_all_to_all( |
| output_tensor_list, input_tensor_list, mesh, mesh_dim=dim |
| ) |
| output_tensor = torch.cat(output_tensor_list, dim=scatter_dim) |
| expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim) |
| self.assertEqual(output_tensor, expected_tensor) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |