[dtensor] implement dim-0 (row) embedding sharding with MaskPartial (#118080)

This PR add support for rowwise sharded embedding by adding a
MaskPartial placement that inherits from the default partial placement,
and override the Partial constracts to construct the mask and release
the mask after the reduction

The MaskPartial placement have the potential to support other ops
sharding computation that requires a mask for semantic correctness.
currently make it live in the embedding ops but we can move it to a
common place if needed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118080
Approved by: https://github.com/tianyu-l
ghstack dependencies: #118079
diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py
index 3ac264c..3ac61e0 100644
--- a/test/distributed/_tensor/test_embedding_ops.py
+++ b/test/distributed/_tensor/test_embedding_ops.py
@@ -25,9 +25,26 @@
     sys.exit(0)
 
 
+funcol = torch.ops.c10d_functional
+
+
 class TestEmbeddingOp(DTensorTestBase):
+    def _apply_sharding(self, embedding_mod, shard_dim, device_mesh):
+        def shard_embedding_fn(name, module, device_mesh):
+            for name, param in module.named_parameters():
+                dist_param = torch.nn.Parameter(
+                    distribute_tensor(param, device_mesh, [Shard(shard_dim)])
+                )
+                module.register_parameter(name, dist_param)
+
+        sharded_embedding = distribute_module(
+            embedding_mod, device_mesh, shard_embedding_fn
+        )
+        return sharded_embedding
+
     def _run_embedding_op_test(
         self,
+        device_mesh,
         shard_dim,
         input_size,
         num_embeddings,
@@ -35,7 +52,6 @@
         **kwargs,
     ):
         # Use same seed.
-        device_mesh = self.build_device_mesh()
         torch.manual_seed(0)
         local_embedding = torch.nn.Embedding(
             num_embeddings,
@@ -55,15 +71,8 @@
             local_embedding.weight.clone().detach()
         )
 
-        def shard_embedding_fn(name, module, device_mesh):
-            for name, param in module.named_parameters():
-                dist_param = torch.nn.Parameter(
-                    distribute_tensor(param, device_mesh, [Shard(shard_dim)])
-                )
-                module.register_parameter(name, dist_param)
-
-        sharded_embedding = distribute_module(
-            sharded_embedding, device_mesh, shard_embedding_fn
+        sharded_embedding = self._apply_sharding(
+            sharded_embedding, shard_dim, device_mesh
         )
 
         # Run sharded computation
@@ -121,7 +130,7 @@
             **kwargs,
         )
         sharded_output = torch.nn.functional.embedding(
-            DTensor.from_local(inp, device_mesh, [Replicate()]),
+            DTensor.from_local(inp, device_mesh, [Replicate()], run_check=False),
             sharded_embedding.weight,
             **kwargs,
         )
@@ -129,31 +138,50 @@
 
     @with_comms
     def test_sharded_embedding_colwise(self):
-        self._run_embedding_op_test(1, [5, 4], 17, 12)
-        self._run_embedding_op_test(1, [6, 7, 6], 21, 11)
-        self._run_embedding_op_test(1, [8, 6, 5, 4], 23, 13)
-        self._run_embedding_op_test(1, [8, 6, 5, 4, 7], 23, 16)
-        self._run_embedding_op_test(1, [4], 15, 14)
-        self._run_embedding_op_test(1, [34], 15, 14, padding_idx=10)
-        self._run_embedding_op_test(1, [8, 6, 5, 4], 23, 13, padding_idx=12)
+        mesh = self.build_device_mesh()
+        self._run_embedding_op_test(mesh, 1, [5, 4], 17, 12)
+        self._run_embedding_op_test(mesh, 1, [6, 7, 6], 21, 11)
+        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13)
+        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4, 7], 23, 16)
+        self._run_embedding_op_test(mesh, 1, [4], 15, 14)
+        self._run_embedding_op_test(mesh, 1, [34], 15, 14, padding_idx=10)
+        self._run_embedding_op_test(mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12)
 
     @with_comms
     def test_sharded_embedding_colwise_max_norm_errors(self):
