[dtensor] PART 5: move DTensor basic tests to core distributed (#88178)

This PR moves DTensor basic tests to torch.distributed, including
dtensor, device_mesh tests

part of https://github.com/pytorch/pytorch/issues/88838
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88178
Approved by: https://github.com/fduwjj
diff --git a/test/distributed/_tensor/README.md b/test/distributed/_tensor/README.md
new file mode 100644
index 0000000..6235f96
--- /dev/null
+++ b/test/distributed/_tensor/README.md
@@ -0,0 +1,11 @@
+## Run distributed tensor tests:
+
+from root, run (either CPU or GPU)
+
+`pytest test/spmd/tensor/test_tensor.py`
+
+`pytest test/spmd/tensor/test_ddp.py`
+
+run specific test case and print stdout/stderr:
+
+`pytest test/spmd/tensor/test_tensor.py -s -k test_tensor_from_local`
diff --git a/test/distributed/_tensor/__init__.py b/test/distributed/_tensor/__init__.py
new file mode 100644
index 0000000..087882b
--- /dev/null
+++ b/test/distributed/_tensor/__init__.py
@@ -0,0 +1 @@
+# shut up pylint
diff --git a/test/distributed/_tensor/test_api.py b/test/distributed/_tensor/test_api.py
new file mode 100644
index 0000000..a966f30
--- /dev/null
+++ b/test/distributed/_tensor/test_api.py
@@ -0,0 +1,234 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+import torch.nn as nn
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import DTensorTestBase, with_comms
+from torch.distributed._tensor import (
+    distribute_tensor,
+    distribute_module,
+    DeviceMesh,
+    DTensor,
+    Shard,
+    Replicate,
+)
+
+
+class MyModel(nn.Module):
+    def __init__(self, n_features, n_layers, device):
+        super().__init__()
+        self.seq = nn.Sequential(
+            *[
+                nn.Linear(n_features, n_features, device=device)
+                for _ in range(n_layers)
+            ]
+        )
+
+    def forward(self, x):
+        return self.seq(x)
+
+    def reset_parameters(self):
+        for m in self.seq:
+            m.reset_parameters()
+
+
+class DTensorAPITest(DTensorTestBase):
+    @property
+    def world_size(self) -> int:
+        # hard code world size to 4 as we need to test
+        # at least with 2d mesh
+        return 4
+
+    @with_comms
+    def test_distribute_tensor(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        for requires_grad in [True, False]:
+
+            tensor_to_shard = torch.randn(
+                3 * self.world_size, 3, requires_grad=requires_grad
+            )
+            dist_tensor = distribute_tensor(
+                tensor_to_shard, device_mesh, shard_spec
+            )
+            self.assertEqual(
+                dist_tensor.size(), torch.Size([3 * self.world_size, 3])
+            )
+            local_tensor = dist_tensor.to_local()
+            self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
+            if requires_grad:
+                self.assertTrue(dist_tensor.requires_grad)
+                self.assertTrue(dist_tensor.is_leaf)
+
+    @with_comms
+    def test_distribute_tensor_errors(self):
+        device_mesh = DeviceMesh(
+            self.device_type, torch.arange(self.world_size).reshape(2, 2)
+        )
+        tensor_shape = [3 * self.world_size, 3 * self.world_size]
+        tensor_to_distribute = torch.randn(*tensor_shape)
+
+        with self.assertRaisesRegex(ValueError, "must have the same length"):
+            shard_spec = [Shard(0)]
+            distribute_tensor(tensor_to_distribute, device_mesh, shard_spec)
+
+        spec = [Shard(0), Shard(1)]
+        dtensor = distribute_tensor(tensor_to_distribute, device_mesh, spec)
+
+        with self.assertRaisesRegex(ValueError, "to a different device mesh"):
+            new_mesh = DeviceMesh(
+                self.device_type, torch.arange(self.world_size)
+            )
+            distribute_tensor(dtensor, new_mesh, [Shard(0)])
+
+        with self.assertRaisesRegex(ValueError, "to a different placements"):
+            new_spec = [Shard(0), Replicate()]
+            distribute_tensor(dtensor, device_mesh, new_spec)
+
+    @with_comms
+    def test_distribute_tensor_uneven_sharding(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        input_sizes_and_shard_dims = [
+            ((self.world_size * 3 + 1, 3, 3), 0),
+            ((self.world_size * 3 + 2, 3, 3), 0),
+            ((3, self.world_size * 3 + 1, 3), 1),
+            ((3, self.world_size * 3 + 2, 3), 1),
+            ((3, 3, self.world_size * 3 + 1), 2),
+            ((3, 3, self.world_size * 3 + 2), 2),
+        ]
+        for input_size, shard_dim in input_sizes_and_shard_dims:
+            shard_spec = [Shard(shard_dim)]
+            tensor_to_shard = torch.randn(input_size)
+            splitted_tensor_list = tensor_to_shard.tensor_split(
+                self.world_size, dim=shard_dim
+            )
+            dist_tensor = distribute_tensor(
+                tensor_to_shard, device_mesh, shard_spec
+            )
+            self.assertEqual(dist_tensor.size(), torch.Size(input_size))
+            local_tensor = dist_tensor.to_local()
+            self.assertEqual(local_tensor, splitted_tensor_list[self.rank])
+
+    @with_comms
+    def test_distribute_module(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        # fully shard all linear modules on dim 0
+        module_to_shard = MyModel(
+            5 * self.world_size, 20, device=self.device_type
+        )
+        shard_spec = [Shard(0)]
+
+        def shard_fn(name, module, device_mesh):
+            if isinstance(module, nn.Linear):
+                for name, param in module.named_parameters():
+                    dist_param = torch.nn.Parameter(
+                        distribute_tensor(param, device_mesh, shard_spec)
+                    )
+                    module.register_parameter(name, dist_param)
+
+        sharded_module = distribute_module(
+            module_to_shard, device_mesh, shard_fn
+        )
+        for param in sharded_module.parameters():
+            self.assertIsInstance(param, DTensor)
+            self.assertEqual(param.placements, shard_spec)
+
+        replica_spec = [Replicate()]
+        # fully replicate all modules without passing in partition_fn
+        module_to_replicate = MyModel(5, 20, device=self.device_type)
+        replica_module = distribute_module(module_to_replicate, device_mesh)
+        for param in replica_module.parameters():
+            self.assertIsInstance(param, DTensor)
+            self.assertEqual(param.placements, replica_spec)
+
+        # fully replicate all modules by passing in partition_fn
+        def replicate_fn(name, module, device_mesh):
+            if isinstance(module, nn.Linear):
+                for name, param in module.named_parameters():
+                    dist_param = torch.nn.Parameter(
+                        distribute_tensor(param, device_mesh, replica_spec)
+                    )
+                    module.register_parameter(name, dist_param)
+
+        module_to_replicate = MyModel(5, 20, device=self.device_type)
+        replica_module = distribute_module(
+            module_to_replicate, device_mesh, replicate_fn
+        )
+        for param in replica_module.parameters():
+            self.assertIsInstance(param, DTensor)
+            self.assertEqual(param.placements, replica_spec)
+
+        # only shard part of module, and rest of module should be replicate
+        def shard_fn(name, module, device_mesh):
+            if isinstance(module, nn.Linear) and (
+                name == "seq.0" or name == "seq.8"
+            ):
+                for name, param in module.named_parameters():
+                    dist_param = torch.nn.Parameter(
+                        distribute_tensor(param, device_mesh, shard_spec)
+                    )
+                    module.register_parameter(name, dist_param)
+
+        module_to_distribute = MyModel(
+            5 * self.world_size, 20, device=self.device_type
+        )
+        dist_module = distribute_module(
+            module_to_distribute, device_mesh, shard_fn
+        )
+        for name, param in dist_module.named_parameters():
+            self.assertIsInstance(param, DTensor)
+            if name.startswith("seq.0") or name.startswith("seq.8"):
+                self.assertEqual(param.placements, shard_spec)
+            else:
+                self.assertEqual(param.placements, replica_spec)
+
+    @with_comms
+    def test_distribute_module_input_fn_output_fn(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+
+        # fully replicate all linear modules
+        module_to_replicate = MyModel(20, 1, device=self.device_type)
+
+        # mark input sharding on dim 0
+        def input_fn(inputs, device_mesh):
+            return DTensor.from_local(inputs[0], device_mesh, [Shard(0)])
+
+        def output_fn(outputs, device_mesh):
+            assert isinstance(outputs, DTensor)
+            return outputs.to_local()
+
+        replica_module = distribute_module(
+            module_to_replicate,
+            device_mesh,
+            input_fn=input_fn,
+            output_fn=output_fn,
+        )
+
+        input_tensor = torch.randn(5, 20, device=self.device_type)
+        local_out = replica_module(input_tensor)
+        self.assertIsInstance(local_out, torch.Tensor)
+        self.assertNotIsInstance(local_out, DTensor)
+
+        # full replicate (even on inputs)
+        model = MyModel(10, 10, device=self.device_type)
+
+        def replicate_input_fn(inputs, device_mesh):
+            return DTensor.from_local(inputs[0], device_mesh, [Replicate()])
+
+        replica_model = distribute_module(
+            model,
+            device_mesh,
+            input_fn=replicate_input_fn,
+        )
+        input = torch.randn(10, 10, requires_grad=True)
+        output = replica_model(input)
+        output.sum().backward()
+        param_grad = list(replica_model.parameters())[0].grad
+        self.assertTrue(isinstance(param_grad, DTensor))
+        self.assertTrue(isinstance(param_grad.placements[0], Replicate))
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/distributed/_tensor/test_device_mesh.py b/test/distributed/_tensor/test_device_mesh.py
new file mode 100644
index 0000000..7088f33
--- /dev/null
+++ b/test/distributed/_tensor/test_device_mesh.py
@@ -0,0 +1,518 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+
+from torch.distributed.distributed_c10d import (
+    ProcessGroup,
+    new_group,
+    get_global_rank,
+    get_world_size,
+)
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorTestBase,
+    with_comms,
+)
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from torch.distributed._tensor.placement_types import Shard
+
+
+class DeviceMeshTest(DTensorTestBase):
+    @property
+    def world_size(self):
+        return 8
+
+    @with_comms
+    def test_device_mesh_2d(self):
+        mesh_tensor = torch.arange(4).reshape(2, 2)
+        # construct a cuda device mesh
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+
+        expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            self.assertTrue(dim < 2)
+            dim_ranks = expected_ranks_by_dim[dim]
+
+            dim_group_size = get_world_size(dim_group)
+            self.assertIsInstance(dim_group, ProcessGroup)
+            self.assertEqual(dim_group_size, 2)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            current_rank_expected_group_ranks = (
+                dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
+            )
+            self.assertEqual(global_ranks, current_rank_expected_group_ranks)
+
+    @with_comms
+    def test_device_mesh_2d_from_dim_groups(self):
+        # construct a two dimension subgroups
+        dim_groups = []
+        expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
+        for dim_group_ranks in expected_ranks_by_dim:
+            for subgroup_ranks in dim_group_ranks:
+                subgroup = new_group(ranks=subgroup_ranks)
+                if self.rank in subgroup_ranks:
+                    dim_groups.append(subgroup)
+
+        # construct a device mesh from the subgroups
+        mesh = DeviceMesh(
+            self.device_type, [[0, 1], [2, 3]], dim_groups=dim_groups
+        )
+
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            self.assertTrue(dim < 2)
+            dim_ranks = expected_ranks_by_dim[dim]
+
+            dim_group_size = get_world_size(dim_group)
+            self.assertIsInstance(dim_group, ProcessGroup)
+            self.assertEqual(dim_group_size, 2)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            current_rank_expected_group_ranks = (
+                dim_ranks[0] if self.rank in dim_ranks[0] else dim_ranks[1]
+            )
+            self.assertEqual(global_ranks, current_rank_expected_group_ranks)
+
+    @with_comms
+    def test_device_mesh_dim_groups_error(self):
+        # construct a two dimension subgroups
+        dim_groups = []
+        expected_ranks_by_dim = [[[0, 2], [1, 3]], [[0, 1], [2, 3]]]
+        for dim_group_ranks in expected_ranks_by_dim:
+            for subgroup_ranks in dim_group_ranks:
+                subgroup = new_group(ranks=subgroup_ranks)
+                if self.rank in subgroup_ranks:
+                    dim_groups.append(subgroup)
+
+        if len(dim_groups) > 0:
+            # dim_groups is not a list
+            self.assertRaises(
+                RuntimeError,
+                DeviceMesh,
+                self.device_type,
+                [[0, 1], [2, 3]],
+                dim_groups=dim_groups[0],
+            )
+
+            # dim_groups is a list, but not a list of ProcessGroup
+            self.assertRaises(
+                RuntimeError,
+                DeviceMesh,
+                self.device_type,
+                [[0, 1], [2, 3]],
+                dim_groups=[dim_groups[0], "dummy"],
+            )
+
+            # dim_groups has incorrect length
+            self.assertRaises(
+                RuntimeError,
+                DeviceMesh,
+                self.device_type,
+                [[0, 1], [2, 3]],
+                dim_groups=[dim_groups[0]],
+            )
+
+    @with_comms
+    def test_device_mesh_nd(self):
+        # construct a cuda device mesh
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            self.assertTrue(dim < mesh_tensor.ndim)
+            dim_ranks = mesh_tensor.swapdims(-1, dim).reshape(-1, 2)
+            # print(dim_ranks)
+            # dim_ranks = expected_ranks_by_dim[dim]
+
+            dim_group_size = get_world_size(dim_group)
+            self.assertIsInstance(dim_group, ProcessGroup)
+            self.assertEqual(dim_group_size, 2)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            for ranks in dim_ranks:
+                if self.rank in ranks:
+                    self.assertEqual(global_ranks, ranks.tolist())
+
+
+class DeviceMeshCollectiveTest(DTensorTestBase):
+    @property
+    def world_size(self):
+        return 8
+
+    @with_comms
+    def test_all_reduce_1d(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+        mesh.all_reduce(local_tensor, mesh_dim=0)
+        res_num = ((0 + self.world_size - 1) * self.world_size) / 2
+        self.assertEqual(local_tensor, torch.ones(3, 3) * res_num)
+
+    @with_comms
+    def test_broadcast_1d(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+        mesh.broadcast(local_tensor, mesh_dim=0)
+        self.assertEqual(local_tensor, torch.zeros(3, 3))
+
+    @with_comms
+    def test_scatter_1d(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        scatter_tensor_shape = [3, 3, 3]
+        for scatter_dim in range(len(scatter_tensor_shape)):
+            shard_placement = Shard(scatter_dim)
+            scatter_tensor_shape[scatter_dim] *= self.world_size
+            # make the random seed same across rank
+            torch.manual_seed(0)
+            global_tensor = torch.randn(
+                scatter_tensor_shape, device=self.device_type
+            )
+            splitted_list, _ = shard_placement._split_tensor(
+                global_tensor, mesh.size(), with_padding=True, contiguous=True
+            )
+            recv_tensor = torch.empty_like(splitted_list[mesh.get_rank()])
+            # scatter on dim > 0 would generate non-contiguous tensor, verify that works
+            mesh.scatter(recv_tensor, splitted_list, mesh_dim=0)
+            self.assertEqual(recv_tensor, splitted_list[mesh.get_rank()])
+
+    @with_comms
+    def test_scatter_uneven(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        my_rank = device_mesh.get_rank()
+        tensor_to_split = torch.randn(
+            device_mesh.size() + 3, device_mesh.size() + 1
+        )
+
+        for shard_dim in range(tensor_to_split.ndim):
+            shard_placement = Shard(shard_dim)
+            tensor_to_scatter = tensor_to_split.clone()
+            tensor_splitted_list = tensor_to_split.tensor_split(
+                device_mesh.size(), dim=shard_dim
+            )
+            padded_tensor_list, pad_idx = shard_placement._split_tensor(
+                tensor_to_scatter,
+                device_mesh.size(),
+                with_padding=True,
+                contiguous=True,
+            )
+
+            scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
+            device_mesh.scatter(
+                scattered_tensor, padded_tensor_list, mesh_dim=0
+            )
+            # unpad scattered_tensor
+            if pad_idx != 0 and my_rank >= pad_idx:
+                scattered_tensor = shard_placement._unpad_tensor(
+                    scattered_tensor
+                )
+
+            self.assertEqual(
+                scattered_tensor.size(), tensor_splitted_list[my_rank].size()
+            )
+            self.assertEqual(scattered_tensor, tensor_splitted_list[my_rank])
+
+    @with_comms
+    def test_all_gather_1d(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        dims_to_gather = [0, 1]
+        for dim in dims_to_gather:
+            output_size = [3, 3]
+            output_size[dim] *= self.world_size
+            # each rank have its own tensor, all_gather gives a list
+            local_tensor = torch.ones(3, 3, device=self.device_type)
+            gathered_list = []
+            for _ in range(self.world_size):
+                gathered_list.append(torch.zeros_like(local_tensor))
+            mesh.all_gather(gathered_list, local_tensor, mesh_dim=0)
+            gathered_tensor = torch.cat(gathered_list, dim=dim)
+            self.assertEqual(gathered_tensor, torch.ones(output_size))
+
+    @with_comms
+    def test_all_gather_uneven(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        my_rank = device_mesh.get_rank()
+        tensor_to_split = torch.ones(
+            device_mesh.size() + 3,
+            device_mesh.size() + 1,
+            device=self.device_type,
+        )
+
+        for shard_dim in range(tensor_to_split.ndim):
+            shard_placement = Shard(shard_dim)
+            tensor_padded_list, pad_idx = shard_placement._split_tensor(
+                tensor_to_split,
+                device_mesh.size(),
+                with_padding=True,
+                contiguous=True,
+            )
+            local_tensor = tensor_padded_list[my_rank]
+            gathered_list = []
+            for _ in range(device_mesh.size()):
+                gathered_list.append(torch.empty_like(local_tensor))
+
+            device_mesh.all_gather(
+                gathered_list,
+                local_tensor,
+                mesh_dim=0,
+            )
+            if pad_idx != 0:
+                gathered_list = [
+                    shard_placement._unpad_tensor(gathered_tensor)
+                    if i >= pad_idx
+                    else gathered_tensor
+                    for i, gathered_tensor in enumerate(gathered_list)
+                ]
+            all_gathered_tensor = torch.cat(gathered_list, dim=shard_dim)
+            self.assertEqual(all_gathered_tensor.size(), tensor_to_split.size())
+            self.assertEqual(all_gathered_tensor, tensor_to_split)
+
+    @with_comms
+    def test_reduce_scatter_1d(self):
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        dims_to_scatter = [0, 1]
+        for dim in dims_to_scatter:
+            input_size = [3, 3]
+            scattered_tensor = torch.empty(input_size, device=self.device_type)
+            input_size[dim] *= self.world_size
+
+            input_rs_list = (
+                torch.ones(input_size, device=self.device_type) * self.rank
+            ).tensor_split(self.world_size, dim=dim)
+            res_num = ((0 + self.world_size - 1) * self.world_size) / 2
+            mesh.reduce_scatter(scattered_tensor, input_rs_list, mesh_dim=0)
+            self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num)
+
+    @with_comms
+    def test_reduce_scatter_uneven(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        my_rank = device_mesh.get_rank()
+        tensor_to_split = (
+            torch.ones(
+                device_mesh.size() + 3,
+                device_mesh.size() + 1,
+                device=self.device_type,
+            )
+            * self.rank
+        )
+
+        for shard_dim in range(tensor_to_split.ndim):
+            shard_placement = Shard(shard_dim)
+            tensor_to_scatter = tensor_to_split.clone()
+            tensor_splitted_list = tensor_to_split.tensor_split(
+                device_mesh.size(), dim=shard_dim
+            )
+            padded_tensor_list, pad_idx = shard_placement._split_tensor(
+                tensor_to_scatter,
+                device_mesh.size(),
+                with_padding=True,
+                contiguous=True,
+            )
+
+            res_num = ((0 + self.world_size - 1) * self.world_size) / 2
+            scattered_tensor = torch.empty_like(padded_tensor_list[my_rank])
+            device_mesh.reduce_scatter(
+                scattered_tensor, padded_tensor_list, mesh_dim=0
+            )
+            # unpad scattered_tensor
+            if pad_idx != 0 and my_rank >= pad_idx:
+                scattered_tensor = shard_placement._unpad_tensor(
+                    scattered_tensor
+                )
+
+            self.assertEqual(
+                scattered_tensor.size(), tensor_splitted_list[my_rank].size()
+            )
+            self.assertEqual(
+                scattered_tensor,
+                torch.ones_like(tensor_splitted_list[my_rank]) * res_num,
+            )
+
+    @with_comms
+    def test_all_gather_nd(self):
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            dim_group_size = get_world_size(dim_group)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            gathered_tensor_list = list(
+                torch.empty(
+                    (dim_group_size * 3, 3), device=self.device_type
+                ).tensor_split(dim_group_size, dim=0)
+            )
+            mesh.all_gather(gathered_tensor_list, local_tensor, mesh_dim=dim)
+            gathered_tensor = torch.cat(gathered_tensor_list)
+            exp_tensor = torch.ones(3 * dim_group_size, 3)
+            for i in range(len(global_ranks)):
+                exp_tensor[i * 3 : (i + 1) * 3] = (
+                    torch.ones(3, 3) * global_ranks[i]
+                )
+            self.assertEqual(gathered_tensor, exp_tensor)
+
+    @with_comms
+    def test_reduce_scatter_nd(self):
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            dim_group_size = get_world_size(dim_group)
+            local_rs_list = (
+                torch.ones(dim_group_size * 3, 3, device=self.device_type)
+                * self.rank
+            ).tensor_split(dim_group_size, dim=0)
+            scattered_tensor = torch.empty_like(
+                local_rs_list[mesh.get_coordinate_on_dim(dim)],
+                device=self.device_type,
+            )
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            mesh.reduce_scatter(scattered_tensor, local_rs_list, mesh_dim=dim)
+            res_num = torch.sum(torch.tensor(global_ranks))
+            self.assertEqual(scattered_tensor, torch.ones(3, 3) * res_num)
+
+    @with_comms
+    def test_all_reduce_nd(self):
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            dim_group_size = get_world_size(dim_group)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            cloned_local_tensor = local_tensor.clone()
+            mesh.all_reduce(cloned_local_tensor, mesh_dim=dim)
+            res_num = sum(global_ranks)
+            self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
+
+    @with_comms
+    def test_broadcast_nd(self):
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+        local_tensor = torch.ones(3, 3, device=self.device_type) * self.rank
+
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            dim_group_size = get_world_size(dim_group)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            cloned_local_tensor = local_tensor.clone()
+            mesh.broadcast(cloned_local_tensor, mesh_dim=dim)
+            res_num = global_ranks[0]
+            self.assertEqual(cloned_local_tensor, torch.ones(3, 3) * res_num)
+
+    @with_comms
+    def test_scatter_nd(self):
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            dim_group_size = get_world_size(dim_group)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            scattered_tensors = [
+                torch.ones(3, 3, device=self.device_type) * global_rank
+                for global_rank in global_ranks
+            ]
+            received_tensor = torch.empty_like(
+                scattered_tensors[mesh.get_coordinate_on_dim(dim)]
+            )
+            mesh.scatter(received_tensor, scattered_tensors, mesh_dim=dim)
+            self.assertEqual(received_tensor, torch.ones(3, 3) * self.rank)
+
+    @with_comms
+    def test_all_to_all_1d(self):
+        # transpose on a 2D tensor distributed over N nodes:
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+        tensor_shape = [3, 3]
+        input_tensor_list = [
+            torch.ones(*tensor_shape, device=self.device_type)
+            * (rank + self.rank * self.world_size)
+            for rank in range(self.world_size)
+        ]
+        expected_tensor_list = [
+            torch.ones(tensor_shape, device=self.device_type)
+            * (self.rank + rank * self.world_size)  # i.e. transpose
+            for rank in range(self.world_size)
+        ]
+        for scatter_dim in range(len(tensor_shape)):
+            output_tensor_list = [
+                torch.empty_like(input_tensor_list[idx])
+                for idx in range(len(input_tensor_list))
+            ]
+            # scatter on dim > 0 would generate non-contiguous tensor, verify that works
+            mesh.all_to_all(output_tensor_list, input_tensor_list, mesh_dim=0)
+            output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
+            expected_tensor = torch.cat(expected_tensor_list, dim=scatter_dim)
+
+            self.assertEqual(output_tensor, expected_tensor)
+
+    @with_comms
+    def test_all_to_all_nd(self):
+        mesh_tensor = torch.arange(8).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+        tensor_shape = [3, 3, 3]
+        # check all dim groups
+        dim_to_subgroups = mesh.get_dim_groups()
+        for dim, dim_group in enumerate(dim_to_subgroups):
+            my_coordinate = mesh.get_coordinate_on_dim(dim)
+            dim_group_size = get_world_size(dim_group)
+            global_ranks = [
+                get_global_rank(dim_group, i) for i in range(dim_group_size)
+            ]
+            input_tensor_list = [
+                torch.ones(*tensor_shape, device=self.device_type)
+                * (i + self.rank * dim_group_size)
+                for i in range(dim_group_size)
+            ]
+            expected_tensor_list = [
+                torch.ones(*tensor_shape, device=self.device_type)
+                * (
+                    my_coordinate + global_rank * dim_group_size
+                )  # i.e. transpose
+                for global_rank in global_ranks
+            ]
+            for scatter_dim in range(len(tensor_shape)):
+                # input_tensor = torch.cat(input_tensor_list, dim=scatter_dim)
+                output_tensor_list = [
+                    torch.empty_like(input_tensor_list[idx])
+                    for idx in range(len(input_tensor_list))
+                ]
+                # scatter on dim > 0 would generate non-contiguous tensor, verify that works
+                mesh.all_to_all(
+                    output_tensor_list, input_tensor_list, mesh_dim=dim
+                )
+                output_tensor = torch.cat(output_tensor_list, dim=scatter_dim)
+                expected_tensor = torch.cat(
+                    expected_tensor_list, dim=scatter_dim
+                )
+                self.assertEqual(output_tensor, expected_tensor)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/distributed/_tensor/test_dtensor.py b/test/distributed/_tensor/test_dtensor.py
new file mode 100644
index 0000000..51ce1bd
--- /dev/null
+++ b/test/distributed/_tensor/test_dtensor.py
@@ -0,0 +1,359 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import torch
+
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorTestBase,
+    with_comms,
+)
+from torch.distributed._tensor import DeviceMesh, DTensor, distribute_tensor
+from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
+
+
+class DTensorTest(DTensorTestBase):
+    # @with_comms
+    # def test_tensor_constructor(self):
+    #     import torch.distributed._tensor as dist_tensor
+    #     shard_spec = PlacementSpec(device_mesh, strategies=[Shard(0)])
+    #     empty_tensor = dist_tensor.empty((12, 10), placement_spec=shard_spec)
+    #     zero_tensor = dist_tensor.zeros((12, 10), placement_spec=shard_spec)
+    #     one_tensor = dist_tensor.ones((12, 10), placement_spec=shard_spec)
+
+    #     zero_cuda_tensor = dist_tensor.zeros((12, 10), device="cuda", placement_spec=shard_spec)
+
+    #     dist_tensor.empty_like(empty_tensor)
+    #     dist_tensor.zero_like(empty_tensor)
+    #     dist_tensor.one_like(empty_tensor)
+
+    @with_comms
+    def test_dtensor_constructor(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+        local_tensor = torch.randn(3, 3, requires_grad=True)
+        dist_tensor_shape = torch.Size([self.world_size * 3, 3])
+        dist_tensor = DTensor(
+            local_tensor,
+            device_mesh,
+            shard_spec,
+            size=dist_tensor_shape,
+            requires_grad=True,
+        )
+        self.assertEqual(
+            dist_tensor.size(), torch.Size((self.world_size * 3, 3))
+        )
+
+        with self.assertWarnsRegex(UserWarning, "To construct"):
+            DTensor(
+                local_tensor, device_mesh, shard_spec, size=dist_tensor_shape
+            )
+
+        local_tensor = torch.randn(3, 3, requires_grad=False)
+        with self.assertWarnsRegex(UserWarning, "To construct"):
+            dist_tensor = DTensor(
+                local_tensor,
+                device_mesh,
+                shard_spec,
+                size=dist_tensor_shape,
+                requires_grad=True,
+            )
+
+    @with_comms
+    def test_dtensor_stride(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard0_spec = [Shard(0)]
+        local_tensor = torch.randn(4, 8)
+        global_shape = torch.Size([self.world_size * 4, 8])
+        dist_tensor = DTensor(
+            local_tensor, device_mesh, shard0_spec, size=global_shape
+        )
+        # won't affect stride
+        self.assertEqual(dist_tensor.stride(), (8, 1))
+
+        shard1_spec = [Shard(1)]
+        local_tensor = torch.randn(8, 4)
+        global_shape = torch.Size([8, self.world_size * 4])
+        dist_tensor = DTensor(
+            local_tensor, device_mesh, shard1_spec, size=global_shape
+        )
+        # will affect stride after DT initialized
+        self.assertEqual(dist_tensor.stride(), (4 * self.world_size, 1))
+
+        # if initialized from a transposed mat
+        local_tensor = torch.randn(8, 4, 8)
+        local_tensor_t = local_tensor.permute(1, 2, 0)
+        global_shape = torch.Size([4, self.world_size * 8, 8])
+        self.assertEqual(local_tensor_t.stride(), (8, 1, 32))
+        dist_tensor = DTensor(
+            local_tensor_t, device_mesh, shard1_spec, size=global_shape
+        )
+        global_stride = (8 * self.world_size, 1, 32 * self.world_size)
+        self.assertEqual(dist_tensor.stride(), global_stride)
+
+    @with_comms
+    def test_from_local(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+        local_tensor = torch.randn(3, 3)
+        sharded_tensor = DTensor.from_local(
+            local_tensor, device_mesh, shard_spec
+        )
+        self.assertEqual(
+            sharded_tensor.size(), torch.Size([self.world_size * 3, 3])
+        )
+
+        replica_spec = [Replicate()]
+        ddp_tensor = DTensor.from_local(local_tensor, device_mesh, replica_spec)
+        self.assertEqual(ddp_tensor.size(), local_tensor.size())
+
+        partial_spec = [_Partial()]
+        partial_tensor = DTensor.from_local(
+            local_tensor, device_mesh, partial_spec
+        )
+        self.assertEqual(partial_tensor.size(), local_tensor.size())
+
+        # test dist tensor works with torch.Tensor during backwards
+        local_tensor_with_grad = torch.randn(3, 3, requires_grad=True)
+        # do some operations on local tensor
+        local_tensor_temp = local_tensor_with_grad * 3
+        # create the dist tensor with non leaf local tensor, dist tensor created
+        # should also be non leaf node
+        dist_tensor = DTensor.from_local(
+            local_tensor_temp, device_mesh, shard_spec
+        )
+        self.assertFalse(dist_tensor.is_leaf)
+        # do some random operations on dist tensor
+        output = dist_tensor * 3
+        self.assertIsInstance(output, DTensor)
+        # trigger .backward() on dist tensor directly
+        local_grad = torch.ones(3, 3)
+        grad_output = DTensor.from_local(local_grad, device_mesh, shard_spec)
+        # run backward directly on dist tensor
+        output.backward(grad_output)
+        # check it gradients flow back to original torch.Tensor
+        self.assertIsNotNone(local_tensor_with_grad.grad)
+        expected_grad = torch.ones(3, 3) * 9
+        self.assertEqual(local_tensor_with_grad.grad, expected_grad)
+
+    @with_comms
+    def test_to_local(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+        dist_tensor_shape = torch.Size([self.world_size * 3, 3])
+        local_tensor_with_grad = torch.randn(
+            3, 3, device=self.device_type, requires_grad=True
+        )
+
+        sharded_tensor = DTensor(
+            local_tensor_with_grad,
+            device_mesh,
+            shard_spec,
+            size=dist_tensor_shape,
+            requires_grad=True,
+        )
+        self.assertEqual(sharded_tensor.size(), dist_tensor_shape)
+        self.assertEqual(sharded_tensor.to_local(), local_tensor_with_grad)
+
+        # test dist tensor works with torch.Tensor during backwards
+        # dist tensor created is a leaf node, do some operation on dist tensor
+        temp_st = sharded_tensor * 3
+
+        # do some operation on local tensor of the dist tensor
+        new_tensor_with_grad = torch.randn(
+            3, 3, device=self.device_type, requires_grad=True
+        )
+        res = temp_st.to_local() + new_tensor_with_grad
+        # call backward directly on torch.Tensor, and see if it works by
+        # propagating through dist tensor
+        res.sum().backward()
+        self.assertIsNotNone(sharded_tensor.grad)
+
+        self.assertEqual(sharded_tensor.grad.to_local(), torch.ones(3, 3) * 3)
+
+    @with_comms
+    def test_from_local_then_to_local(self):
+        # this test ensure end to end from torch.Tensor -> dist tensor -> torch.Tensor works
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+
+        # step 1. construct from construct local tensor
+        local_tensor_with_grad = torch.randn(
+            3, 3, device=self.device_type, requires_grad=True
+        )
+        # do some operations on local tensor
+        local_tensor_temp = local_tensor_with_grad + 8
+        # step 2. create the dist tensor with non leaf local tensor, dist tensor
+        # created should also be non leaf node
+        dist_tensor = DTensor.from_local(
+            local_tensor_temp, device_mesh, shard_spec
+        )
+        self.assertFalse(dist_tensor.is_leaf)
+        # do some random operations on dist tensor
+        output = dist_tensor * 6
+        self.assertIsInstance(output, DTensor)
+
+        # step 3. do some operation on local tensor of the dist tensor
+        new_tensor_with_grad = torch.randn(
+            3, 3, device=self.device_type, requires_grad=True
+        )
+        res = output.to_local() + new_tensor_with_grad
+        # call backward directly on torch.Tensor, and see if it works by
+        # propagating all the way back to the original torch.Tensor
+        res.sum().backward()
+        self.assertIsNotNone(local_tensor_with_grad.grad)
+
+        expected_grad = torch.ones(3, 3) * 6
+        self.assertEqual(local_tensor_with_grad.grad, expected_grad)
+
+    @with_comms
+    def test_dtensor_spec_read_only_after_set(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+        local_tensor = torch.randn(3, 3)
+        sharded_tensor = DTensor.from_local(
+            local_tensor, device_mesh, shard_spec
+        )
+
+        # modify shard_spec, and dist_tensor's spec should not be changed
+        shard_spec[0] = Replicate()
+        self.assertTrue(sharded_tensor.placements is not shard_spec)
+        self.assertNotEqual(sharded_tensor.placements, shard_spec)
+
+    @with_comms
+    def test_dtensor_properties(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        shard_spec = [Shard(0)]
+        local_tensor = torch.randn(3, 3)
+        sharded_tensor = DTensor.from_local(
+            local_tensor, device_mesh, shard_spec
+        )
+        self.assertEqual(sharded_tensor.device.type, self.device_type)
+
+
+class DTensorMeshTest(DTensorTestBase):
+    @property
+    def world_size(self):
+        return 8
+
+    @with_comms
+    def test_dtensor_device_mesh_device_conversion(self):
+        # construct a cuda device mesh
+        mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
+
+        # construct from a cpu local tensor with cuda device mesh
+        # should automatically convert the dist tensor to cuda
+        shard_spec = [Shard(0)]
+        local_tensor = torch.randn(3, 3)
+        dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+        self.assertEqual(dist_tensor.device.type, self.device_type)
+        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+    @with_comms
+    def test_dtensor_api_device_mesh_context_manager(self):
+        with DeviceMesh(self.device_type, list(range(self.world_size))) as mesh:
+            shard_spec = [Shard(0)]
+            local_tensor = torch.randn(3, 3)
+            sharded_tensor = DTensor.from_local(
+                local_tensor, device_mesh=mesh, placements=shard_spec
+            )
+
+        with DeviceMesh(self.device_type, list(range(self.world_size))):
+            shard_spec = [Shard(0)]
+            local_tensor = torch.randn(3, 3)
+            sharded_tensor = DTensor.from_local(
+                local_tensor, placements=shard_spec
+            )
+            replica_spec = [Replicate()]
+            replica_tensor = sharded_tensor.redistribute(
+                placements=replica_spec
+            )
+            self.assertEqual(
+                replica_tensor.size(), torch.Size([3 * self.world_size, 3])
+            )
+
+    @with_comms
+    def test_dtensor_2d_mesh(self):
+        mesh_tensor = torch.arange(self.world_size).reshape(2, 4)
+        # construct a cuda device mesh
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+
+        # construct a dist tensor on 2d device mesh and test if works
+        shard_spec = [Shard(0), Shard(1)]
+        local_tensor = torch.randn(3, 3)
+        dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+        self.assertEqual(
+            dist_tensor.size(), torch.Size([3 * mesh.size(0), 3 * mesh.size(1)])
+        )
+        self.assertEqual(dist_tensor.device.type, self.device_type)
+        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+        # if shard on the same tensor dimension
+        # we should correctly construct the global tensor size
+        shard_same_dim_spec = [Shard(0), Shard(0)]
+        local_tensor = torch.randn(3, 3)
+        dist_tensor = DTensor.from_local(
+            local_tensor, mesh, shard_same_dim_spec
+        )
+        self.assertEqual(
+            dist_tensor.size(), torch.Size([3 * self.world_size, 3])
+        )
+
+    @with_comms
+    def test_device_mesh_nd(self):
+        # construct a cuda device mesh
+        mesh_tensor = torch.arange(self.world_size).reshape(2, 2, 2)
+        mesh = DeviceMesh(self.device_type, mesh_tensor)
+        # construct a dist tensor on 3d device mesh and test if works
+        shard_spec = [Shard(0), Shard(1), Shard(2)]
+        local_tensor = torch.randn(3, 3, 3)
+        dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+        self.assertEqual(dist_tensor.size(), torch.Size([6, 6, 6]))
+        self.assertEqual(dist_tensor.device.type, self.device_type)
+        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+        # construct a dist tensor on 3d device mesh with some shards on same dim
+        shard_spec = [Shard(0), Shard(0), Shard(2)]
+        local_tensor = torch.randn(3, 3, 3)
+        dist_tensor = DTensor.from_local(local_tensor, mesh, shard_spec)
+        self.assertEqual(dist_tensor.size(), torch.Size([12, 3, 6]))
+        self.assertEqual(dist_tensor.device.type, self.device_type)
+        self.assertEqual(dist_tensor.to_local().device.type, self.device_type)
+
+    @with_comms
+    def test_dtensor_spec_local_shard_offset(self):
+        device_mesh = DeviceMesh(
+            self.device_type, torch.arange(self.world_size).reshape(2, 4)
+        )
+        tensor_shape = (3 * self.world_size, 3 * self.world_size)
+        # sharding specs and its corresponding local shard offsets
+        shard_spec_and_offsets = [
+            (
+                [Shard(0), Replicate()],
+                (3 * (self.world_size // 2) * (self.rank // 4), 0),
+            ),
+            (
+                [Shard(1), Replicate()],
+                (0, 3 * (self.world_size // 2) * (self.rank // 4)),
+            ),
+            (
+                [Replicate(), Shard(0)],
+                (3 * (self.world_size // 4) * (self.rank % 4), 0),
+            ),
+            (
+                [Replicate(), Shard(1)],
+                (0, 3 * (self.world_size // 4) * (self.rank % 4)),
+            ),
+        ]
+
+        # loop through all sharding specs and check local shard offsets
+        logical_tensor = torch.randn(tensor_shape)
+        for shard_spec, expected_shard_offsets in shard_spec_and_offsets:
+            dtensor = distribute_tensor(logical_tensor, device_mesh, shard_spec)
+            self.assertEqual(
+                expected_shard_offsets, dtensor._spec.local_offsets
+            )
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/distributed/_tensor/test_redistribute.py b/test/distributed/_tensor/test_redistribute.py
new file mode 100644
index 0000000..78fc991
--- /dev/null
+++ b/test/distributed/_tensor/test_redistribute.py
@@ -0,0 +1,317 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+import itertools
+import torch
+
+from torch.testing._internal.common_utils import run_tests
+
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+    DTensorTestBase,
+    with_comms,
+)
+from torch.distributed._tensor import distribute_tensor, DeviceMesh, DTensor
+from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
+
+
+class RedistributeTest(DTensorTestBase):
+    @with_comms
+    def test_shard_to_replicate_forward_backward(self):
+        # 1) test shard -> replicate forward
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        replica_spec = [Replicate()]
+
+        input_sizes_and_shard_dim = [
+            ((self.world_size * 3, 3), 0),
+            ((self.world_size * 3 + 1, 3), 0),
+            ((self.world_size * 3 + 2, 3), 0),
+            ((3, self.world_size * 3), 1),
+            ((3, self.world_size * 3 + 1), 1),
+            ((3, self.world_size * 3 + 2), 1),
+        ]
+
+        for input_size, shard_dim in input_sizes_and_shard_dim:
+            shard_spec = [Shard(shard_dim)]
+            expected_tensor = torch.randn(
+                input_size, device=self.device_type, requires_grad=True
+            )
+            dtensor = distribute_tensor(
+                expected_tensor.clone(), device_mesh, shard_spec
+            )
+            reshard_dtensor = dtensor.redistribute(device_mesh, replica_spec)
+            self.assertEqual(reshard_dtensor.size(), torch.Size(input_size))
+            self.assertEqual(expected_tensor, reshard_dtensor.to_local())
+
+            # 2) test shard -> replicate backward:
+            # should give gradient as shard
+            grad_output = torch.ones_like(reshard_dtensor)
+            reshard_dtensor.backward(grad_output)
+            grad_input = dtensor.grad
+            self.assertEqual(grad_input.placements, shard_spec)
+            self.assertEqual(
+                grad_input.to_local(), torch.ones(dtensor.to_local().size())
+            )
+
+    @with_comms
+    def test_replicate_to_replicate_forward_backward(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        replica_spec = [Replicate()]
+        local_tensor = torch.randn(
+            12, 3, device=self.device_type, requires_grad=True
+        )
+        # 1) test replicate -> replicate forward
+        replica_tensor = distribute_tensor(
+            local_tensor, device_mesh, replica_spec
+        )
+        reshard_replica_tensor = replica_tensor.redistribute(
+            device_mesh, replica_spec
+        )
+        self.assertEqual(replica_tensor.size(), local_tensor.size())
+        self.assertEqual(replica_tensor, reshard_replica_tensor)
+
+        # 2) test replicate -> replicate backward:
+        # should give gradient as replicate
+        grad_output = torch.ones_like(reshard_replica_tensor)
+        reshard_replica_tensor.backward(grad_output)
+        grad_input = replica_tensor.grad
+        self.assertEqual(grad_input.placements, replica_spec)
+        self.assertEqual(grad_input.to_local(), torch.ones(12, 3))
+
+    @with_comms
+    def test_replicate_to_shard_forward_backward(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        replica_spec = [Replicate()]
+
+        input_sizes_and_shard_dim = [
+            ((self.world_size * 3, 3), 0),
+            ((self.world_size * 3 + 1, 3), 0),
+            ((self.world_size * 3 + 2, 3), 0),
+            ((3, self.world_size * 3), 1),
+            ((3, self.world_size * 3 + 1), 1),
+            ((3, self.world_size * 3 + 2), 1),
+        ]
+        for input_size, shard_dim in input_sizes_and_shard_dim:
+            shard_spec = [Shard(shard_dim)]
+            # 1) test replicate -> shard forward
+            local_replica = torch.randn(
+                input_size, device=self.device_type, requires_grad=True
+            )
+            splitted_list = local_replica.tensor_split(
+                self.world_size, shard_dim
+            )
+            # make local tensor as the element of the corresponding chunked list
+            local_tensor = splitted_list[self.rank]
+            replica_tensor = distribute_tensor(
+                local_replica, device_mesh, replica_spec
+            )
+            reshard_tensor = replica_tensor.redistribute(
+                device_mesh, shard_spec
+            )
+            self.assertEqual(reshard_tensor.size(), replica_tensor.size())
+            self.assertEqual(reshard_tensor.placements, shard_spec)
+            self.assertEqual(reshard_tensor.to_local(), local_tensor)
+
+            # 2) test replicate -> shard backward:
+            # should give gradient as replicate
+            grad_output = torch.ones_like(reshard_tensor)
+            reshard_tensor.backward(grad_output)
+            grad_input = replica_tensor.grad
+            self.assertEqual(grad_input.placements, replica_spec)
+            self.assertEqual(grad_input.to_local(), torch.ones(input_size))
+
+    @with_comms
+    def test_partial_to_replicate_forward_backward(self):
+        # Although we don't allow user to reshard to produce a partial
+        # placement (i.e. user can't reshard to partial), we do allow
+        # replicate to partial internally, and also partial to replicate
+        # backward should work as expected
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        partial_local = torch.randn(
+            12, 3, device=self.device_type, requires_grad=True
+        )
+        partial_spec = [_Partial()]
+        replica_spec = [Replicate()]
+        # test partial -> replicate, which trigger all_reduce
+        partial_tensor = DTensor.from_local(
+            partial_local, device_mesh, partial_spec
+        )
+        global_partial_tensor = partial_tensor.redistribute(
+            device_mesh, replica_spec
+        )
+
+        self.assertEqual(partial_tensor.size(), partial_local.size())
+        self.assertEqual(
+            partial_local * self.world_size, global_partial_tensor.to_local()
+        )
+
+        # test backward to have replicate grad on partial
+        global_partial_tensor.backward(torch.ones_like(global_partial_tensor))
+        self.assertIsNotNone(partial_local.grad)
+        if device_mesh.get_rank() == 0:
+            self.assertEqual(partial_local.grad, torch.ones_like(partial_local))
+
+    @with_comms
+    def test_replicate_to_partial(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        local_tensor = torch.randn(
+            12, 3, device=self.device_type, requires_grad=True
+        )
+        partial_spec = _Partial()
+        replica_spec = Replicate()
+        # 1) test replicate -> partial forward
+        replica_tensor = distribute_tensor(
+            local_tensor, device_mesh, [replica_spec]
+        )
+        with self.assertRaisesRegex(
+            RuntimeError, "Can not redistribute to _Partial"
+        ):
+            partial_tensor = replica_tensor.redistribute(
+                device_mesh, [partial_spec]
+            )
+
+        from torch.distributed._tensor.redistribute import Redistribute
+
+        partial_tensor = Redistribute.apply(
+            replica_tensor, device_mesh, [partial_spec]
+        )
+        self.assertEqual(partial_tensor.size(), local_tensor.size())
+        # test it successfully zero out the contents on other ranks
+        if self.rank == 0:
+            self.assertEqual(
+                replica_tensor.to_local(), partial_tensor.to_local()
+            )
+        else:
+            self.assertEqual(
+                partial_tensor.to_local(), torch.zeros_like(local_tensor)
+            )
+
+        # replicate to partial on sub groups
+        local_tensor = torch.randn(12, 3, device=self.device_type)
+        device_mesh = DeviceMesh(
+            self.device_type,
+            torch.arange(self.world_size).reshape(self.world_size // 2, 2),
+        )
+        # 1) test replicate -> partial on 2d-mesh subgroups
+        replica_tensor = distribute_tensor(
+            local_tensor, device_mesh, [replica_spec, replica_spec]
+        )
+        partial_tensor = Redistribute.apply(
+            replica_tensor, device_mesh, [partial_spec, partial_spec]
+        )
+        self.assertEqual(partial_tensor.size(), local_tensor.size())
+
+        if self.rank != 3:
+            # replicate to partial should only zero out rank 3, and leave
+            # rank 0/2 (rank0 on mesh dim 1) and 0, 1 (rank0 on mesh dim 1) un-touched
+            self.assertEqual(
+                replica_tensor.to_local(), partial_tensor.to_local()
+            )
+        else:
+            self.assertEqual(
+                replica_tensor.to_local(), torch.zeros_like(local_tensor)
+            )
+
+    @with_comms
+    def test_partial_to_shard(self):
+        device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
+        partial_spec = [_Partial()]
+
+        input_sizes_and_shard_dim = [
+            ((self.world_size * 3, 3), 0),
+            ((self.world_size * 3 + 1, 3), 0),
+            ((self.world_size * 3 + 2, 3), 0),
+            ((3, self.world_size * 3), 1),
+            ((3, self.world_size * 3 + 1), 1),
+            ((3, self.world_size * 3 + 2), 1),
+        ]
+
+        for input_size, shard_dim in input_sizes_and_shard_dim:
+            shard_spec = [Shard(shard_dim)]
+
+            partial_local = torch.ones(input_size, device=self.device_type)
+            partial_tensor = DTensor.from_local(
+                partial_local, device_mesh, partial_spec, run_check=False
+            )
+
+            quot, rem = divmod(input_size[shard_dim], self.world_size)
+            local_shape = list(input_size)
+            local_shape[shard_dim] = quot + (1 if self.rank < rem else 0)
+            # test partial to shard, trigger reduce_scatter
+            scatter_shard_tensor = partial_tensor.redistribute(
+                device_mesh, shard_spec
+            )
+            self.assertEqual(scatter_shard_tensor.size(), partial_tensor.size())
+            self.assertEqual(scatter_shard_tensor.placements, shard_spec)
+            self.assertEqual(
+                scatter_shard_tensor.to_local(),
+                torch.ones(local_shape) * self.world_size,
+            )
+
+
+class MultiDimRedistributeTest(DTensorTestBase):
+    @property
+    def world_size(self) -> int:
+        return 8
+
+    @with_comms
+    def test_multi_dim_mesh(self):
+        devices = torch.arange(self.world_size)
+        for mesh_shape in [devices, devices.view(4, 2), devices.view(2, 2, 2)]:
+            mesh_shape = torch.arange(self.world_size).view(-1, 2)
+            device_mesh = DeviceMesh(self.device_type, mesh_shape)
+            tensor_shape = (16, 24)
+
+            if torch.distributed.get_rank() == 0:
+                full_tensor = torch.randn(*tensor_shape)
+            else:
+                # these should be entirely ignored
+                # because distribute_tensor is expected to override shards in ranks != 0
+                full_tensor = torch.ones(*tensor_shape)
+
+            possibilities = [Replicate()] + [
+                Shard(i) for i in range(full_tensor.ndim)
+            ]
+            all_outputs = list(
+                itertools.product(*(mesh_shape.ndim * [possibilities]))
+            )
+            all_inputs = list(
+                itertools.product(
+                    *(mesh_shape.ndim * [possibilities + [_Partial()]])
+                )
+            )
+
+            for inputs in all_inputs:
+                # if partial, temporarily make it Replicated, then replace replicated with partial afterwards
+                repl_inputs = [
+                    Replicate() if s.is_partial() else s for s in inputs
+                ]
+                dt = distribute_tensor(full_tensor, device_mesh, repl_inputs)
+
+                if repl_inputs != inputs:
+                    # create a new DTensor reinterpreting some of the replicated entires as "Partial"
+                    dt = DTensor.from_local(
+                        dt.to_local(), device_mesh, inputs, run_check=False
+                    )
+
+                for outputs in all_outputs:
+                    # redistribute on target outputs
+                    dt2 = dt.redistribute(device_mesh, outputs)
+
+                    # replicate and then get first shard
+                    local_full = dt2.redistribute(
+                        device_mesh, device_mesh.ndim * [Replicate()]
+                    ).to_local()
+
+                    if torch.distributed.get_rank() == 0:
+                        self.assertEqual(local_full.shape, full_tensor.shape)
+
+                        num_sums = 1
+                        for idx, input in enumerate(inputs):
+                            if input.is_partial():
+                                num_sums *= mesh_shape.size(idx)
+                        expected = num_sums * full_tensor
+                        self.assertEqual(local_full, expected)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/torch/testing/_internal/distributed/_tensor/__init__.py b/torch/testing/_internal/distributed/_tensor/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/__init__.py
diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
new file mode 100644
index 0000000..cf2abe0
--- /dev/null
+++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py
@@ -0,0 +1,334 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+
+from contextlib import contextmanager
+from dataclasses import dataclass
+import itertools
+import sys
+from functools import wraps
+from typing import (
+    Any,
+    Callable,
+    Generator,
+    Iterator,
+    Tuple,
+    Dict,
+    Optional,
+    List,
+    Sequence,
+    TypeVar,
+    cast,
+)
+
+import torch
+import torch.distributed as dist
+
+from torch.utils._pytree import tree_flatten, tree_unflatten, TreeSpec
+from torch.testing._internal.common_distributed import (
+    MultiProcessTestCase,
+    TEST_SKIPS,
+    skip_if_lt_x_gpu,
+)
+
+from torch.distributed._tensor import (
+    DeviceMesh,
+    Shard,
+    Replicate,
+    distribute_tensor,
+    redistribute,
+)
+from torch.distributed._tensor.api import DTensor
+from torch.distributed._tensor.placement_types import Placement
+
+DEVICE_TYPE = "cuda" if torch.cuda.is_available() else "cpu"
+NUM_DEVICES = 4
+
+# We use this as a proxy for "multiple GPUs exist"
+if torch.cuda.is_available() and torch.cuda.device_count() > 1:
+    # when we actually have multiple GPUs, relax the requirement to smaller counts.
+    NUM_DEVICES = min(NUM_DEVICES, torch.cuda.device_count())
+
+T = TypeVar("T")
+
+
+def skip_unless_torch_gpu(method: T) -> T:
+    """
+    Test decorator which skips the test unless there's a GPU available to torch.
+
+    >>> @skip_unless_torch_gpu
+    >>> def test_some_method(self) -> None:
+    >>>   ...
+    """
+    # The builtin @skip_if_no_gpu relies on os.environ['WORLD_SIZE'] being set.
+    return cast(T, skip_if_lt_x_gpu(NUM_DEVICES)(method))
+
+
+@dataclass
+class RedistributeProfile:
+    num_calls: int
+
+
+@contextmanager
+def redistribute_profiler() -> Generator[RedistributeProfile, None, None]:
+
+    orig_redistribute_dtensor = redistribute.redistribute_dtensor
+    profile: RedistributeProfile = RedistributeProfile(num_calls=0)
+
+    # pyre-ignore[53]
+    def patched_redistribute_dtensor(
+        input: DTensor,
+        device_mesh: DeviceMesh,
+        placements: Sequence[Placement],
+    ) -> DTensor:
+        result = orig_redistribute_dtensor(input, device_mesh, placements)
+        profile.num_calls += 1
+        return result
+
+    try:
+        # pyre-ignore[9]
+        redistribute.redistribute_dtensor = patched_redistribute_dtensor
+        yield profile
+    finally:
+        redistribute.redistribute_dtensor = orig_redistribute_dtensor
+
+
+class DTensorTestBase(MultiProcessTestCase):
+    @property
+    def world_size(self) -> int:
+        return NUM_DEVICES
+
+    def build_device_mesh(self) -> DeviceMesh:
+        return DeviceMesh(DEVICE_TYPE, list(range(NUM_DEVICES)))
+
+    def init_pg(self, backend: str = "nccl") -> None:
+        if backend == "nccl" and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        if backend not in ["nccl", "gloo", "mpi"]:
+            raise RuntimeError(f"Backend {backend} not supported!")
+
+        dist.init_process_group(
+            backend=backend,
+            world_size=self.world_size,
+            rank=self.rank,  # pyre-ignore[16]
+            init_method=f"file://{self.file_name}",  # pyre-ignore[16]
+        )
+
+        # set device for nccl pg for collectives
+        if backend == "nccl":
+            torch.cuda.set_device(self.rank)
+
+    def destroy_pg(self) -> None:
+        # Wait for all ranks to reach here before starting shutdown.
+        dist.barrier()
+        dist.destroy_process_group()
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._spawn_processes()
+
+    # pyre-ignore[2]:
+    def _test_op(self, mesh: DeviceMesh, op_call, *args, **kwargs) -> None:
+        with redistribute_profiler() as profile:
+            out = op_call(*args, **kwargs)
+            dtc = DTensorConverter(mesh, args, kwargs)
+            for d_args, d_kwargs in dtc:
+                # pyre can't find assertTrue anymore?
+                self.assertEqual(dtc.successful(), True)
+                d_out = op_call(*d_args, **d_kwargs)
+                self.assertEqual(
+                    d_out.redistribute(
+                        mesh, [Replicate()] * mesh.ndim
+                    ).to_local(),
+                    out,
+                )
+
+
+# wrapper to initialize comms (processgroup)
+def with_comms(
+    func: Optional[  # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
+        Callable
+    ] = None,
+    backend: Optional[str] = None,
+) -> Optional[  # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
+    Callable
+]:
+    assert func is not None
+
+    @wraps(func)  # pyre-ignore[6]
+    def wrapper(
+        self, *args: Tuple[object], **kwargs: Dict[str, Any]  # type: ignore[misc]
+    ) -> None:
+        # if backend not specified, and cuda available, then use nccl, else gloo
+        pg_backend = (
+            "nccl" if backend is None and torch.cuda.is_available() else "gloo"
+        )
+        if pg_backend == "nccl" and torch.cuda.device_count() < self.world_size:
+            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
+
+        self.device_type = "cuda" if pg_backend == "nccl" else "cpu"
+        self.init_pg(backend=pg_backend)
+        func(self)  # type: ignore[misc]
+        self.destroy_pg()
+
+    return wrapper
+
+
+# This is a class for converting args/kwargs of an op into distributed args/kwargs
+class DTensorConverter(object):
+    def __init__(
+        self,
+        mesh: DeviceMesh,
+        args: Tuple[object, ...],
+        kwargs: Dict[str, object],
+    ) -> None:
+        self.hit = 0
+        self.miss = 0
+        self.mesh = mesh
+        self.args = args
+        self.kwargs = kwargs
+        flatten_args, flatten_args_spec = tree_flatten(args)
+        flatten_kwargs, flatten_kwargs_spec = tree_flatten(kwargs)
+
+        self.flatten_args: List[object] = flatten_args
+        self.flatten_args_spec: TreeSpec = flatten_args_spec
+        self.flatten_kwargs: List[object] = flatten_kwargs
+        self.flatten_kwargs_spec: TreeSpec = flatten_kwargs_spec
+
+        choices_for_args = []
+        for arg in self.flatten_args:
+            if isinstance(arg, torch.Tensor):
+                choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
+
+        for arg in self.flatten_kwargs:
+            if isinstance(arg, torch.Tensor):
+                choices_for_args.append(self.gen_sharding_choices_for_arg(arg))
+
+        self.sharding_combs: Iterator[Sequence[Placement]] = iter(
+            itertools.product(*choices_for_args)
+        )
+
+    def successful(self) -> bool:
+        return self.hit > 0 and self.miss == 0
+
+    def is_supported_tensor(self, t: torch.Tensor) -> bool:
+        # TODO: dist tensor need to support quantized and sparse
+        # tensors, quantized tensor might be relatively easy, but
+        # sparse tensor have special layouts that we need to possibly
+        # deal with, until we are clear about them, we don't officially
+        # support them.
+        return not any(
+            [
+                t.is_sparse_csr,
+                t.is_sparse,
+                t.is_mkldnn,
+                t.is_quantized,
+                t.is_nested,
+                torch._is_functional_tensor(t),
+                t.is_neg(),
+                t.is_conj(),
+                t.device.type in ("lazy", "meta"),
+                # We need a way to test if a tensor is batched but there
+                # is no official APi to do it
+                # torch._C._is_batched(t),
+            ]
+        )
+
+    def gen_sharding_choices_for_arg(
+        self, arg: torch.Tensor
+    ) -> Sequence[Placement]:
+        mesh_size = self.mesh.size()
+        sharding_choices: List[Placement] = [Replicate()]
+        # c10d collective does not support bool tensor
+        # for bool tensor we treat it as replicated
+        if arg.dtype != torch.bool:
+            # only generating choices with: replicate, or sharding
+            # evenly on a dimension that could be sharded
+            sharding_choices = sharding_choices + [
+                Shard(i)
+                for i, s in enumerate(arg.shape)
+                if s > 1 and s % mesh_size == 0
+            ]
+        # TODO: add multi mesh choices
+        # all_choices = itertools.product(
+        #     *(self.mesh.ndim * [sharding_choices])
+        # )
+        return sharding_choices
+
+    def __iter__(self) -> "DTensorConverter":
+        return self
+
+    def __next__(self) -> Tuple[Tuple[object, ...], Dict[str, object]]:
+        try:
+            next_sharding_choices = next(self.sharding_combs)
+            idx = 0
+
+            new_args: List[object] = []
+            for arg in self.flatten_args:
+                if isinstance(arg, torch.Tensor):
+                    new_args.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_args.append(arg)
+
+            new_kwargs: List[object] = []
+            for arg in self.flatten_kwargs:
+                if isinstance(arg, torch.Tensor):
+                    new_kwargs.append(
+                        self.to_dist_tensor(
+                            arg, self.mesh, [next_sharding_choices[idx]]
+                        )
+                    )
+                    idx += 1
+                else:
+                    new_kwargs.append(arg)
+
+            return (
+                tree_unflatten(new_args, self.flatten_args_spec),
+                tree_unflatten(new_kwargs, self.flatten_kwargs_spec),
+            )
+        except StopIteration:
+            raise StopIteration
+
+    def to_dist_tensor(
+        self, t: torch.Tensor, mesh: DeviceMesh, placements: List[Placement]
+    ) -> torch.Tensor:
+        if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
+            if self.is_supported_tensor(t):
+                self.hit += 1
+                # We cannot use distribute_tensor for bool tensors as c10d
+                # collectives does not support the dtype, we assume op with
+                # bool tensor args the same tensor so we don't need to broadcast
+                # TODO: add bool tensor dtype support in c10d collective
+                if t.dtype == torch.bool:
+                    r = DTensor(
+                        t,
+                        mesh,
+                        placements,
+                        size=t.size(),
+                        requires_grad=t.requires_grad,
+                    )
+                else:
+                    r = distribute_tensor(t, mesh, placements)
+                if type(t) is torch.nn.Parameter:
+                    r = torch.nn.Parameter(  # type: ignore[assignment]
+                        r, requires_grad=r.requires_grad
+                    )
+                return r
+            else:
+                self.miss += 1
+                return t
+        elif torch.overrides.is_tensor_like(t):
+            # Blindly converting tensor subclasses to dist tensor can cause
+            # unpredictable problems, we explicitly disable this conversion
+            # for now (i.e. we don't support DTensor holding tensor subclass
+            # until there's a strong reason later).
+            self.miss += 1
+            return t
+        else:
+            raise RuntimeError(
+                f"Trying to convert to DTensor, but got {type(t)}"
+            )