[dtensor] tensor ops to use strategy based sharding prop (#100607)
This is the first series of PR that adopts operator impls to use a
strategy based approach, each op utilizes OpStrategy and PlacementStrategy
to generate their own strategy. By utilizing the strategy based
approach along with the op graph, we could enable more advanced op
implementation (decomp is possible), and turn the sharding prop to be
more like a contraint satisfication problem.
This PR alone only adds some basic tensor op strategies, and it directly
works on the op graph that was used for metadata propagation. The tensor ops
added in this PR mainly follows one of the arg strategy. The next set of
PRs would add more op strategies to other ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100607
Approved by: https://github.com/XilunWu
diff --git a/torch/distributed/_tensor/dispatch.py b/torch/distributed/_tensor/dispatch.py
index c7c3cee..f08e7d2 100644
--- a/torch/distributed/_tensor/dispatch.py
+++ b/torch/distributed/_tensor/dispatch.py
@@ -145,7 +145,7 @@
# unwrap the args/kwargs schema
op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)
- output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)
+ output_sharding = sharding_propagator.propagate(op_call, op_schema)
# first we need to lift some private aten aliases to public calls
if op_call in _CURRENT_DECOMPOSITION_TABLE:
diff --git a/torch/distributed/_tensor/op_schema.py b/torch/distributed/_tensor/op_schema.py
index e77f9a3..e8bb45d 100644
--- a/torch/distributed/_tensor/op_schema.py
+++ b/torch/distributed/_tensor/op_schema.py
@@ -76,6 +76,12 @@
strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies])
return f"OpStrategy: [{strategy_list_str}]"
+ def max_num_shards(self) -> int:
+ """
+ Returns the max number of shards across all placement strategies
+ """
+ return max([strategy.output_spec.num_shards for strategy in self.strategies])
+
class TupleStrategy(StrategyType):
"""
@@ -204,6 +210,19 @@
DTensorSpec, _rebuild_tensor_from_dtensor_meta, self.kwargs_schema
)
+ def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None:
+ suggestion_args_spec = self.args_spec
+ new_arg_schema: List[object] = []
+ idx_of_args_spec = 0
+ for arg in origin_schema.args_schema:
+ if isinstance(arg, DTensorSpec):
+ new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
+ idx_of_args_spec += 1
+ else:
+ new_arg_schema.append(arg)
+ self.args_schema = tuple(new_arg_schema)
+ self.kwargs_schema = origin_schema.kwargs_schema
+
@dataclass
class OutputSharding:
diff --git a/torch/distributed/_tensor/ops/common_rules.py b/torch/distributed/_tensor/ops/common_rules.py
index 8c30e3d..65189f5 100644
--- a/torch/distributed/_tensor/ops/common_rules.py
+++ b/torch/distributed/_tensor/ops/common_rules.py
@@ -13,22 +13,6 @@
return string[:idx] + new_char + string[idx + 1 :]
-def _inplace_rewrap_schema_suggestion(
- suggestion: OpSchema, input_schema: OpSchema
-) -> None:
- suggestion_args_spec = suggestion.args_spec
- new_arg_schema: List[object] = []
- idx_of_args_spec = 0
- for arg in input_schema.args_schema:
- if isinstance(arg, DTensorSpec):
- new_arg_schema.append(suggestion_args_spec[idx_of_args_spec])
- idx_of_args_spec += 1
- else:
- new_arg_schema.append(arg)
- suggestion.args_schema = tuple(new_arg_schema)
- suggestion.kwargs_schema = input_schema.kwargs_schema
-
-
def _gen_reshard_suggestions(
op_schema: OpSchema,
input_dims: List[str],
@@ -48,7 +32,7 @@
)
)
suggested_schema = OpSchema(op_schema.func_schema, tuple(suggested_arg_specs), {})
- _inplace_rewrap_schema_suggestion(suggested_schema, op_schema)
+ suggested_schema._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
None,
schema_suggestions=[suggested_schema],
@@ -350,7 +334,7 @@
input_spec.mesh, reshard_dim_map, [], tensor_meta=input_spec.tensor_meta
)
schema_suggestion = OpSchema(op_schema.func_schema, (no_partial_spec,), {})
- _inplace_rewrap_schema_suggestion(schema_suggestion, op_schema)
+ schema_suggestion._inplace_rewrap_schema_suggestion(op_schema)
return OutputSharding(
output_spec=None, schema_suggestions=[schema_suggestion]
)
diff --git a/torch/distributed/_tensor/ops/pointwise_ops.py b/torch/distributed/_tensor/ops/pointwise_ops.py
index 7e4fc50..bc80923 100644
--- a/torch/distributed/_tensor/ops/pointwise_ops.py
+++ b/torch/distributed/_tensor/ops/pointwise_ops.py
@@ -1,4 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
+
import torch
from torch.distributed._tensor.ops.common_rules import (
diff --git a/torch/distributed/_tensor/ops/tensor_ops.py b/torch/distributed/_tensor/ops/tensor_ops.py
index 8232652..61a5243 100644
--- a/torch/distributed/_tensor/ops/tensor_ops.py
+++ b/torch/distributed/_tensor/ops/tensor_ops.py
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
-from typing import cast, List, Optional, Sequence, Tuple
+from typing import cast, Dict, List, Optional, Sequence, Tuple
import torch
@@ -11,38 +11,92 @@
Replicate,
Shard,
)
-from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from torch.distributed._tensor.op_schema import (
+ OpSchema,
+ OpStrategy,
+ OutputSharding,
+ PlacementStrategy,
+ StrategyType,
+)
from torch.distributed._tensor.ops.common_rules import pointwise_rule
-from torch.distributed._tensor.ops.utils import normalize_dim, prod, register_prop_rule
+from torch.distributed._tensor.ops.utils import (
+ normalize_dim,
+ prod,
+ register_op_strategy,
+ register_prop_rule,
+)
+from torch.fx import Node
aten = torch.ops.aten
-# NOTE: the default propagation rule should apply for
-# any operator that does not return a DTensor, i.e.
-# for operators that only returns int/float/bool, we by
-# default still propagate the spec, this is to ensure
-# that we only return None for the case where the sharding
-# propagation failed, and we should do auto-redistribute
-def default_prop_rule(op_schema: OpSchema) -> OutputSharding:
- # by default prop the first arg spec
- return OutputSharding(op_schema.args_spec[0])
-
-
-def prop_create_like(op_schema: OpSchema) -> OutputSharding:
- # For operators that create tensors with same shape as input but
- # with specific content that does not depend on the input, we
- # can propagate Sharding, but we have to make sure we move from
- # partial to replicated.
- input_spec = op_schema.args_spec[0]
- output_spec = DTensorSpec(
- mesh=input_spec.mesh,
- placements=tuple(
- Replicate() if isinstance(p, _Partial) else p for p in input_spec.placements
- ),
+@register_op_strategy(
+ [
+ aten._to_copy.default,
+ aten.clone.default,
+ aten.contiguous.default,
+ aten.copy_.default,
+ aten.detach.default,
+ aten.new_empty_strided.default, # TODO: re-think new_empty_strided
+ ]
+)
+def default_strategy(
+ node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
+) -> StrategyType:
+ # Default strategy by default just propagate the first input strategy
+ select_strategy = node_to_strategy[node.all_input_nodes[0]]
+ assert isinstance(select_strategy, OpStrategy)
+ return OpStrategy(
+ [
+ PlacementStrategy(arg_strategy.output_spec)
+ for arg_strategy in select_strategy.strategies
+ ]
)
- return OutputSharding(output_spec=output_spec)
+
+
+@register_op_strategy(
+ [
+ aten.empty_like.default,
+ aten.fill_.Scalar,
+ aten.full_like.default,
+ aten.ones_like.default,
+ aten.zero_.default,
+ aten.zeros_like.default,
+ ]
+)
+def create_like_strategy(
+ node: Node, mesh: DeviceMesh, node_to_strategy: Dict[Node, StrategyType]
+) -> StrategyType:
+ # create_like_strategy deals with ops that creating tensors with same
+ # shape as input, but with specific content that does not depend on
+ # the input, we can propagate sharding, but we have to make sure we
+ # move from partial to replicated.
+ select_strategy = node_to_strategy[node.all_input_nodes[0]]
+ create_like_strategy = OpStrategy([])
+ assert isinstance(select_strategy, OpStrategy)
+ for arg_strategy in select_strategy.strategies:
+ arg_spec = arg_strategy.output_spec
+ if arg_spec.sums:
+ # if the arg_spec have partial, accept partial
+ # in the input_specs but output replicate for
+ # those corresponding mesh dims
+ output_spec = DTensorSpec(
+ mesh=arg_spec.mesh,
+ placements=tuple(
+ Replicate() if isinstance(p, _Partial) else p
+ for p in arg_spec.placements
+ ),
+ )
+ create_like_strategy.strategies.append(
+ PlacementStrategy(output_spec=output_spec, input_specs=(arg_spec,))
+ )
+
+ else:
+ create_like_strategy.strategies.append(PlacementStrategy(arg_spec))
+
+ return create_like_strategy
@register_prop_rule(aten._local_scalar_dense.default)
@@ -85,36 +139,12 @@
return OutputSharding(output_spec=None)
-default_prop_ops = [
- aten._to_copy.default,
- aten.clone.default,
- aten.contiguous.default,
- aten.copy_.default,
- aten.detach.default,
- aten.new_empty_strided.default,
-]
-
-create_like_ops = [
- aten.empty_like.default,
- aten.fill_.Scalar,
- aten.full_like.default,
- aten.ones_like.default,
- aten.zero_.default,
- aten.zeros_like.default,
-]
-
new_factory_ops = [
aten.new_full.default,
aten.new_ones.default,
aten.new_zeros.default,
]
-for op in default_prop_ops:
- register_prop_rule(op)(default_prop_rule)
-
-for op in create_like_ops:
- register_prop_rule(op)(prop_create_like)
-
for op in new_factory_ops:
register_prop_rule(op)(new_factory_rule)
diff --git a/torch/distributed/_tensor/ops/utils.py b/torch/distributed/_tensor/ops/utils.py
index 8513ab2..1568cb3 100644
--- a/torch/distributed/_tensor/ops/utils.py
+++ b/torch/distributed/_tensor/ops/utils.py
@@ -7,15 +7,6 @@
from torch.distributed._tensor.api import DTensor
-# pyre-fixme[3]: Return type must be annotated.
-# pyre-fixme[2]: Parameter must be annotated.
-def unwrap_single_placement(e):
- if not isinstance(e, DTensor):
- return None
- assert len(e.placements) == 1, "more than one placement!"
- return e.placements[0]
-
-
# convenient wrapper to register sharding propagation rules
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
@@ -32,6 +23,19 @@
return wrapper
+def register_op_strategy(op):
+ # pyre-fixme[53]: Captured variable `func` is not annotated.
+ # pyre-fixme[3]: Return type must be annotated.
+ # pyre-fixme[2]: Parameter must be annotated.
+ def wrapper(impl):
+ overloads = op if isinstance(op, list) else [op]
+ for overload in overloads:
+ DTensor._propagator.register_op_strategy(overload, impl)
+ return impl
+
+ return wrapper
+
+
def as_list(
x: Union[List[object], object]
# pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
diff --git a/torch/distributed/_tensor/placement_types.py b/torch/distributed/_tensor/placement_types.py
index e998cc5..a7750eb 100644
--- a/torch/distributed/_tensor/placement_types.py
+++ b/torch/distributed/_tensor/placement_types.py
@@ -413,6 +413,14 @@
return len(self.tensor_meta.shape)
@property
+ def num_shards(self) -> int:
+ num_shards = 1
+ for i, placement in enumerate(self.placements):
+ if placement.is_shard():
+ num_shards *= self.mesh.size(i)
+ return num_shards
+
+ @property
def dim_map(self) -> List[int]:
"""
dim_map is a property we derive from `placements` of
diff --git a/torch/distributed/_tensor/sharding_prop.py b/torch/distributed/_tensor/sharding_prop.py
index a635a7d..435d03f 100644
--- a/torch/distributed/_tensor/sharding_prop.py
+++ b/torch/distributed/_tensor/sharding_prop.py
@@ -4,9 +4,19 @@
import torch.distributed._tensor.api as dtensor
from torch._ops import OpOverload
from torch._subclasses import FakeTensorMode
-from torch.distributed._tensor.op_schema import DTensorSpec, OpSchema, OutputSharding
+from torch.distributed._tensor.device_mesh import DeviceMesh
+from torch.distributed._tensor.op_schema import (
+ DTensorSpec,
+ OpSchema,
+ OpStrategy,
+ OutputSharding,
+ OutputSpecType,
+ PlacementStrategy,
+ StrategyType,
+)
+from torch.fx import Node
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule
-from torch.utils._pytree import tree_map
+from torch.utils._pytree import tree_flatten, tree_map
"""
Print information on ops input shape and sharding for debugging purposes.
@@ -21,6 +31,10 @@
class ShardingPropagator:
def __init__(self) -> None:
self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
+ self.op_strategy_funcs: Dict[
+ OpOverload,
+ Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
+ ] = {}
def register_sharding_prop_rule(
self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
@@ -30,6 +44,16 @@
"""
self.op_to_rules[op_overload] = rule_func
+ def register_op_strategy(
+ self,
+ op_overload: OpOverload,
+ rule_func: Callable[[Node, DeviceMesh, Dict[Node, StrategyType]], StrategyType],
+ ):
+ """
+ Register a sharding strategy generator for an operator.
+ """
+ self.op_strategy_funcs[op_overload] = rule_func
+
def prepare_op_schema(
self, op_call: OpOverload, args: Tuple[object, ...], kwargs: Dict[str, object]
) -> OpSchema:
@@ -54,6 +78,103 @@
return op_schema
+ def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
+ if op_overload in self.op_strategy_funcs:
+ # generate op strategy for the op, this is done by propagating
+ # the sharding in the graph.
+ op_gm = self._prepare_op_graph(op_overload, op_schema)
+ if op_gm is None:
+ return OutputSharding(None, [op_schema])
+
+ flat_args_sharding, _ = tree_flatten(
+ [op_schema.args_schema, op_schema.kwargs_schema]
+ )
+ node_to_strategy: Dict[Node, StrategyType] = {}
+ output_node = None
+ out_node_strategy = None
+ mesh = flat_args_sharding[0].mesh
+ placeholder_idx = 0
+ for node in op_gm.graph.nodes:
+ if node.op == "placeholder":
+ # set sharding to placeholders if it's Node
+ if isinstance(flat_args_sharding[placeholder_idx], DTensorSpec):
+ strategy = PlacementStrategy(
+ flat_args_sharding[placeholder_idx]
+ )
+ # for eager execution, inputs only have one fixed sharding
+ node_to_strategy[node] = OpStrategy([strategy])
+ placeholder_idx += 1
+ elif node.op == "call_function":
+ if isinstance(node.target, OpOverload):
+ op_strategy_func = self.op_strategy_funcs[op_overload]
+ out_strategies = op_strategy_func(node, mesh, node_to_strategy)
+ node_to_strategy[node] = out_strategies
+ else:
+ raise NotImplementedError(
+ f"Unsupported function: {node.target}"
+ )
+ elif node.op == "output":
+ output_node = node.args[0]
+ out_node_strategy = node_to_strategy[output_node[0]]
+ else:
+ raise NotImplementedError(f"Unsupported node type: {node.op}")
+
+ # NOTE: This had the assumption we only have one call_function op in the
+ # op graph, we need to harden this logic when there're decomposed ops.
+ assert isinstance(out_node_strategy, OpStrategy)
+ # we take the first strategy for now
+ # TODO: add a min cost selection logic
+ output_strategy = out_node_strategy.strategies[0]
+ needs_redistribute = False
+ expected_input_specs = []
+ for idx, input_spec in enumerate(op_schema.args_spec):
+ desired_spec = (
+ output_strategy.output_spec
+ if output_strategy.input_specs is None
+ else output_strategy.input_specs[idx]
+ )
+ expected_input_specs.append(desired_spec)
+ if input_spec != desired_spec:
+ needs_redistribute = True
+
+ if needs_redistribute:
+ suggestion_schema = OpSchema(
+ op_schema.func_schema, tuple(expected_input_specs), {}
+ )
+ suggestion_schema._inplace_rewrap_schema_suggestion(op_schema)
+ else:
+ suggestion_schema = op_schema
+
+ output_sharding = OutputSharding(
+ output_strategy.output_spec,
+ [suggestion_schema],
+ )
+ if output_node is not None:
+ self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
+ return output_sharding
+
+ elif op_overload in self.op_to_rules:
+ return self.propagate_op_sharding(op_overload, op_schema)
+ else:
+ raise NotImplementedError(
+ f"Operator {op_overload} does not have a sharding strategy registered."
+ )
+
+ def _wrap_output_spec_meta(
+ self, output_spec: OutputSpecType, output_nodes: Node
+ ) -> None:
+ """
+ Wrap the output_spec with the metadata from the output node.
+ """
+ if output_spec is not None:
+ assert isinstance(output_nodes, (tuple, list))
+ if isinstance(output_spec, DTensorSpec):
+ output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
+ elif isinstance(output_spec, (tuple, list)):
+ for i, spec in enumerate(output_spec):
+ if isinstance(spec, DTensorSpec):
+ spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
+
def propagate_op_sharding(
self, op_overload: OpOverload, op_schema: OpSchema
) -> OutputSharding:
@@ -61,19 +182,17 @@
Propagate the sharding for an operator given the op_schema.
"""
# first we propagate the tensor metadata
- output_node = self._propagate_tensor_meta(op_overload, op_schema)
+ output_node = None
+ op_gm = self._prepare_op_graph(op_overload, op_schema)
+ if op_gm is not None:
+ for node in op_gm.graph.nodes:
+ if node.op == "output":
+ output_node = node.args[0]
# then we propagate the sharding
- sharding_prop_func = self.op_to_rules.get(op_overload, None)
+ sharding_prop_func = self.op_to_rules[op_overload]
- if sharding_prop_func is None:
- # step 1. If there's not even one sharding rule
- # implemented for the operator, we error out.
- raise NotImplementedError(
- f"Operator {op_overload} does not have a DistributedTensor rule registered."
- )
-
- # step 2. there's sharding propagation rule, run
+ # step 1. there's sharding propagation rule, run
# sharding propagation to get the output sharding
try:
output_sharding = sharding_prop_func(op_schema)
@@ -86,7 +205,7 @@
f"Error: {e}"
) from e
- # step 3. if can't get output_spec from sharding
+ # step 2. if can't get output_spec from sharding
# propagation (i.e. no rules apply for input
# placements), we return the output sharding
# with schema suggestions, which can be used to
@@ -110,8 +229,6 @@
# to get an eligible input, which we will pick a
# schema suggestion base on the redistribute cost.
# For now we simply pick the first suggestion.
- # TODO: implement full auto distribute with a
- # simple cost estimation model
suggested_input_schema = output_sharding.schema_suggestions[0]
# run sharding propagation again with suggested schema
propagation_res = sharding_prop_func(suggested_input_schema)
@@ -126,24 +243,15 @@
# associate the output sharding with the output metadata
if output_node is not None:
- output_nodes = output_node.args[0]
- output_spec = output_sharding.output_spec
- if output_spec is not None:
- assert isinstance(output_nodes, (tuple, list))
- if isinstance(output_spec, DTensorSpec):
- output_spec.tensor_meta = output_nodes[0].meta["tensor_meta"]
- elif isinstance(output_spec, (tuple, list)):
- for i, spec in enumerate(output_spec):
- if isinstance(spec, DTensorSpec):
- spec.tensor_meta = output_nodes[i].meta["tensor_meta"]
+ self._wrap_output_spec_meta(output_sharding.output_spec, output_node)
return output_sharding
- def _propagate_tensor_meta(
+ def _prepare_op_graph(
self,
op_overload: OpOverload,
op_schema: OpSchema,
- ) -> Optional[torch.fx.Node]:
+ ) -> Optional[torch.fx.GraphModule]:
# right now we only use the graph for metadata prop, but next we will use
# the graph to do sharding prop together
@@ -163,11 +271,7 @@
fake_kwargs = op_schema.gen_fake_kwargs()
g = get_isolated_graphmodule(op_overload, fake_args, fake_kwargs)
- output = None
- for node in g.graph.nodes:
- if node.op == "output":
- output = node
- return output
+ return g
class _CachingPropagator(ShardingPropagator):
@@ -176,18 +280,16 @@
This is currently experimental for Tensor Parallel usage.
"""
- def __init__(self, op_to_rules=None) -> None:
+ def __init__(self, propagator: ShardingPropagator) -> None:
super().__init__()
- if op_to_rules is not None:
- self.op_to_rules = op_to_rules
+ self.op_to_rules = propagator.op_to_rules
+ self.op_strategy_funcs = propagator.op_strategy_funcs
# cache table for sharding propagation results, we might need to
# limit the size of the cache table in the future
self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
- def propagate_op_sharding(
- self, op_overload: OpOverload, op_schema: OpSchema
- ) -> OutputSharding:
+ def propagate(self, op_overload: OpOverload, op_schema: OpSchema) -> OutputSharding:
"""
Propagate the sharding for an operator given the op_schema.
Cache the propagation results to avoid running propagation again.
@@ -196,7 +298,7 @@
return self.cached_prop_results[op_schema]
else:
# call DTensor's propagate_op_sharding to get the prop result
- output_sharding = super().propagate_op_sharding(op_overload, op_schema)
+ output_sharding = super().propagate(op_overload, op_schema)
# update cached table
self.cached_prop_results[op_schema] = output_sharding
return output_sharding
diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py
index 5af3afa..cdd575e 100644
--- a/torch/distributed/tensor/parallel/api.py
+++ b/torch/distributed/tensor/parallel/api.py
@@ -30,7 +30,7 @@
# switch the DTensor propagator to use the caching propagator to speed up
# the TP eager execution time.
-DTensor._propagator = _CachingPropagator(DTensor._propagator.op_to_rules)
+DTensor._propagator = _CachingPropagator(DTensor._propagator)
def parallelize_module( # type: ignore[return]
module: nn.Module,