fix _MaskPartial when multiple embeddings coexist (#131264)

Previously, using _MaskPartial when multiple embeddings have the following issues:
1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes.
2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py
index 7822962..3c366d5 100644
--- a/test/distributed/_tensor/test_embedding_ops.py
+++ b/test/distributed/_tensor/test_embedding_ops.py
@@ -184,6 +184,49 @@
             self.assertEqual(comm_mode.get_total_counts(), 1)
             self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
 
+    @with_comms
+    def test_multiple_embeddings_rowwise(self):
+        mesh = self.build_device_mesh()
+
+        inp = torch.randint(0, 10, (4, 4), device=self.device_type)
+        replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
+
+        from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
+
+        # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
+        # and MaskBuffer, because of cache hit from sharding propagation
+
+        emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
+        sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
+        output1 = sharded_emb1(replicated_inp)
+
+        emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
+        sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
+        output2 = sharded_emb2(replicated_inp)
+
+        partial_placement1 = output1.placements[0]
+        self.assertIsInstance(partial_placement1, _MaskPartial)
+        output1.full_tensor()
+
+        partial_placement2 = output2.placements[0]
+        self.assertIsInstance(partial_placement2, _MaskPartial)
+        output2.full_tensor()
+
+        self.assertTrue(id(partial_placement1), id(partial_placement2))
+
+        # case 2: two embeddings with the same logical_dim_size, but different logical_shape
+        # thus they will have different _MaskPartial placements (with no cache hit)
+
+        emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
+        sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
+        output3 = sharded_emb3(replicated_inp)
+        partial_placement3 = output3.placements[0]
+        self.assertIsInstance(partial_placement3, _MaskPartial)
+        output2.full_tensor()
+
+        # not equal because of different logical_shape, despite of same logical_dim_size
+        self.assertNotEqual(partial_placement1, partial_placement3)
+
 
 if __name__ == "__main__":
     run_tests()
diff --git a/torch/distributed/_tensor/ops/_embedding_ops.py b/torch/distributed/_tensor/ops/_embedding_ops.py
index a374022..15b2af2 100644
--- a/torch/distributed/_tensor/ops/_embedding_ops.py
+++ b/torch/distributed/_tensor/ops/_embedding_ops.py
@@ -32,21 +32,29 @@
 @dataclass
 class MaskBuffer:
     data: Optional[torch.Tensor] = None
+    # refcount allows shared usage of the MaskBuffer, as long as all users have the same data
+    refcount: int = 0
 
     def materialize_mask(self, mask):
-        if self.data is not None:
-            raise RuntimeError("MaskBuffer has already been materialized")
-        self.data = mask
+        if self.refcount == 0:
+            self.data = mask
+        else:
+            assert self.data is not None
+            if not torch.equal(self.data, mask):
+                raise RuntimeError(
+                    "MaskBuffer has been materialized with conflicting data"
+                )
+        self.refcount += 1
 
     def release_mask(self):
-        # TODO: evaluate if we need to release the mask buffer or the buffer
-        # can just have the same lifetime as the Partial placement
-        if self.data is None:
+        if self.refcount == 0 or self.data is None:
             raise RuntimeError("MaskBuffer has not been materialized")
-        self.data = None
+        self.refcount -= 1
+        if self.refcount == 0:
+            self.data = None
 
     def apply_mask(self, tensor):
-        if self.data is None:
+        if self.refcount == 0 or self.data is None:
             raise RuntimeError("MaskBuffer has not been materialized")
 
         # NOTE: _MaskPartial is being used by the embedding op and the gather op.
@@ -70,17 +78,23 @@
     lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
     """
 
-    logical_dim_size: int = -1
     mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
 
+    # required fields for computing the local offset and deriving the mask
+    offset_shape: Optional[torch.Size] = None
+    offset_dim: int = 0
+
     def _partition_value(
         self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
     ) -> torch.Tensor:
         # override parent logic to perform partial mask for embedding
         num_chunks = mesh.size(mesh_dim)
         # get local shard size and offset on the embedding_dim
+        assert (
+            self.offset_shape is not None
+        ), "offset_shape needs to be set for _MaskPartial"
         local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
-            self.logical_dim_size,
+            self.offset_shape[self.offset_dim],
             num_chunks,
             mesh.get_local_rank(mesh_dim),
             return_offset=True,
@@ -146,19 +160,24 @@
 
         return (
             self.reduce_op == other.reduce_op
-            and self.logical_dim_size == other.logical_dim_size
+            and self.offset_shape == other.offset_shape
+            and self.offset_dim == other.offset_dim
         )
 
     def __hash__(self) -> int:
         return 1 + hash(
-            (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)
+            (
+                self.reduce_op,
+                self.offset_shape,
+                self.offset_dim,
+            )
         )
 
     def __repr__(self) -> str:
         """
         machine readable representation of the MaskPartial placement
         """
-        return f"_MaskPartial(logical_dim_size={self.logical_dim_size})"
+        return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
 
     def __str__(self) -> str:
         """
@@ -192,7 +211,7 @@
     single_mesh_dim_strategies.append(colwise_sharding)
 
     # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
-    embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0])
+    embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
 
     # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
     # from the input indices and use it for output reduction
