fix _MaskPartial when multiple embeddings coexist (#131264)
Previously, using _MaskPartial when multiple embeddings have the following issues:
1. Suppose an `nn.Embedding` has shape `[vocab_size, emb_size]`. When there are more than one embeddings, sharing the same `vocab_size` but with different `emb_size`s. Then they would not share `OpStrategy` since each, when involved in computation, would have different `OpSchema`; however, there would be cache hit for redistribute (specifically `_gen_transform_infos` in `torch/distributed/_tensor/_redistribute.py` when doing `Replicate` -> `_MaskPartial`) as the `_MaskPartial` only has `vocab_size` as `logical_dim_size` but not `emb_size` as attribute. This cache hit is undesirable and would cause trouble when doing all-reduce/reduce-scatter on the new `_MaskPartial` in a separate `OpStrategy`. The error was reported in #130725. In this PR, we introduce `offset_shape` to represent the embedding's full shape to avoid cache hit from embeddings of different shapes.
2. The second issue is when we have two `nn.Embedding`s `emb1` and `emb2` with the same shape. There will be cache hit not only in `_gen_transform_infos`, but also in `OpStrategy` generation. Previously, if we sequentially do `Replicate` -> `_MaskPartial` for both `emb1` `emb2` and then sequentially do reduction on the `_MaskPartial` of `emb1`, it would destroy the `MaskBuffer` and `emb2` would hit error. This PR adds a `refcount` for the `MaskBuffer` so that it can be properly shared by multiple `nn.Embedding`s.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131264
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_tensor/test_embedding_ops.py b/test/distributed/_tensor/test_embedding_ops.py
index 7822962..3c366d5 100644
--- a/test/distributed/_tensor/test_embedding_ops.py
+++ b/test/distributed/_tensor/test_embedding_ops.py
@@ -184,6 +184,49 @@
self.assertEqual(comm_mode.get_total_counts(), 1)
self.assertEqual(comm_mode.get_comm_counts()[funcol.all_reduce], 1)
+ @with_comms
+ def test_multiple_embeddings_rowwise(self):
+ mesh = self.build_device_mesh()
+
+ inp = torch.randint(0, 10, (4, 4), device=self.device_type)
+ replicated_inp = DTensor.from_local(inp, mesh, [Replicate()], run_check=False)
+
+ from torch.distributed._tensor.ops._embedding_ops import _MaskPartial
+
+ # case 1: two embeddings with the same shape, thus sharing the underying _MaskPartial
+ # and MaskBuffer, because of cache hit from sharding propagation
+
+ emb1 = torch.nn.Embedding(10, 23, device=self.device_type)
+ sharded_emb1 = self._apply_sharding(emb1, 0, mesh)
+ output1 = sharded_emb1(replicated_inp)
+
+ emb2 = torch.nn.Embedding(10, 29, device=self.device_type)
+ sharded_emb2 = self._apply_sharding(emb2, 0, mesh)
+ output2 = sharded_emb2(replicated_inp)
+
+ partial_placement1 = output1.placements[0]
+ self.assertIsInstance(partial_placement1, _MaskPartial)
+ output1.full_tensor()
+
+ partial_placement2 = output2.placements[0]
+ self.assertIsInstance(partial_placement2, _MaskPartial)
+ output2.full_tensor()
+
+ self.assertTrue(id(partial_placement1), id(partial_placement2))
+
+ # case 2: two embeddings with the same logical_dim_size, but different logical_shape
+ # thus they will have different _MaskPartial placements (with no cache hit)
+
+ emb3 = torch.nn.Embedding(10, 29, device=self.device_type)
+ sharded_emb3 = self._apply_sharding(emb3, 0, mesh)
+ output3 = sharded_emb3(replicated_inp)
+ partial_placement3 = output3.placements[0]
+ self.assertIsInstance(partial_placement3, _MaskPartial)
+ output2.full_tensor()
+
+ # not equal because of different logical_shape, despite of same logical_dim_size
+ self.assertNotEqual(partial_placement1, partial_placement3)
+
if __name__ == "__main__":
run_tests()
diff --git a/torch/distributed/_tensor/ops/_embedding_ops.py b/torch/distributed/_tensor/ops/_embedding_ops.py
index a374022..15b2af2 100644
--- a/torch/distributed/_tensor/ops/_embedding_ops.py
+++ b/torch/distributed/_tensor/ops/_embedding_ops.py
@@ -32,21 +32,29 @@
@dataclass
class MaskBuffer:
data: Optional[torch.Tensor] = None
+ # refcount allows shared usage of the MaskBuffer, as long as all users have the same data
+ refcount: int = 0
def materialize_mask(self, mask):
- if self.data is not None:
- raise RuntimeError("MaskBuffer has already been materialized")
- self.data = mask
+ if self.refcount == 0:
+ self.data = mask
+ else:
+ assert self.data is not None
+ if not torch.equal(self.data, mask):
+ raise RuntimeError(
+ "MaskBuffer has been materialized with conflicting data"
+ )
+ self.refcount += 1
def release_mask(self):
- # TODO: evaluate if we need to release the mask buffer or the buffer
- # can just have the same lifetime as the Partial placement
- if self.data is None:
+ if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
- self.data = None
+ self.refcount -= 1
+ if self.refcount == 0:
+ self.data = None
def apply_mask(self, tensor):
- if self.data is None:
+ if self.refcount == 0 or self.data is None:
raise RuntimeError("MaskBuffer has not been materialized")
# NOTE: _MaskPartial is being used by the embedding op and the gather op.
@@ -70,17 +78,23 @@
lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor.
"""
- logical_dim_size: int = -1
mask_buffer: MaskBuffer = field(default_factory=MaskBuffer)
+ # required fields for computing the local offset and deriving the mask
+ offset_shape: Optional[torch.Size] = None
+ offset_dim: int = 0
+
def _partition_value(
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
) -> torch.Tensor:
# override parent logic to perform partial mask for embedding
num_chunks = mesh.size(mesh_dim)
# get local shard size and offset on the embedding_dim
+ assert (
+ self.offset_shape is not None
+ ), "offset_shape needs to be set for _MaskPartial"
local_shard_size, local_offset_on_dim = Shard._local_shard_size_on_dim(
- self.logical_dim_size,
+ self.offset_shape[self.offset_dim],
num_chunks,
mesh.get_local_rank(mesh_dim),
return_offset=True,
@@ -146,19 +160,24 @@
return (
self.reduce_op == other.reduce_op
- and self.logical_dim_size == other.logical_dim_size
+ and self.offset_shape == other.offset_shape
+ and self.offset_dim == other.offset_dim
)
def __hash__(self) -> int:
return 1 + hash(
- (self.logical_dim_size, id(self.mask_buffer.data), self.reduce_op)
+ (
+ self.reduce_op,
+ self.offset_shape,
+ self.offset_dim,
+ )
)
def __repr__(self) -> str:
"""
machine readable representation of the MaskPartial placement
"""
- return f"_MaskPartial(logical_dim_size={self.logical_dim_size})"
+ return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})"
def __str__(self) -> str:
"""
@@ -192,7 +211,7 @@
single_mesh_dim_strategies.append(colwise_sharding)
# rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial
- embedding_partial_placement = _MaskPartial(logical_dim_size=weight_shape[0])
+ embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0)
# NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates
# from the input indices and use it for output reduction
diff --git a/torch/distributed/_tensor/ops/_tensor_ops.py b/torch/distributed/_tensor/ops/_tensor_ops.py
index 223ff06..335773f 100644
--- a/torch/distributed/_tensor/ops/_tensor_ops.py
+++ b/torch/distributed/_tensor/ops/_tensor_ops.py
@@ -408,7 +408,7 @@
# this only works when the input is sharded on the gather dimension, and
# index has size 1 on the gather dimension
if index_shape[dim] == 1:
- index_partial_placement = _MaskPartial(logical_dim_size=input_shape[dim])
+ index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim)
input_sharding: PlacementList = [
index_partial_placement,
Shard(dim),
diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py
index ead6cca..79f6f08 100644
--- a/torch/distributed/tensor/parallel/loss.py
+++ b/torch/distributed/tensor/parallel/loss.py
@@ -200,7 +200,8 @@
local_weight: Optional[Tensor],
reduction: int,
ignore_index: int,
- channel_dim_size: int,
+ input_shape: torch.Size,
+ channel_dim: int,
mesh: DeviceMesh,
mesh_dim: int,
) -> Tuple[Tensor, Tensor]:
@@ -230,7 +231,7 @@
# The following code block is a distributed version of
# result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim)
- partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
+ partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target_partial_ = partial_placement._partition_value(
safe_target_, mesh, mesh_dim
)
@@ -317,7 +318,8 @@
local_weight,
reduction,
ignore_index,
- channel_dim_size,
+ x.shape,
+ channel_dim,
spec.mesh,
mesh_dim,
)
@@ -348,7 +350,8 @@
reduction: int,
ignore_index: int,
total_weight: Tensor,
- channel_dim_size: int,
+ input_shape: torch.Size,
+ channel_dim: int,
mesh: DeviceMesh,
mesh_dim: int,
) -> Tensor:
@@ -362,7 +365,7 @@
# The following code block is a distributed version of
# grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0)
- partial_placement = _MaskPartial(logical_dim_size=channel_dim_size)
+ partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim)
safe_target = safe_target.squeeze(channel_dim).flatten()
masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim)
# only update grad_input to -1 if not masked
@@ -422,7 +425,6 @@
total_weight = cast(Tensor, args[6])
channel_dim = 1 if x.dim() >= 2 else 0
- channel_dim_size = x.shape[channel_dim]
spec = x._spec
mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim)
@@ -449,7 +451,8 @@
reduction,
ignore_index,
total_weight,
- channel_dim_size,
+ x.shape,
+ channel_dim,
spec.mesh,
mesh_dim,
)