[dtensor] refactor and generalize stack strategy (#121869)
This PR rewrite the stack strategy to be more generalized, basically
stack/cat like strategy follow pattern need to be smarter, i.e. it
should be able to identify:
1. PR, PP, RP -> follow PP
2. RR, SR, RS -> follow SS
So this PR refactors how the follow strategy should work, and make sure
we start following the strategy that incurred lowest cost. i.e. for
multiple PR, RP placements, we should be able to further delay the
pending sum reductions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121869
Approved by: https://github.com/awgu
diff --git a/test/distributed/_tensor/test_tensor_ops.py b/test/distributed/_tensor/test_tensor_ops.py
index 0510d0c..e4d1e3e 100644
--- a/test/distributed/_tensor/test_tensor_ops.py
+++ b/test/distributed/_tensor/test_tensor_ops.py
@@ -5,6 +5,7 @@
from torch.distributed._tensor import DeviceMesh, distribute_tensor, DTensor
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorConverter,
@@ -230,6 +231,44 @@
self.assertEqual(zeros_expected, zeros_like_dt.to_local())
@with_comms
+ @skip_if_lt_x_gpu(4)
+ def test_stack(self):
+ mesh_2d = DeviceMesh(
+ self.device_type, torch.arange(self.world_size).reshape(2, 2)
+ )
+ partial_replicate_placement = [_Partial(), Replicate()]
+ partial_placement = [_Partial(), _Partial()]
+
+ partial_replicate_dt = DTensor.from_local(
+ torch.randn(4, 8), mesh_2d, partial_replicate_placement
+ )
+ partial_dt = DTensor.from_local(torch.randn(4, 8), mesh_2d, partial_placement)
+
+ stack_dt = torch.stack([partial_replicate_dt, partial_dt])
+ self.assertEqual(stack_dt.placements, tuple(partial_placement))
+ self.assertEqual(stack_dt.shape, (2, 4, 8))
+
+ mesh_1d = DeviceMesh(self.device_type, torch.arange(self.world_size))
+ # stack before/after shard dim
+ global_input = torch.randn(8, 8)
+ shard1_input = distribute_tensor(global_input, mesh_1d, [Shard(1)])
+ cloned_shard1_input = shard1_input.clone()
+ stack_shard1_dt = torch.stack([shard1_input, cloned_shard1_input])
+ self.assertEqual(stack_shard1_dt.placements, (Shard(2),))
+ self.assertEqual(stack_shard1_dt.shape, (2, 8, 8))
+ self.assertEqual(
+ stack_shard1_dt.full_tensor(), torch.stack([global_input, global_input])
+ )
+
+ stack_dim1_shard1_dt = torch.stack([shard1_input, cloned_shard1_input], dim=1)
+ self.assertEqual(stack_dim1_shard1_dt.placements, (Shard(2),))
+ self.assertEqual(stack_dim1_shard1_dt.shape, (8, 2, 8))
+ self.assertEqual(
+ stack_dim1_shard1_dt.full_tensor(),
+ torch.stack([global_input, global_input], dim=1),
+ )
+
+ @with_comms
def test_equal(self):
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
shard_spec = [Shard(0)]
diff --git a/torch/distributed/_tensor/ops/math_ops.py b/torch/distributed/_tensor/ops/math_ops.py
index 64d5a09..14258aa 100644
--- a/torch/distributed/_tensor/ops/math_ops.py
+++ b/torch/distributed/_tensor/ops/math_ops.py
@@ -123,6 +123,14 @@
return tensor ** (1.0 / self.norm_type)
return tensor
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, _NormPartial):
+ return False
+ return self.norm_type == other.norm_type
+
+ def __hash__(self) -> int:
+ return 1 + hash(self.norm_type)
+
def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[List[int]]:
if dims_arg is None:
diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py
index 4bfc4c6..71879eb 100644
--- a/torch/distributed/_tensor/ops/tensor_ops.py
+++ b/torch/distributed/_tensor/ops/tensor_ops.py
@@ -399,6 +399,80 @@
return OpStrategy(all_strategies)
+def _derive_follow_placements_from_tuple_strategy(
+ tuple_strategy: TupleStrategy,
+) -> Sequence[Placement]:
+ """
+ derive the placements to follow from the tuple strategy, mainly used by
+ aten.stack, aten.cat, where each operand have the same shape, and correspondingly
+ expecting the same sharding
+ """
+
+ def merge_placement(
+ cur_placement: Placement, new_placement: Placement
+ ) -> Placement:
+ # semantic if we already have a follow placement, we
+ # check each placement for the current arg placement
+ # to see if we want to merge/adjust the placement to follow
+ # the priority: Partial -> Shard -> Replicate
+ if cur_placement == new_placement:
+ return cur_placement
+
+ if cur_placement.is_partial():
+ if new_placement.is_shard():
+ # follow new placement
+ return new_placement
+ elif new_placement.is_partial():
+ # different partial types, we can't merge and have to replicate all here
+ return Replicate()
+ else:
+ # follow partial
+ return cur_placement
+ elif cur_placement.is_shard():
+ if new_placement.is_shard():
+ # cur/new placement are different sharding (i.e. different shard dim)
+ # currently fallback to replicate all args
+ return Replicate()
+ else:
+ # for partial/replicate, follow the current shard placement
+ return cur_placement
+ else:
+ # current replicate, just follow new placement
+ return new_placement
+
+ follow_placements: Optional[List[Placement]] = None
+ for arg_strategy in tuple_strategy.childs:
+ assert isinstance(arg_strategy, OpStrategy)
+ for placement_strategy in arg_strategy.strategies:
+ arg_placements = placement_strategy.output_spec.placements
+ if follow_placements is None:
+ follow_placements = list(arg_placements)
+ continue
+ mesh_ndim = len(follow_placements)
+ assert follow_placements is not None
+ for mesh_idx in range(mesh_ndim):
+ # merge placements with the priority
+ follow_placements[mesh_idx] = merge_placement(
+ follow_placements[mesh_idx], arg_placements[mesh_idx]
+ )
+ assert follow_placements is not None, "follow placements should not be None!"
+ return follow_placements
+
+
+def normalize_shard_for_stack(
+ placements: Sequence[Placement], insert_dim: int = 0
+) -> Sequence[Placement]:
+ # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to
+ # be normalized with the new Shard placement
+ normalized_placements: List[Placement] = []
+ for placement in placements:
+ if isinstance(placement, Shard) and placement.dim >= insert_dim:
+ normalized_placements.append(Shard(placement.dim + 1))
+ else:
+ normalized_placements.append(placement)
+ return normalized_placements
+
+
@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True))
def stack_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
args_schema = op_schema.args_schema
@@ -406,59 +480,25 @@
assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}"
dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0
- # Follow the 1st child strategy's placement strategies
- child_strategy = input_tuple_strategy.childs[0]
- assert isinstance(child_strategy, OpStrategy), f"{child_strategy}"
- strategies: List[PlacementStrategy] = []
+ follow_placements = _derive_follow_placements_from_tuple_strategy(
+ input_tuple_strategy
+ )
+ follow_placements = normalize_shard_for_stack(follow_placements, dim)
- # For each arg strategy of the child to follow, we check if every other
- # child has an equal strategy. If so, then that is a valid strategy. If
- # there are no such valid strategies, then we replicate.
- for arg_strategy in child_strategy.strategies:
- arg_spec = arg_strategy.output_spec
- # For each arg strategy (whether the one to follow or other), we
- # replicate the stack dim since we cannot stack on a sharded dim
- if is_tensor_dim_sharded(arg_spec, dim):
- arg_spec = DTensorSpec(
- mesh, unshard_tensor_dim(arg_spec.placements, dim=dim)
- )
- all_compatible = True
- for other_child_strategy in input_tuple_strategy.childs[1:]:
- has_compatible_strategy = False
- assert isinstance(
- other_child_strategy, OpStrategy
- ), f"{other_child_strategy}"
- for other_arg_strategy in other_child_strategy.strategies:
- other_arg_spec = other_arg_strategy.output_spec
- if is_tensor_dim_sharded(other_arg_spec, dim):
- other_arg_spec = DTensorSpec(
- mesh, unshard_tensor_dim(other_arg_spec.placements, dim=dim)
- )
- if other_arg_spec.placements == arg_spec.placements:
- has_compatible_strategy = True
- break
- if not has_compatible_strategy:
- all_compatible = False
- break
- if all_compatible:
- input_specs = tuple(
- arg_spec for _ in range(len(input_tuple_strategy.childs))
- )
- strategies.append(
- PlacementStrategy(
- output_specs=DTensorSpec(mesh, arg_spec.placements),
- input_specs=input_specs,
- )
- )
- if not strategies:
- # Arbitrarily use each child strategy's 0th strategy's output spec
- input_specs = tuple(
- cast(OpStrategy, child_strategy).strategies[0].output_spec
- for child_strategy in input_tuple_strategy.childs
+ # create op strategy base on the follow placements
+ op_strategy = OpStrategy([])
+
+ input_specs = tuple(
+ DTensorSpec(mesh, tuple(follow_placements))
+ for _ in range(len(input_tuple_strategy.childs))
+ )
+ op_strategy.strategies.append(
+ PlacementStrategy(
+ output_specs=DTensorSpec(mesh, tuple(follow_placements)),
+ input_specs=input_specs,
)
- replicate_spec = DTensorSpec(mesh, tuple(Replicate() for _ in range(mesh.ndim)))
- strategies.append(PlacementStrategy(output_specs=replicate_spec))
- return OpStrategy(strategies)
+ )
+ return op_strategy
@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1))