+        mesh = self.build_device_mesh()
         with self.assertRaisesRegex(
             NotImplementedError,
             "aten.embedding_renorm_.default does not have a sharding strategy registered.",
         ):
             self._run_embedding_op_test(
-                1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
+                mesh, 1, [8, 6, 5, 4], 23, 13, padding_idx=12, max_norm=2.0
             )
 
     @with_comms
     def test_sharded_embedding_rowwise(self):
-        with self.assertRaisesRegex(
-            NotImplementedError,
-            "row-wise sharded embedding operation yet",
-        ):
-            self._run_embedding_op_test(0, [5, 12], 16, 22)
+        mesh = self.build_device_mesh()
+        # test correctness
+        self._run_embedding_op_test(mesh, 0, [5, 12], 16, 22)
+        self._run_embedding_op_test(mesh, 0, [6, 7, 6], 13, 22)
+        self._run_embedding_op_test(mesh, 0, [34], 15, 14, padding_idx=10)
+
+        from torch.distributed._tensor.ops.embedding_ops import _MaskPartial
+
+        # test collectives
+        embedding_mod = torch.nn.Embedding(10, 20, device=self.device_type)
+        sharded_embedding = self._apply_sharding(embedding_mod, 0, mesh)
+        inp = torch.randint(0, 10, (8, 8), device=self.device_type)
+        replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
+        output = sharded_embedding(replicated_inp)
+        self.assertIsInstance(output.placements[0], _MaskPartial)
+
+        comm_mode = CommDebugMode()
+
+        with comm_mode:
+            output.full_tensor()
+            self.assertEqual(comm_mode.get_total_counts(), 1)
+            self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
 
 
 if __name__ == "__main__":
diff --git a/torch/distributed/_tensor/ops/embedding_ops.py b/torch/distributed/_tensor/ops/embedding_ops.py
index 54d80f0..869f83d 100644
--- a/torch/distributed/_tensor/ops/embedding_ops.py
+++ b/torch/distributed/_tensor/ops/embedding_ops.py
@@ -1,9 +1,11 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
 # implement matrix related ops for distributed tensor
 import itertools
-from typing import cast, List
+from dataclasses import dataclass, field
+from typing import cast, List, Optional
 
 import torch
