[dtensor] PART 5: move DTensor basic tests to core distributed (#88178)
This PR moves DTensor basic tests to torch.distributed, including
dtensor, device_mesh tests
part of https://github.com/pytorch/pytorch/issues/88838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88178
Approved by: https://github.com/fduwjj
diff --git a/test/distributed/_tensor/README.md b/test/distributed/_tensor/README.md
new file mode 100644
index 0000000..6235f96
--- /dev/null
+++ b/test/distributed/_tensor/README.md
@@ -0,0 +1,11 @@
+## Run distributed tensor tests:
+
+from root, run (either CPU or GPU)
+
+`pytest test/spmd/tensor/test_tensor.py`
+
+`pytest test/spmd/tensor/test_ddp.py`
+
+run specific test case and print stdout/stderr:
+
+`pytest test/spmd/tensor/test_tensor.py -s -k test_tensor_from_local`
diff --git a/test/distributed/_tensor/__init__.py b/test/distributed/_tensor/__init__.py
new file mode 100644
index 0000000..087882b
--- /dev/null
+++ b/test/distributed/_tensor/__init__.py
@@ -0,0 +1 @@
+# shut up pylint
diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py
new file mode 100644
index 0000000..a966f30
--- /dev/null
+++ b/test/distributed/_tensor/test_api.py
@@ -0,0 +1,234 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms
+from torch.distributed._tensor import (
+ distribute_tensor,
+ distribute_module,
+ DeviceMesh,
+ DTensor,
+ Shard,
+ Replicate,
+)
+
+
+class MyModel(nn.Module):
+ def __init__(self, n_features, n_layers, device):
+ super().__init__()
+ self.seq = nn.Sequential(
+ *[
+ nn.Linear(n_features, n_features, device=device)
+ for _ in range(n_layers)
+ ]
+ )
+
+ def forward(self, x):
+ return self.seq(x)
+
+ def reset_parameters(self):
+ for m in self.seq:
+ m.reset_parameters()
+
+
+class DTensorAPITest(DTensorTestBase):
+ @property
+ def world_size(self) -> int:
+ # hard code world size to 4 as we need to test
+ # at least with 2d mesh
+ return 4
+
+ @with_comms
+ def test_distribute_tensor(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ for requires_grad in [True, False]:
+
+ tensor_to_shard = torch.randn(
+ 3 * self.world_size, 3, requires_grad=requires_grad
+ )
+ dist_tensor = distribute_tensor(
+ tensor_to_shard, device_mesh, shard_spec
+ )
+ self.assertEqual(
+ dist_tensor.size(), torch.Size([3 * self.world_size, 3])
+ )
+ local_tensor = dist_tensor.to_local()
+ self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
+ if requires_grad:
+ self.assertTrue(dist_tensor.requires_grad)
+ self.assertTrue(dist_tensor.is_leaf)
+
+ @with_comms
+ def test_distribute_tensor_errors(self):
+ device_mesh = DeviceMesh(
+ self.device_type, torch.arange(self.world_size).reshape(2, 2)
+ )
+ tensor_shape = [3 * self.world_size, 3 * self.world_size]
+ tensor_to_distribute = torch.randn(*tensor_shape)
+
+ with self.assertRaisesRegex(ValueError, "must have the same length"):
+ shard_spec = [Shard(0)]
+ distribute_tensor(tensor_to_distribute, device_mesh, shard_spec)
+
+ spec = [Shard(0), Shard(1)]
+ dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec)
+
+ with self.assertRaisesRegex(ValueError, "to a different device mesh"):
+ new_mesh = DeviceMesh(
+ self.device_type, torch.arange(self.world_size)
+ )
+ distribute_tensor(dtensor, new_mesh, [Shard(0)])
+
+ with self.assertRaisesRegex(ValueError, "to a different placements"):
+ new_spec = [Shard(0), Replicate()]
+ distribute_tensor(dtensor, device_mesh, new_spec)
+
+ @with_comms
+ def test_distribute_tensor_uneven_sharding(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ input_sizes_and_shard_dims = [
+ ((self.world_size * 3 + 1, 3, 3), 0),
+ ((self.world_size * 3 + 2, 3, 3), 0),
+ ((3, self.world_size * 3 + 1, 3), 1),
+ ((3, self.world_size * 3 + 2, 3), 1),
+ ((3, 3, self.world_size * 3 + 1), 2),
+ ((3, 3, self.world_size * 3 + 2), 2),
+ ]
+ for input_size, shard_dim in input_sizes_and_shard_dims:
+ shard_spec = [Shard(shard_dim)]
+ tensor_to_shard = torch.randn(input_size)
+ splitted_tensor_list = tensor_to_shard.tensor_split(
+ self.world_size, dim=shard_dim
+ )
+ dist_tensor = distribute_tensor(
+ tensor_to_shard, device_mesh, shard_spec
+ )
+ self.assertEqual(dist_tensor.size(), torch.Size(input_size))
+ local_tensor = dist_tensor.to_local()
+ self.assertEqual(local_tensor, splitted_tensor_list[self.rank])
+
+ @with_comms
+ def test_distribute_module(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ # fully shard all linear modules on dim 0
+ module_to_shard = MyModel(
+ 5 * self.world_size, 20, device=self.device_type
+ )
+ shard_spec = [Shard(0)]
+
+ def shard_fn(name, module, device_mesh):
+ if isinstance(module, nn.Linear):
+ for name, param in module.named_parameters():
+ dist_param = torch.nn.Parameter(
+ distribute_tensor(param, device_mesh, shard_spec)
+ )
+ module.register_parameter(name, dist_param)
+
+ sharded_module = distribute_module(
+ module_to_shard, device_mesh, shard_fn
+ )
+ for param in sharded_module.parameters():
+ self.assertIsInstance(param, DTensor)
+ self.assertEqual(param.placements, shard_spec)
+
+ replica_spec = [Replicate()]
+ # fully replicate all modules without passing in partition_fn
+ module_to_replicate = MyModel(5, 20, device=self.device_type)
+ replica_module = distribute_module(module_to_replicate, device_mesh)
+ for param in replica_module.parameters():
+ self.assertIsInstance(param, DTensor)
+ self.assertEqual(param.placements, replica_spec)
+
+ # fully replicate all modules by passing in partition_fn
+ def replicate_fn(name, module, device_mesh):
+ if isinstance(module, nn.Linear):
+ for name, param in module.named_parameters():
+ dist_param = torch.nn.Parameter(
+ distribute_tensor(param, device_mesh, replica_spec)
+ )
+ module.register_parameter(name, dist_param)
+
+ module_to_replicate = MyModel(5, 20, device=self.device_type)
+ replica_module = distribute_module(
+ module_to_replicate, device_mesh, replicate_fn
+ )
+ for param in replica_module.parameters():
+ self.assertIsInstance(param, DTensor)
+ self.assertEqual(param.placements, replica_spec)
+
+ # only shard part of module, and rest of module should be replicate
+ def shard_fn(name, module, device_mesh):
+ if isinstance(module, nn.Linear) and (
+ name == "seq.0" or name == "seq.8"
+ ):
+ for name, param in module.named_parameters():
+ dist_param = torch.nn.Parameter(
+ distribute_tensor(param, device_mesh, shard_spec)
+ )
+ module.register_parameter(name, dist_param)
+
+ module_to_distribute = MyModel(
+ 5 * self.world_size, 20, device=self.device_type
+ )
+ dist_module = distribute_module(
+ module_to_distribute, device_mesh, shard_fn
+ )
+ for name, param in dist_module.named_parameters():
+ self.assertIsInstance(param, DTensor)
+ if name.startswith("seq.0") or name.startswith("seq.8"):
+ self.assertEqual(param.placements, shard_spec)
+ else:
+ self.assertEqual(param.placements, replica_spec)
+
+ @with_comms
+ def test_distribute_module_input_fn_output_fn(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+
+ # fully replicate all linear modules
+ module_to_replicate = MyModel(20, 1, device=self.device_type)
+
+ # mark input sharding on dim 0
+ def input_fn(inputs, device_mesh):
+ return DTensor.from_local(inputs[0], device_mesh, [Shard(0)])
+
+ def output_fn(outputs, device_mesh):
+ assert isinstance(outputs, DTensor)
+ return outputs.to_local()
+
+ replica_module = distribute_module(
+ module_to_replicate,
+ device_mesh,
+ input_fn=input_fn,
+ output_fn=output_fn,
+ )
+
+ input_tensor = torch.randn(5, 20, device=self.device_type)
+ local_out = replica_module(input_tensor)
+ self.assertIsInstance(local_out, torch.Tensor)
+ self.assertNotIsInstance(local_out, DTensor)
+
+ # full replicate (even on inputs)
+ model = MyModel(10, 10, device=self.device_type)
+
+ def replicate_input_fn(inputs, device_mesh):
+ return DTensor.from_local(inputs[0], device_mesh, [Replicate()])
+
+ replica_model = distribute_module(
+ model,
+ device_mesh,
+ input_fn=replicate_input_fn,
+ )
+ input = torch.randn(10, 10, requires_grad=True)
+ output = replica_model(input)
+ output.sum().backward()
+ param_grad = list(replica_model.parameters())[0].grad
+ self.assertTrue(isinstance(param_grad, DTensor))
+ self.assertTrue(isinstance(param_grad.placements[0], Replicate))
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py
new file mode 100644
index 0000000..7088f33
--- /dev/null
+++ b/test/distributed/_tensor/test_device_mesh.py
@@ -0,0 +1,518 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+
+from torch.distributed.distributed_c10d import (
+ ProcessGroup,
+ new_group,
+ get_global_rank,
+ get_world_size,
+)
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ with_comms,
+)
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from torch.distributed._tensor.placement_types import Shard
+
+
+class DeviceMeshTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 8
+
+ @with_comms
+ def test_device_mesh_2d(self):
+ mesh_tensor = torch.arange(4).reshape(2, 2)
+ # construct a cuda device mesh
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+
+ expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ self.assertTrue(dim < 2)
+ dim_ranks = expected_ranks_by_dim[dim]
+
+ dim_group_size = get_world_size(dim_group)
+ self.assertIsInstance(dim_group, ProcessGroup)
+ self.assertEqual(dim_group_size, 2)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ current_rank_expected_group_ranks = (
+ dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
+ )
+ self.assertEqual(global_ranks, current_rank_expected_group_ranks)
+
+ @with_comms
+ def test_device_mesh_2d_from_dim_groups(self):
+ # construct a two dimension subgroups
+ dim_groups = []
+ expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
+ for dim_group_ranks in expected_ranks_by_dim:
+ for subgroup_ranks in dim_group_ranks:
+ subgroup = new_group(ranks=subgroup_ranks)
+ if self.rank in subgroup_ranks:
+ dim_groups.append(subgroup)
+
+ # construct a device mesh from the subgroups
+ mesh = DeviceMesh(
+ self.device_type, [[0, 1], [2, 3]], dim_groups=dim_groups
+ )
+
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ self.assertTrue(dim < 2)
+ dim_ranks = expected_ranks_by_dim[dim]
+
+ dim_group_size = get_world_size(dim_group)
+ self.assertIsInstance(dim_group, ProcessGroup)
+ self.assertEqual(dim_group_size, 2)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ current_rank_expected_group_ranks = (
+ dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
+ )
+ self.assertEqual(global_ranks, current_rank_expected_group_ranks)
+
+ @with_comms
+ def test_device_mesh_dim_groups_error(self):
+ # construct a two dimension subgroups
+ dim_groups = []
+ expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
+ for dim_group_ranks in expected_ranks_by_dim:
+ for subgroup_ranks in dim_group_ranks:
+ subgroup = new_group(ranks=subgroup_ranks)
+ if self.rank in subgroup_ranks:
+ dim_groups.append(subgroup)
+
+ if len(dim_groups) > 0:
+ # dim_groups is not a list
+ self.assertRaises(
+ RuntimeError,
+ DeviceMesh,
+ self.device_type,
+ [[0, 1], [2, 3]],
+ dim_groups=dim_groups[0],
+ )
+
+ # dim_groups is a list, but not a list of ProcessGroup
+ self.assertRaises(
+ RuntimeError,
+ DeviceMesh,
+ self.device_type,
+ [[0, 1], [2, 3]],
+ dim_groups=[dim_groups[0], "dummy"],
+ )
+
+ # dim_groups has incorrect length
+ self.assertRaises(
+ RuntimeError,
+ DeviceMesh,
+ self.device_type,
+ [[0, 1], [2, 3]],
+ dim_groups=[dim_groups[0]],
+ )
+
+ @with_comms
+ def test_device_mesh_nd(self):
+ # construct a cuda device mesh
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ self.assertTrue(dim < mesh_tensor.ndim)
+ dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)
+ # print(dim_ranks)
+ # dim_ranks = expected_ranks_by_dim[dim]
+
+ dim_group_size = get_world_size(dim_group)
+ self.assertIsInstance(dim_group, ProcessGroup)
+ self.assertEqual(dim_group_size, 2)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ for ranks in dim_ranks:
+ if self.rank in ranks:
+ self.assertEqual(global_ranks, ranks.tolist())
+
+
+class DeviceMeshCollectiveTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 8
+
+ @with_comms
+ def test_all_reduce_1d(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+ mesh.all_reduce(local_tensor, mesh_dim=0)
+ res_num = ((0 + self.world_size - 1) * self.world_size) / 2
+ self.assertEqual(local_tensor, torch.ones(3, 3) * res_num)
+
+ @with_comms
+ def test_broadcast_1d(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+ mesh.broadcast(local_tensor, mesh_dim=0)
+ self.assertEqual(local_tensor, torch.zeros(3, 3))
+
+ @with_comms
+ def test_scatter_1d(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ scatter_tensor_shape = [3, 3, 3]
+ for scatter_dim in range(len(scatter_tensor_shape)):
+ shard_placement = Shard(scatter_dim)
+ scatter_tensor_shape[scatter_dim] *= self.world_size
+ # make the random seed same across rank
+ torch.manual_seed(0)
+ global_tensor = torch.randn(
+ scatter_tensor_shape, device=self.device_type
+ )
+ splitted_list, _ = shard_placement._split_tensor(
+ global_tensor, mesh.size(), with_padding=True, contiguous=True
+ )
+ recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])
+ # scatter on dim > 0 would generate non-contiguous tensor, verify that works
+ mesh.scatter(recv_tensor, splitted_list, mesh_dim=0)
+ self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])
+
+ @with_comms
+ def test_scatter_uneven(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ my_rank = device_mesh.get_rank()
+ tensor_to_split = torch.randn(
+ device_mesh.size() + 3, device_mesh.size() + 1
+ )
+
+ for shard_dim in range(tensor_to_split.ndim):
+ shard_placement = Shard(shard_dim)
+ tensor_to_scatter = tensor_to_split.clone()
+ tensor_splitted_list = tensor_to_split.tensor_split(
+ device_mesh.size(), dim=shard_dim
+ )
+ padded_tensor_list, pad_idx = shard_placement._split_tensor(
+ tensor_to_scatter,
+ device_mesh.size(),
+ with_padding=True,
+ contiguous=True,
+ )
+
+ scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
+ device_mesh.scatter(
+ scattered_tensor, padded_tensor_list, mesh_dim=0
+ )
+ # unpad scattered_tensor
+ if pad_idx != 0 and my_rank >= pad_idx:
+ scattered_tensor = shard_placement._unpad_tensor(
+ scattered_tensor
+ )
+
+ self.assertEqual(
+ scattered_tensor.size(), tensor_splitted_list[my_rank].size()
+ )
+ self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])
+
+ @with_comms
+ def test_all_gather_1d(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ dims_to_gather = [0, 1]
+ for dim in dims_to_gather:
+ output_size = [3, 3]
+ output_size[dim] *= self.world_size
+ # each rank have its own tensor, all_gather gives a list
+ local_tensor = torch.ones(3, 3, device=self.device_type)
+ gathered_list = []
+ for _ in range(self.world_size):
+ gathered_list.append(torch.zeros_like(local_tensor))
+ mesh.all_gather(gathered_list, local_tensor, mesh_dim=0)
+ gathered_tensor = torch.cat(gathered_list, dim=dim)
+ self.assertEqual(gathered_tensor, torch.ones(output_size))
+
+ @with_comms
+ def test_all_gather_uneven(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ my_rank = device_mesh.get_rank()
+ tensor_to_split = torch.ones(
+ device_mesh.size() + 3,
+ device_mesh.size() + 1,
+ device=self.device_type,
+ )
+
+ for shard_dim in range(tensor_to_split.ndim):
+ shard_placement = Shard(shard_dim)
+ tensor_padded_list, pad_idx = shard_placement._split_tensor(
+ tensor_to_split,
+ device_mesh.size(),
+ with_padding=True,
+ contiguous=True,
+ )
+ local_tensor = tensor_padded_list[my_rank]
+ gathered_list = []
+ for _ in range(device_mesh.size()):
+ gathered_list.append(torch.empty_like(local_tensor))
+
+ device_mesh.all_gather(
+ gathered_list,
+ local_tensor,
+ mesh_dim=0,
+ )
+ if pad_idx != 0:
+ gathered_list = [
+ shard_placement._unpad_tensor(gathered_tensor)
+ if i >= pad_idx
+ else gathered_tensor
+ for i, gathered_tensor in enumerate(gathered_list)
+ ]
+ all_gathered_tensor = torch.cat(gathered_list, dim=shard_dim)
+ self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())
+ self.assertEqual(all_gathered_tensor, tensor_to_split)
+
+ @with_comms
+ def test_reduce_scatter_1d(self):
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ dims_to_scatter = [0, 1]
+ for dim in dims_to_scatter:
+ input_size = [3, 3]
+ scattered_tensor = torch.empty(input_size, device=self.device_type)
+ input_size[dim] *= self.world_size
+
+ input_rs_list = (
+ torch.ones(input_size, device=self.device_type) * self.rank
+ ).tensor_split(self.world_size, dim=dim)
+ res_num = ((0 + self.world_size - 1) * self.world_size) / 2
+ mesh.reduce_scatter(scattered_tensor, input_rs_list, mesh_dim=0)
+ self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num)
+
+ @with_comms
+ def test_reduce_scatter_uneven(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ my_rank = device_mesh.get_rank()
+ tensor_to_split = (
+ torch.ones(
+ device_mesh.size() + 3,
+ device_mesh.size() + 1,
+ device=self.device_type,
+ )
+ * self.rank
+ )
+
+ for shard_dim in range(tensor_to_split.ndim):
+ shard_placement = Shard(shard_dim)
+ tensor_to_scatter = tensor_to_split.clone()
+ tensor_splitted_list = tensor_to_split.tensor_split(
+ device_mesh.size(), dim=shard_dim
+ )
+ padded_tensor_list, pad_idx = shard_placement._split_tensor(
+ tensor_to_scatter,
+ device_mesh.size(),
+ with_padding=True,
+ contiguous=True,
+ )
+
+ res_num = ((0 + self.world_size - 1) * self.world_size) / 2
+ scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
+ device_mesh.reduce_scatter(
+ scattered_tensor, padded_tensor_list, mesh_dim=0
+ )
+ # unpad scattered_tensor
+ if pad_idx != 0 and my_rank >= pad_idx:
+ scattered_tensor = shard_placement._unpad_tensor(
+ scattered_tensor
+ )
+
+ self.assertEqual(
+ scattered_tensor.size(), tensor_splitted_list[my_rank].size()
+ )
+ self.assertEqual(
+ scattered_tensor,
+ torch.ones_like(tensor_splitted_list[my_rank]) * res_num,
+ )
+
+ @with_comms
+ def test_all_gather_nd(self):
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+ local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ dim_group_size = get_world_size(dim_group)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ gathered_tensor_list = list(
+ torch.empty(
+ (dim_group_size * 3, 3), device=self.device_type
+ ).tensor_split(dim_group_size, dim=0)
+ )
+ mesh.all_gather(gathered_tensor_list, local_tensor, mesh_dim=dim)
+ gathered_tensor = torch.cat(gathered_tensor_list)
+ exp_tensor = torch.ones(3 * dim_group_size, 3)
+ for i in range(len(global_ranks)):
+ exp_tensor[i * 3 : (i + 1) * 3] = (
+ torch.ones(3, 3) * global_ranks[i]
+ )
+ self.assertEqual(gathered_tensor, exp_tensor)
+
+ @with_comms
+ def test_reduce_scatter_nd(self):
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ dim_group_size = get_world_size(dim_group)
+ local_rs_list = (
+ torch.ones(dim_group_size * 3, 3, device=self.device_type)
+ * self.rank
+ ).tensor_split(dim_group_size, dim=0)
+ scattered_tensor = torch.empty_like(
+ local_rs_list[mesh.get_coordinate_on_dim(dim)],
+ device=self.device_type,
+ )
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ mesh.reduce_scatter(scattered_tensor, local_rs_list, mesh_dim=dim)
+ res_num = torch.sum(torch.tensor(global_ranks))
+ self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num)
+
+ @with_comms
+ def test_all_reduce_nd(self):
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+ local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ dim_group_size = get_world_size(dim_group)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ cloned_local_tensor = local_tensor.clone()
+ mesh.all_reduce(cloned_local_tensor, mesh_dim=dim)
+ res_num = sum(global_ranks)
+ self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
+
+ @with_comms
+ def test_broadcast_nd(self):
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+ local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ dim_group_size = get_world_size(dim_group)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ cloned_local_tensor = local_tensor.clone()
+ mesh.broadcast(cloned_local_tensor, mesh_dim=dim)
+ res_num = global_ranks[0]
+ self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
+
+ @with_comms
+ def test_scatter_nd(self):
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ dim_group_size = get_world_size(dim_group)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ scattered_tensors = [
+ torch.ones(3, 3, device=self.device_type) * global_rank
+ for global_rank in global_ranks
+ ]
+ received_tensor = torch.empty_like(
+ scattered_tensors[mesh.get_coordinate_on_dim(dim)]
+ )
+ mesh.scatter(received_tensor, scattered_tensors, mesh_dim=dim)
+ self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
+
+ @with_comms
+ def test_all_to_all_1d(self):
+ # transpose on a 2D tensor distributed over N nodes:
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ tensor_shape = [3, 3]
+ input_tensor_list = [
+ torch.ones(*tensor_shape, device=self.device_type)
+ * (rank + self.rank * self.world_size)
+ for rank in range(self.world_size)
+ ]
+ expected_tensor_list = [
+ torch.ones(tensor_shape, device=self.device_type)
+ * (self.rank + rank * self.world_size) # i.e. transpose
+ for rank in range(self.world_size)
+ ]
+ for scatter_dim in range(len(tensor_shape)):
+ output_tensor_list = [
+ torch.empty_like(input_tensor_list[idx])
+ for idx in range(len(input_tensor_list))
+ ]
+ # scatter on dim > 0 would generate non-contiguous tensor, verify that works
+ mesh.all_to_all(output_tensor_list, input_tensor_list, mesh_dim=0)
+ output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
+ expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
+
+ self.assertEqual(output_tensor, expected_tensor)
+
+ @with_comms
+ def test_all_to_all_nd(self):
+ mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+ tensor_shape = [3, 3, 3]
+ # check all dim groups
+ dim_to_subgroups = mesh.get_dim_groups()
+ for dim, dim_group in enumerate(dim_to_subgroups):
+ my_coordinate = mesh.get_coordinate_on_dim(dim)
+ dim_group_size = get_world_size(dim_group)
+ global_ranks = [
+ get_global_rank(dim_group, i) for i in range(dim_group_size)
+ ]
+ input_tensor_list = [
+ torch.ones(*tensor_shape, device=self.device_type)
+ * (i + self.rank * dim_group_size)
+ for i in range(dim_group_size)
+ ]
+ expected_tensor_list = [
+ torch.ones(*tensor_shape, device=self.device_type)
+ * (
+ my_coordinate + global_rank * dim_group_size
+ ) # i.e. transpose
+ for global_rank in global_ranks
+ ]
+ for scatter_dim in range(len(tensor_shape)):
+ # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)
+ output_tensor_list = [
+ torch.empty_like(input_tensor_list[idx])
+ for idx in range(len(input_tensor_list))
+ ]
+ # scatter on dim > 0 would generate non-contiguous tensor, verify that works
+ mesh.all_to_all(
+ output_tensor_list, input_tensor_list, mesh_dim=dim
+ )
+ output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
+ expected_tensor = torch.cat(
+ expected_tensor_list, dim=scatter_dim
+ )
+ self.assertEqual(output_tensor, expected_tensor)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py
new file mode 100644
index 0000000..51ce1bd
--- /dev/null
+++ b/test/distributed/_tensor/test_dtensor.py
@@ -0,0 +1,359 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ with_comms,
+)
+from torch.distributed._tensor import DeviceMesh, DTensor, distribute_tensor
+from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
+
+
+class DTensorTest(DTensorTestBase):
+ # @with_comms
+ # def test_tensor_constructor(self):
+ # import torch.distributed._tensor as dist_tensor
+ # shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)])
+ # empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec)
+ # zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec)
+ # one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec)
+
+ # zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec)
+
+ # dist_tensor.empty_like(empty_tensor)
+ # dist_tensor.zero_like(empty_tensor)
+ # dist_tensor.one_like(empty_tensor)
+
+ @with_comms
+ def test_dtensor_constructor(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3, requires_grad=True)
+ dist_tensor_shape = torch.Size([self.world_size * 3, 3])
+ dist_tensor = DTensor(
+ local_tensor,
+ device_mesh,
+ shard_spec,
+ size=dist_tensor_shape,
+ requires_grad=True,
+ )
+ self.assertEqual(
+ dist_tensor.size(), torch.Size((self.world_size * 3, 3))
+ )
+
+ with self.assertWarnsRegex(UserWarning, "To construct"):
+ DTensor(
+ local_tensor, device_mesh, shard_spec, size=dist_tensor_shape
+ )
+
+ local_tensor = torch.randn(3, 3, requires_grad=False)
+ with self.assertWarnsRegex(UserWarning, "To construct"):
+ dist_tensor = DTensor(
+ local_tensor,
+ device_mesh,
+ shard_spec,
+ size=dist_tensor_shape,
+ requires_grad=True,
+ )
+
+ @with_comms
+ def test_dtensor_stride(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard0_spec = [Shard(0)]
+ local_tensor = torch.randn(4, 8)
+ global_shape = torch.Size([self.world_size * 4, 8])
+ dist_tensor = DTensor(
+ local_tensor, device_mesh, shard0_spec, size=global_shape
+ )
+ # won't affect stride
+ self.assertEqual(dist_tensor.stride(), (8, 1))
+
+ shard1_spec = [Shard(1)]
+ local_tensor = torch.randn(8, 4)
+ global_shape = torch.Size([8, self.world_size * 4])
+ dist_tensor = DTensor(
+ local_tensor, device_mesh, shard1_spec, size=global_shape
+ )
+ # will affect stride after DT initialized
+ self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
+
+ # if initialized from a transposed mat
+ local_tensor = torch.randn(8, 4, 8)
+ local_tensor_t = local_tensor.permute(1, 2, 0)
+ global_shape = torch.Size([4, self.world_size * 8, 8])
+ self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
+ dist_tensor = DTensor(
+ local_tensor_t, device_mesh, shard1_spec, size=global_shape
+ )
+ global_stride = (8 * self.world_size, 1, 32 * self.world_size)
+ self.assertEqual(dist_tensor.stride(), global_stride)
+
+ @with_comms
+ def test_from_local(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ sharded_tensor = DTensor.from_local(
+ local_tensor, device_mesh, shard_spec
+ )
+ self.assertEqual(
+ sharded_tensor.size(), torch.Size([self.world_size * 3, 3])
+ )
+
+ replica_spec = [Replicate()]
+ ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec)
+ self.assertEqual(ddp_tensor.size(), local_tensor.size())
+
+ partial_spec = [_Partial()]
+ partial_tensor = DTensor.from_local(
+ local_tensor, device_mesh, partial_spec
+ )
+ self.assertEqual(partial_tensor.size(), local_tensor.size())
+
+ # test dist tensor works with torch.Tensor during backwards
+ local_tensor_with_grad = torch.randn(3, 3, requires_grad=True)
+ # do some operations on local tensor
+ local_tensor_temp = local_tensor_with_grad * 3
+ # create the dist tensor with non leaf local tensor, dist tensor created
+ # should also be non leaf node
+ dist_tensor = DTensor.from_local(
+ local_tensor_temp, device_mesh, shard_spec
+ )
+ self.assertFalse(dist_tensor.is_leaf)
+ # do some random operations on dist tensor
+ output = dist_tensor * 3
+ self.assertIsInstance(output, DTensor)
+ # trigger .backward() on dist tensor directly
+ local_grad = torch.ones(3, 3)
+ grad_output = DTensor.from_local(local_grad, device_mesh, shard_spec)
+ # run backward directly on dist tensor
+ output.backward(grad_output)
+ # check it gradients flow back to original torch.Tensor
+ self.assertIsNotNone(local_tensor_with_grad.grad)
+ expected_grad = torch.ones(3, 3) * 9
+ self.assertEqual(local_tensor_with_grad.grad, expected_grad)
+
+ @with_comms
+ def test_to_local(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+ dist_tensor_shape = torch.Size([self.world_size * 3, 3])
+ local_tensor_with_grad = torch.randn(
+ 3, 3, device=self.device_type, requires_grad=True
+ )
+
+ sharded_tensor = DTensor(
+ local_tensor_with_grad,
+ device_mesh,
+ shard_spec,
+ size=dist_tensor_shape,
+ requires_grad=True,
+ )
+ self.assertEqual(sharded_tensor.size(), dist_tensor_shape)
+ self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad)
+
+ # test dist tensor works with torch.Tensor during backwards
+ # dist tensor created is a leaf node, do some operation on dist tensor
+ temp_st = sharded_tensor * 3
+
+ # do some operation on local tensor of the dist tensor
+ new_tensor_with_grad = torch.randn(
+ 3, 3, device=self.device_type, requires_grad=True
+ )
+ res = temp_st.to_local() + new_tensor_with_grad
+ # call backward directly on torch.Tensor, and see if it works by
+ # propagating through dist tensor
+ res.sum().backward()
+ self.assertIsNotNone(sharded_tensor.grad)
+
+ self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3)
+
+ @with_comms
+ def test_from_local_then_to_local(self):
+ # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+
+ # step 1. construct from construct local tensor
+ local_tensor_with_grad = torch.randn(
+ 3, 3, device=self.device_type, requires_grad=True
+ )
+ # do some operations on local tensor
+ local_tensor_temp = local_tensor_with_grad + 8
+ # step 2. create the dist tensor with non leaf local tensor, dist tensor
+ # created should also be non leaf node
+ dist_tensor = DTensor.from_local(
+ local_tensor_temp, device_mesh, shard_spec
+ )
+ self.assertFalse(dist_tensor.is_leaf)
+ # do some random operations on dist tensor
+ output = dist_tensor * 6
+ self.assertIsInstance(output, DTensor)
+
+ # step 3. do some operation on local tensor of the dist tensor
+ new_tensor_with_grad = torch.randn(
+ 3, 3, device=self.device_type, requires_grad=True
+ )
+ res = output.to_local() + new_tensor_with_grad
+ # call backward directly on torch.Tensor, and see if it works by
+ # propagating all the way back to the original torch.Tensor
+ res.sum().backward()
+ self.assertIsNotNone(local_tensor_with_grad.grad)
+
+ expected_grad = torch.ones(3, 3) * 6
+ self.assertEqual(local_tensor_with_grad.grad, expected_grad)
+
+ @with_comms
+ def test_dtensor_spec_read_only_after_set(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ sharded_tensor = DTensor.from_local(
+ local_tensor, device_mesh, shard_spec
+ )
+
+ # modify shard_spec, and dist_tensor's spec should not be changed
+ shard_spec[0] = Replicate()
+ self.assertTrue(sharded_tensor.placements is not shard_spec)
+ self.assertNotEqual(sharded_tensor.placements, shard_spec)
+
+ @with_comms
+ def test_dtensor_properties(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ sharded_tensor = DTensor.from_local(
+ local_tensor, device_mesh, shard_spec
+ )
+ self.assertEqual(sharded_tensor.device.type, self.device_type)
+
+
+class DTensorMeshTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 8
+
+ @with_comms
+ def test_dtensor_device_mesh_device_conversion(self):
+ # construct a cuda device mesh
+ mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+
+ # construct from a cpu local tensor with cuda device mesh
+ # should automatically convert the dist tensor to cuda
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+ self.assertEqual(dist_tensor.device.type, self.device_type)
+ self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+ @with_comms
+ def test_dtensor_api_device_mesh_context_manager(self):
+ with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ sharded_tensor = DTensor.from_local(
+ local_tensor, device_mesh=mesh, placements=shard_spec
+ )
+
+ with DeviceMesh(self.device_type, list(range(self.world_size))):
+ shard_spec = [Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ sharded_tensor = DTensor.from_local(
+ local_tensor, placements=shard_spec
+ )
+ replica_spec = [Replicate()]
+ replica_tensor = sharded_tensor.redistribute(
+ placements=replica_spec
+ )
+ self.assertEqual(
+ replica_tensor.size(), torch.Size([3 * self.world_size, 3])
+ )
+
+ @with_comms
+ def test_dtensor_2d_mesh(self):
+ mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
+ # construct a cuda device mesh
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+ # construct a dist tensor on 2d device mesh and test if works
+ shard_spec = [Shard(0), Shard(1)]
+ local_tensor = torch.randn(3, 3)
+ dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+ self.assertEqual(
+ dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
+ )
+ self.assertEqual(dist_tensor.device.type, self.device_type)
+ self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+ # if shard on the same tensor dimension
+ # we should correctly construct the global tensor size
+ shard_same_dim_spec = [Shard(0), Shard(0)]
+ local_tensor = torch.randn(3, 3)
+ dist_tensor = DTensor.from_local(
+ local_tensor, mesh, shard_same_dim_spec
+ )
+ self.assertEqual(
+ dist_tensor.size(), torch.Size([3 * self.world_size, 3])
+ )
+
+ @with_comms
+ def test_device_mesh_nd(self):
+ # construct a cuda device mesh
+ mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
+ mesh = DeviceMesh(self.device_type, mesh_tensor)
+ # construct a dist tensor on 3d device mesh and test if works
+ shard_spec = [Shard(0), Shard(1), Shard(2)]
+ local_tensor = torch.randn(3, 3, 3)
+ dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+ self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
+ self.assertEqual(dist_tensor.device.type, self.device_type)
+ self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+ # construct a dist tensor on 3d device mesh with some shards on same dim
+ shard_spec = [Shard(0), Shard(0), Shard(2)]
+ local_tensor = torch.randn(3, 3, 3)
+ dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+ self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
+ self.assertEqual(dist_tensor.device.type, self.device_type)
+ self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+ @with_comms
+ def test_dtensor_spec_local_shard_offset(self):
+ device_mesh = DeviceMesh(
+ self.device_type, torch.arange(self.world_size).reshape(2, 4)
+ )
+ tensor_shape = (3 * self.world_size, 3 * self.world_size)
+ # sharding specs and its corresponding local shard offsets
+ shard_spec_and_offsets = [
+ (
+ [Shard(0), Replicate()],
+ (3 * (self.world_size // 2) * (self.rank // 4), 0),
+ ),
+ (
+ [Shard(1), Replicate()],
+ (0, 3 * (self.world_size // 2) * (self.rank // 4)),
+ ),
+ (
+ [Replicate(), Shard(0)],
+ (3 * (self.world_size // 4) * (self.rank % 4), 0),
+ ),
+ (
+ [Replicate(), Shard(1)],
+ (0, 3 * (self.world_size // 4) * (self.rank % 4)),
+ ),
+ ]
+
+ # loop through all sharding specs and check local shard offsets
+ logical_tensor = torch.randn(tensor_shape)
+ for shard_spec, expected_shard_offsets in shard_spec_and_offsets:
+ dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec)
+ self.assertEqual(
+ expected_shard_offsets, dtensor._spec.local_offsets
+ )
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py
new file mode 100644
index 0000000..78fc991
--- /dev/null
+++ b/test/distributed/_tensor/test_redistribute.py
@@ -0,0 +1,317 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import itertools
+import torch
+
+from torch.testing._internal.common_utils import run_tests
+
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ with_comms,
+)
+from torch.distributed._tensor import distribute_tensor, DeviceMesh, DTensor
+from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
+
+
+class RedistributeTest(DTensorTestBase):
+ @with_comms
+ def test_shard_to_replicate_forward_backward(self):
+ # 1) test shard -> replicate forward
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ replica_spec = [Replicate()]
+
+ input_sizes_and_shard_dim = [
+ ((self.world_size * 3, 3), 0),
+ ((self.world_size * 3 + 1, 3), 0),
+ ((self.world_size * 3 + 2, 3), 0),
+ ((3, self.world_size * 3), 1),
+ ((3, self.world_size * 3 + 1), 1),
+ ((3, self.world_size * 3 + 2), 1),
+ ]
+
+ for input_size, shard_dim in input_sizes_and_shard_dim:
+ shard_spec = [Shard(shard_dim)]
+ expected_tensor = torch.randn(
+ input_size, device=self.device_type, requires_grad=True
+ )
+ dtensor = distribute_tensor(
+ expected_tensor.clone(), device_mesh, shard_spec
+ )
+ reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec)
+ self.assertEqual(reshard_dtensor.size(), torch.Size(input_size))
+ self.assertEqual(expected_tensor, reshard_dtensor.to_local())
+
+ # 2) test shard -> replicate backward:
+ # should give gradient as shard
+ grad_output = torch.ones_like(reshard_dtensor)
+ reshard_dtensor.backward(grad_output)
+ grad_input = dtensor.grad
+ self.assertEqual(grad_input.placements, shard_spec)
+ self.assertEqual(
+ grad_input.to_local(), torch.ones(dtensor.to_local().size())
+ )
+
+ @with_comms
+ def test_replicate_to_replicate_forward_backward(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ replica_spec = [Replicate()]
+ local_tensor = torch.randn(
+ 12, 3, device=self.device_type, requires_grad=True
+ )
+ # 1) test replicate -> replicate forward
+ replica_tensor = distribute_tensor(
+ local_tensor, device_mesh, replica_spec
+ )
+ reshard_replica_tensor = replica_tensor.redistribute(
+ device_mesh, replica_spec
+ )
+ self.assertEqual(replica_tensor.size(), local_tensor.size())
+ self.assertEqual(replica_tensor, reshard_replica_tensor)
+
+ # 2) test replicate -> replicate backward:
+ # should give gradient as replicate
+ grad_output = torch.ones_like(reshard_replica_tensor)
+ reshard_replica_tensor.backward(grad_output)
+ grad_input = replica_tensor.grad
+ self.assertEqual(grad_input.placements, replica_spec)
+ self.assertEqual(grad_input.to_local(), torch.ones(12, 3))
+
+ @with_comms
+ def test_replicate_to_shard_forward_backward(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ replica_spec = [Replicate()]
+
+ input_sizes_and_shard_dim = [
+ ((self.world_size * 3, 3), 0),
+ ((self.world_size * 3 + 1, 3), 0),
+ ((self.world_size * 3 + 2, 3), 0),
+ ((3, self.world_size * 3), 1),
+ ((3, self.world_size * 3 + 1), 1),
+ ((3, self.world_size * 3 + 2), 1),
+ ]
+ for input_size, shard_dim in input_sizes_and_shard_dim:
+ shard_spec = [Shard(shard_dim)]
+ # 1) test replicate -> shard forward
+ local_replica = torch.randn(
+ input_size, device=self.device_type, requires_grad=True
+ )
+ splitted_list = local_replica.tensor_split(
+ self.world_size, shard_dim
+ )
+ # make local tensor as the element of the corresponding chunked list
+ local_tensor = splitted_list[self.rank]
+ replica_tensor = distribute_tensor(
+ local_replica, device_mesh, replica_spec
+ )
+ reshard_tensor = replica_tensor.redistribute(
+ device_mesh, shard_spec
+ )
+ self.assertEqual(reshard_tensor.size(), replica_tensor.size())
+ self.assertEqual(reshard_tensor.placements, shard_spec)
+ self.assertEqual(reshard_tensor.to_local(), local_tensor)
+
+ # 2) test replicate -> shard backward:
+ # should give gradient as replicate
+ grad_output = torch.ones_like(reshard_tensor)
+ reshard_tensor.backward(grad_output)
+ grad_input = replica_tensor.grad
+ self.assertEqual(grad_input.placements, replica_spec)
+ self.assertEqual(grad_input.to_local(), torch.ones(input_size))
+
+ @with_comms
+ def test_partial_to_replicate_forward_backward(self):
+ # Although we don't allow user to reshard to produce a partial
+ # placement (i.e. user can't reshard to partial), we do allow
+ # replicate to partial internally, and also partial to replicate
+ # backward should work as expected
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ partial_local = torch.randn(
+ 12, 3, device=self.device_type, requires_grad=True
+ )
+ partial_spec = [_Partial()]
+ replica_spec = [Replicate()]
+ # test partial -> replicate, which trigger all_reduce
+ partial_tensor = DTensor.from_local(
+ partial_local, device_mesh, partial_spec
+ )
+ global_partial_tensor = partial_tensor.redistribute(
+ device_mesh, replica_spec
+ )
+
+ self.assertEqual(partial_tensor.size(), partial_local.size())
+ self.assertEqual(
+ partial_local * self.world_size, global_partial_tensor.to_local()
+ )
+
+ # test backward to have replicate grad on partial
+ global_partial_tensor.backward(torch.ones_like(global_partial_tensor))
+ self.assertIsNotNone(partial_local.grad)
+ if device_mesh.get_rank() == 0:
+ self.assertEqual(partial_local.grad, torch.ones_like(partial_local))
+
+ @with_comms
+ def test_replicate_to_partial(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ local_tensor = torch.randn(
+ 12, 3, device=self.device_type, requires_grad=True
+ )
+ partial_spec = _Partial()
+ replica_spec = Replicate()
+ # 1) test replicate -> partial forward
+ replica_tensor = distribute_tensor(
+ local_tensor, device_mesh, [replica_spec]
+ )
+ with self.assertRaisesRegex(
+ RuntimeError, "Can not redistribute to _Partial"
+ ):
+ partial_tensor = replica_tensor.redistribute(
+ device_mesh, [partial_spec]
+ )
+
+ from torch.distributed._tensor.redistribute import Redistribute
+
+ partial_tensor = Redistribute.apply(
+ replica_tensor, device_mesh, [partial_spec]
+ )
+ self.assertEqual(partial_tensor.size(), local_tensor.size())
+ # test it successfully zero out the contents on other ranks
+ if self.rank == 0:
+ self.assertEqual(
+ replica_tensor.to_local(), partial_tensor.to_local()
+ )
+ else:
+ self.assertEqual(
+ partial_tensor.to_local(), torch.zeros_like(local_tensor)
+ )
+
+ # replicate to partial on sub groups
+ local_tensor = torch.randn(12, 3, device=self.device_type)
+ device_mesh = DeviceMesh(
+ self.device_type,
+ torch.arange(self.world_size).reshape(self.world_size // 2, 2),
+ )
+ # 1) test replicate -> partial on 2d-mesh subgroups
+ replica_tensor = distribute_tensor(
+ local_tensor, device_mesh, [replica_spec, replica_spec]
+ )
+ partial_tensor = Redistribute.apply(
+ replica_tensor, device_mesh, [partial_spec, partial_spec]
+ )
+ self.assertEqual(partial_tensor.size(), local_tensor.size())
+
+ if self.rank != 3:
+ # replicate to partial should only zero out rank 3, and leave
+ # rank 0/2 (rank0 on mesh dim 1) and 0, 1 (rank0 on mesh dim 1) un-touched
+ self.assertEqual(
+ replica_tensor.to_local(), partial_tensor.to_local()
+ )
+ else:
+ self.assertEqual(
+ replica_tensor.to_local(), torch.zeros_like(local_tensor)
+ )
+
+ @with_comms
+ def test_partial_to_shard(self):
+ device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+ partial_spec = [_Partial()]
+
+ input_sizes_and_shard_dim = [
+ ((self.world_size * 3, 3), 0),
+ ((self.world_size * 3 + 1, 3), 0),
+ ((self.world_size * 3 + 2, 3), 0),
+ ((3, self.world_size * 3), 1),
+ ((3, self.world_size * 3 + 1), 1),
+ ((3, self.world_size * 3 + 2), 1),
+ ]
+
+ for input_size, shard_dim in input_sizes_and_shard_dim:
+ shard_spec = [Shard(shard_dim)]
+
+ partial_local = torch.ones(input_size, device=self.device_type)
+ partial_tensor = DTensor.from_local(
+ partial_local, device_mesh, partial_spec, run_check=False
+ )
+
+ quot, rem = divmod(input_size[shard_dim], self.world_size)
+ local_shape = list(input_size)
+ local_shape[shard_dim] = quot + (1 if self.rank < rem else 0)
+ # test partial to shard, trigger reduce_scatter
+ scatter_shard_tensor = partial_tensor.redistribute(
+ device_mesh, shard_spec
+ )
+ self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size())
+ self.assertEqual(scatter_shard_tensor.placements, shard_spec)
+ self.assertEqual(
+ scatter_shard_tensor.to_local(),
+ torch.ones(local_shape) * self.world_size,
+ )
+
+
+class MultiDimRedistributeTest(DTensorTestBase):
+ @property
+ def world_size(self) -> int:
+ return 8
+
+ @with_comms
+ def test_multi_dim_mesh(self):
+ devices = torch.arange(self.world_size)
+ for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]:
+ mesh_shape = torch.arange(self.world_size).view(-1, 2)
+ device_mesh = DeviceMesh(self.device_type, mesh_shape)
+ tensor_shape = (16, 24)
+
+ if torch.distributed.get_rank() == 0:
+ full_tensor = torch.randn(*tensor_shape)
+ else:
+ # these should be entirely ignored
+ # because distribute_tensor is expected to override shards in ranks != 0
+ full_tensor = torch.ones(*tensor_shape)
+
+ possibilities = [Replicate()] + [
+ Shard(i) for i in range(full_tensor.ndim)
+ ]
+ all_outputs = list(
+ itertools.product(*(mesh_shape.ndim * [possibilities]))
+ )
+ all_inputs = list(
+ itertools.product(
+ *(mesh_shape.ndim * [possibilities + [_Partial()]])
+ )
+ )
+
+ for inputs in all_inputs:
+ # if partial, temporarily make it Replicated, then replace replicated with partial afterwards
+ repl_inputs = [
+ Replicate() if s.is_partial() else s for s in inputs
+ ]
+ dt = distribute_tensor(full_tensor, device_mesh, repl_inputs)
+
+ if repl_inputs != inputs:
+ # create a new DTensor reinterpreting some of the replicated entires as "Partial"
+ dt = DTensor.from_local(
+ dt.to_local(), device_mesh, inputs, run_check=False
+ )
+
+ for outputs in all_outputs:
+ # redistribute on target outputs
+ dt2 = dt.redistribute(device_mesh, outputs)
+
+ # replicate and then get first shard
+ local_full = dt2.redistribute(
+ device_mesh, device_mesh.ndim * [Replicate()]
+ ).to_local()
+
+ if torch.distributed.get_rank() == 0:
+ self.assertEqual(local_full.shape, full_tensor.shape)
+
+ num_sums = 1
+ for idx, input in enumerate(inputs):
+ if input.is_partial():
+ num_sums *= mesh_shape.size(idx)
+ expected = num_sums * full_tensor
+ self.assertEqual(local_full, expected)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/torch/testing/_internal/distributed/_tensor/__init__.py b/torch/testing/_internal/distributed/_tensor/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/__init__.py
diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
new file mode 100644
index 0000000..cf2abe0
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -0,0 +1,334 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from contextlib import contextmanager
+from dataclasses import dataclass
+import itertools
+import sys
+from functools import wraps
+from typing import (
+ Any,
+ Callable,
+ Generator,
+ Iterator,
+ Tuple,
+ Dict,
+ Optional,
+ List,
+ Sequence,
+ TypeVar,
+ cast,
+)
+
+import torch
+import torch.distributed as dist
+
+from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
+from torch.testing._internal.common_distributed import (
+ MultiProcessTestCase,
+ TEST_SKIPS,
+ skip_if_lt_x_gpu,
+)
+
+from torch.distributed._tensor import (
+ DeviceMesh,
+ Shard,
+ Replicate,
+ distribute_tensor,
+ redistribute,
+)
+from torch.distributed._tensor.api import DTensor
+from torch.distributed._tensor.placement_types import Placement
+
+DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
+NUM_DEVICES = 4
+
+# We use this as a proxy for "multiple GPUs exist"
+if torch.cuda.is_available() and torch.cuda.device_count() > 1:
+ # when we actually have multiple GPUs, relax the requirement to smaller counts.
+ NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count())
+
+T = TypeVar("T")
+
+
+def skip_unless_torch_gpu(method: T) -> T:
+ """
+ Test decorator which skips the test unless there's a GPU available to torch.
+
+ >>> @skip_unless_torch_gpu
+ >>> def test_some_method(self) -> None:
+ >>> ...
+ """
+ # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
+ return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
+
+
+@dataclass
+class RedistributeProfile:
+ num_calls: int
+
+
+@contextmanager
+def redistribute_profiler() -> Generator[RedistributeProfile, None, None]:
+
+ orig_redistribute_dtensor = redistribute.redistribute_dtensor
+ profile: RedistributeProfile = RedistributeProfile(num_calls=0)
+
+ # pyre-ignore[53]
+ def patched_redistribute_dtensor(
+ input: DTensor,
+ device_mesh: DeviceMesh,
+ placements: Sequence[Placement],
+ ) -> DTensor:
+ result = orig_redistribute_dtensor(input, device_mesh, placements)
+ profile.num_calls += 1
+ return result
+
+ try:
+ # pyre-ignore[9]
+ redistribute.redistribute_dtensor = patched_redistribute_dtensor
+ yield profile
+ finally:
+ redistribute.redistribute_dtensor = orig_redistribute_dtensor
+
+
+class DTensorTestBase(MultiProcessTestCase):
+ @property
+ def world_size(self) -> int:
+ return NUM_DEVICES
+
+ def build_device_mesh(self) -> DeviceMesh:
+ return DeviceMesh(DEVICE_TYPE, list(range(NUM_DEVICES)))
+
+ def init_pg(self, backend: str = "nccl") -> None:
+ if backend == "nccl" and torch.cuda.device_count() < self.world_size:
+ sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+ if backend not in ["nccl", "gloo", "mpi"]:
+ raise RuntimeError(f"Backend {backend} not supported!")
+
+ dist.init_process_group(
+ backend=backend,
+ world_size=self.world_size,
+ rank=self.rank, # pyre-ignore[16]
+ init_method=f"file://{self.file_name}", # pyre-ignore[16]
+ )
+
+ # set device for nccl pg for collectives
+ if backend == "nccl":
+ torch.cuda.set_device(self.rank)
+
+ def destroy_pg(self) -> None:
+ # Wait for all ranks to reach here before starting shutdown.
+ dist.barrier()
+ dist.destroy_process_group()
+
+ def setUp(self) -> None:
+ super().setUp()
+ self._spawn_processes()
+
+ # pyre-ignore[2]:
+ def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
+ with redistribute_profiler() as profile:
+ out = op_call(*args, **kwargs)
+ dtc = DTensorConverter(mesh, args, kwargs)
+ for d_args, d_kwargs in dtc:
+ # pyre can't find assertTrue anymore?
+ self.assertEqual(dtc.successful(), True)
+ d_out = op_call(*d_args, **d_kwargs)
+ self.assertEqual(
+ d_out.redistribute(
+ mesh, [Replicate()] * mesh.ndim
+ ).to_local(),
+ out,
+ )
+
+
+# wrapper to initialize comms (processgroup)
+def with_comms(
+ func: Optional[ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
+ Callable
+ ] = None,
+ backend: Optional[str] = None,
+) -> Optional[ # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
+ Callable
+]:
+ assert func is not None
+
+ @wraps(func) # pyre-ignore[6]
+ def wrapper(
+ self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc]
+ ) -> None:
+ # if backend not specified, and cuda available, then use nccl, else gloo
+ pg_backend = (
+ "nccl" if backend is None and torch.cuda.is_available() else "gloo"
+ )
+ if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size:
+ sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+ self.device_type = "cuda" if pg_backend == "nccl" else "cpu"
+ self.init_pg(backend=pg_backend)
+ func(self) # type: ignore[misc]
+ self.destroy_pg()
+
+ return wrapper
+
+
+# This is a class for converting args/kwargs of an op into distributed args/kwargs
+class DTensorConverter(object):
+ def __init__(
+ self,
+ mesh: DeviceMesh,
+ args: Tuple[object, ...],
+ kwargs: Dict[str, object],
+ ) -> None:
+ self.hit = 0
+ self.miss = 0
+ self.mesh = mesh
+ self.args = args
+ self.kwargs = kwargs
+ flatten_args, flatten_args_spec = tree_flatten(args)
+ flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
+
+ self.flatten_args: List[object] = flatten_args
+ self.flatten_args_spec: TreeSpec = flatten_args_spec
+ self.flatten_kwargs: List[object] = flatten_kwargs
+ self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
+
+ choices_for_args = []
+ for arg in self.flatten_args:
+ if isinstance(arg, torch.Tensor):
+ choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
+
+ for arg in self.flatten_kwargs:
+ if isinstance(arg, torch.Tensor):
+ choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
+
+ self.sharding_combs: Iterator[Sequence[Placement]] = iter(
+ itertools.product(*choices_for_args)
+ )
+
+ def successful(self) -> bool:
+ return self.hit > 0 and self.miss == 0
+
+ def is_supported_tensor(self, t: torch.Tensor) -> bool:
+ # TODO: dist tensor need to support quantized and sparse
+ # tensors, quantized tensor might be relatively easy, but
+ # sparse tensor have special layouts that we need to possibly
+ # deal with, until we are clear about them, we don't officially
+ # support them.
+ return not any(
+ [
+ t.is_sparse_csr,
+ t.is_sparse,
+ t.is_mkldnn,
+ t.is_quantized,
+ t.is_nested,
+ torch._is_functional_tensor(t),
+ t.is_neg(),
+ t.is_conj(),
+ t.device.type in ("lazy", "meta"),
+ # We need a way to test if a tensor is batched but there
+ # is no official APi to do it
+ # torch._C._is_batched(t),
+ ]
+ )
+
+ def gen_sharding_choices_for_arg(
+ self, arg: torch.Tensor
+ ) -> Sequence[Placement]:
+ mesh_size = self.mesh.size()
+ sharding_choices: List[Placement] = [Replicate()]
+ # c10d collective does not support bool tensor
+ # for bool tensor we treat it as replicated
+ if arg.dtype != torch.bool:
+ # only generating choices with: replicate, or sharding
+ # evenly on a dimension that could be sharded
+ sharding_choices = sharding_choices + [
+ Shard(i)
+ for i, s in enumerate(arg.shape)
+ if s > 1 and s % mesh_size == 0
+ ]
+ # TODO: add multi mesh choices
+ # all_choices = itertools.product(
+ # *(self.mesh.ndim * [sharding_choices])
+ # )
+ return sharding_choices
+
+ def __iter__(self) -> "DTensorConverter":
+ return self
+
+ def __next__(self) -> Tuple[Tuple[object, ...], Dict[str, object]]:
+ try:
+ next_sharding_choices = next(self.sharding_combs)
+ idx = 0
+
+ new_args: List[object] = []
+ for arg in self.flatten_args:
+ if isinstance(arg, torch.Tensor):
+ new_args.append(
+ self.to_dist_tensor(
+ arg, self.mesh, [next_sharding_choices[idx]]
+ )
+ )
+ idx += 1
+ else:
+ new_args.append(arg)
+
+ new_kwargs: List[object] = []
+ for arg in self.flatten_kwargs:
+ if isinstance(arg, torch.Tensor):
+ new_kwargs.append(
+ self.to_dist_tensor(
+ arg, self.mesh, [next_sharding_choices[idx]]
+ )
+ )
+ idx += 1
+ else:
+ new_kwargs.append(arg)
+
+ return (
+ tree_unflatten(new_args, self.flatten_args_spec),
+ tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
+ )
+ except StopIteration:
+ raise StopIteration
+
+ def to_dist_tensor(
+ self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
+ ) -> torch.Tensor:
+ if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
+ if self.is_supported_tensor(t):
+ self.hit += 1
+ # We cannot use distribute_tensor for bool tensors as c10d
+ # collectives does not support the dtype, we assume op with
+ # bool tensor args the same tensor so we don't need to broadcast
+ # TODO: add bool tensor dtype support in c10d collective
+ if t.dtype == torch.bool:
+ r = DTensor(
+ t,
+ mesh,
+ placements,
+ size=t.size(),
+ requires_grad=t.requires_grad,
+ )
+ else:
+ r = distribute_tensor(t, mesh, placements)
+ if type(t) is torch.nn.Parameter:
+ r = torch.nn.Parameter( # type: ignore[assignment]
+ r, requires_grad=r.requires_grad
+ )
+ return r
+ else:
+ self.miss += 1
+ return t
+ elif torch.overrides.is_tensor_like(t):
+ # Blindly converting tensor subclasses to dist tensor can cause
+ # unpredictable problems, we explicitly disable this conversion
+ # for now (i.e. we don't support DTensor holding tensor subclass
+ # until there's a strong reason later).
+ self.miss += 1
+ return t
+ else:
+ raise RuntimeError(
+ f"Trying to convert to DTensor, but got {type(t)}"
+ )