diff --git a/torch/distributed/_tensor/ops/_tensor_ops.py b/torch/distributed/_tensor/ops/_tensor_ops.py
index 223ff06..335773f 100644
--- a/torch/distributed/_tensor/ops/_tensor_ops.py
+++ b/torch/distributed/_tensor/ops/_tensor_ops.py
@@ -408,7 +408,7 @@
     # this only works when the input is sharded on the gather dimension, and
     # index has size 1 on the gather dimension
     if index_shape[dim] == 1:
-        index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim])
+        index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
         input_sharding: PlacementList = [
             index_partial_placement,
             Shard(dim),
diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py
index ead6cca..79f6f08 100644
--- a/torch/distributed/tensor/parallel/loss.py
+++ b/torch/distributed/tensor/parallel/loss.py
@@ -200,7 +200,8 @@
     local_weight: Optional[Tensor],
     reduction: int,
     ignore_index: int,
-    channel_dim_size: int,
+    input_shape: torch.Size,
+    channel_dim: int,
     mesh: DeviceMesh,
     mesh_dim: int,
 ) -> Tuple[Tensor, Tensor]:
@@ -230,7 +231,7 @@
 
     # The following code block is a distributed version of
     # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
-    partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
+    partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
     safe_target_partial_ = partial_placement._partition_value(
         safe_target_, mesh, mesh_dim
     )
@@ -317,7 +318,8 @@
         local_weight,
         reduction,
         ignore_index,
-        channel_dim_size,
+        x.shape,
+        channel_dim,
         spec.mesh,
         mesh_dim,
     )
@@ -348,7 +350,8 @@
     reduction: int,
     ignore_index: int,
     total_weight: Tensor,
-    channel_dim_size: int,
+    input_shape: torch.Size,
+    channel_dim: int,
     mesh: DeviceMesh,
     mesh_dim: int,
 ) -> Tensor:
@@ -362,7 +365,7 @@
 
     # The following code block is a distributed version of
     # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
-    partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
+    partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
     safe_target = safe_target.squeeze(channel_dim).flatten()
     masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
     # only update grad_input to -1 if not masked
@@ -422,7 +425,6 @@
     total_weight = cast(Tensor, args[6])
 
     channel_dim = 1 if x.dim() >= 2 else 0
-    channel_dim_size = x.shape[channel_dim]
     spec = x._spec
     mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
 
@@ -449,7 +451,8 @@
         reduction,
         ignore_index,
         total_weight,
-        channel_dim_size,
+        x.shape,
+        channel_dim,
         spec.mesh,
         mesh_dim,
     )