+import torch.distributed._functional_collectives as funcol
 from torch.distributed._tensor.op_schema import (
     OpSchema,
     OpStrategy,
@@ -29,6 +31,131 @@
 aten = torch.ops.aten
 
 
+@dataclass
+class MaskBuffer:
+    data: Optional[torch.Tensor] = None
+
+    def materialize_mask(self, mask):
+        if self.data is not None:
+            raise RuntimeError("MaskBuffer has already been materialized")
+        self.data = mask
+
+    def release_mask(self):
+        # TODO: evaluate if we need to release the mask buffer or the buffer
+        # can just have the same lifetime as the _Partial placement
+        if self.data is None:
+            raise RuntimeError("MaskBuffer has not been materialized")
+        self.data = None
+
+
+@dataclass(frozen=True)
+class _MaskPartial(_Partial):
+    """
+    A partial mask placement devised for rowwise sharded embedding op, where we need
+    to mask and adjust the indices to the local embedding shard, embedding masking
+    is a special type of the Partial placement
+
+    NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor
+    lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
+    """
+
+    logical_dim_size: int = -1
+    mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
+
+    def _partition_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        # override parent logic to perform partial mask for embedding
+        num_chunks = mesh.size(mesh_dim)
+        # get local shard size and offset on the embedding_dim
+        local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
+            self.logical_dim_size,
+            num_chunks,
+            mesh.get_local_rank(mesh_dim),
+            return_offset=True,
+        )
+        # Build the input mask and save it for the current partial placement
+        # this is so that the output of embedding op can reuse the same partial
+        # placement saved mask to perform mask + reduction
+        mask = (tensor < local_offset_on_dim) | (
+            tensor >= local_offset_on_dim + local_shard_size
+        )
+        # mask the input tensor
+        masked_tensor = tensor.clone() - local_offset_on_dim
+        masked_tensor[mask] = 0
+        # materialize the mask buffer to be used for reduction
+        self.mask_buffer.materialize_mask(mask)
+        return masked_tensor
+
+    def _reduce_value(
+        self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
+    ) -> torch.Tensor:
+        # by the time we ned reduction, we should have already saved the mask
+        assert self.mask_buffer.data is not None
+
+        # apply the mask to the tensor that pending reduction
+        tensor[self.mask_buffer.data, :] = 0.0
+
+        # clear the mask buffer
+        self.mask_buffer.release_mask()
+
+        # perform sum reduction
+        return funcol.all_reduce(
+            tensor, reduceOp=self.reduce_op.name, group=(mesh, mesh_dim)
+        )
+
+    def _reduce_shard_value(
+        self,
+        tensor: torch.Tensor,
+        mesh: DeviceMesh,
+        mesh_dim: int,
+        shard_spec: Placement,
+    ) -> torch.Tensor:
+        # by the time we ned reduction, we should have already saved the mask
+        assert self.mask_buffer.data is not None
+
+        # apply the mask to the tensor that pending reduction
+        tensor[self.mask_buffer.data, :] = 0.0
+
+        # clear the mask buffer
+        self.mask_buffer.release_mask()
+
+        # call reduce_shard_tensor of the shard_spec.
+        shard_spec = cast(Shard, shard_spec)
+        return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim)
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, _MaskPartial):
+            return False
+
+        # if either data is not None, we invalidate the sharding cache, as this indicates
+        # the current MaskPartial placement is still in use and should not be used for cache hit.
+        if self.mask_buffer.data is not None or other.mask_buffer.data is not None:
+            return False
+
+        return (
+            self.reduce_op == other.reduce_op
+            and self.logical_dim_size == other.logical_dim_size
+        )
+
+    def __hash__(self) -> int:
+        return 1 + hash(
+            (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)
+        )
+
+    def __repr__(self) -> str:
+        """
+        machine readable representation of the MaskPartial placement
+        """
+        return f"_MaskPartial(logical_dim_size={self.logical_dim_size})"
+
+    def __str__(self) -> str:
+        """
+        human readable representation of the MaskPartial placement
+        """
+        return "MaskP"
+
+
 @register_op_strategy(aten.embedding.default)
 def embedding_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
     """
@@ -43,13 +170,6 @@
     indices_shape = indices_strategy.output_shape
     output_emd_dim = len(indices_shape)
 
-    # guard rowwise sharding not implemented for now
-    weight_spec = weight_strategy.strategies[0].output_spec
-    if any(placement.is_shard(0) for placement in weight_spec.placements):
-        raise NotImplementedError(
-            "DTensor does not support row-wise sharded embedding operation yet!"
-        )
-
     all_mesh_dim_strategies = []
 
     for mesh_dim in range(mesh.ndim):
@@ -64,6 +184,18 @@
         colwise_sharding = [Shard(output_emd_dim), Shard(1), Replicate()]
         single_mesh_dim_strategies.append(colwise_sharding)
 
+        # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
+        embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0])
+
+        # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
+        # from the input indices and use it for output reduction
+        rowwise_sharding = [
+            embedding_partial_placement,
+            Shard(0),
+            embedding_partial_placement,
+        ]
+        single_mesh_dim_strategies.append(rowwise_sharding)
+
         # batch dim sharding, weight replicated, input can shard on any dim, output follows input
         for input_dim in range(len(indices_shape)):
             batch_sharding = [Shard(input_dim), Replicate(), Shard(input_dim)]