[dtensor] refactor and improve readability of _dispatch.py (#132682)
as titled. It also changes some comments of _op_schema.py to make them
update to date
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132682
Approved by: https://github.com/XilunWu
ghstack dependencies: #131210
diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py
index 2fc2761..c224f41 100644
--- a/torch/distributed/_tensor/_dispatch.py
+++ b/torch/distributed/_tensor/_dispatch.py
@@ -123,55 +123,11 @@
assert output_sharding is not None, "output sharding should not be None"
mesh = op_info.mesh
- if mesh.get_coordinate() is None:
- # For a non-participating device, we do:
- # 1. if the return type is scalar, set the local result to None.
- # The local results from all devices will then be all-gathered
- # and a reduce op will be performed on the list of results
- # with appropriate operators:
- # for bool type, we by default use AND to reduce;
- # we can extend for more ops if necessary.
- # 2. if the return type is Tensor or List[Tensor], return empty
- # tensor(s) with correct dtype.
- spec = output_sharding.output_spec
- ret_list = op_info.schema.op._schema.returns
-
- if spec is None:
- # For a scalar return type, the non-participating device has None
- # as its local result
- local_results: object = None
- else:
-
- def default_tensor(spec: DTensorSpec) -> torch.Tensor:
- if spec.tensor_meta is not None:
- shape = spec.tensor_meta.shape
- dtype = spec.tensor_meta.dtype
- if len(shape) == 0:
- # scalar tensor
- return torch.zeros((), dtype=dtype)
- else:
- # non-scalar tensor
- return torch.tensor([], dtype=dtype)
- else:
- raise RuntimeError(f"{spec} has no tensor metadata.")
-
- if isinstance(spec, DTensorSpec):
- # return a Tensor value
- local_results = default_tensor(spec)
- elif isinstance(spec, Sequence):
- # return a List[Tensor] value
- local_results = [
- default_tensor(s) if s is not None else None for s in spec
- ]
- assert isinstance(local_results, List)
- if None in local_results:
- ret_type = str(ret_list[0].type)
- raise NotImplementedError(
- f"return type {ret_type} in DTensor op is not supported"
- )
- else:
+ if mesh.get_coordinate() is not None:
+ # computation that happens in the current rank of the mesh, normal case
if output_sharding.needs_redistribute:
- # compute locally with redistribute first if needed
+ # If sharding propagation decision needs redistribute, perform redistribute
+ # on args first, which could potentially modify args (i.e. allgather certain arg)
assert output_sharding.redistribute_schema is not None
self.redistribute_local_args(
op_info, output_sharding.redistribute_schema
@@ -201,16 +157,62 @@
if random._rng_tracker and not first_local_arg.is_meta
else contextlib.nullcontext()
)
-
- # For DTensor random operator, run it within a distribute region
+ # For DTensor random operator, run it within a RNGTracker context to
+ # ensure the random number generator is properly distributed.
with rng_context:
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
else:
+ # normal case, run local sharded op computation
local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
- # communicate the result to all ranks for some operators that return scalar value
+ else:
+ # For a non-participating device (happens on rank that does not belong to
+ # the device mesh), we do:
+ # 1. if the return type is scalar, set the local result to None.
+ # 2. if the return type is Tensor or List[Tensor], return empty
+ # tensor(s) with correct dtype.
+ spec = output_sharding.output_spec
+ ret_list = op_info.schema.op._schema.returns
+
+ if spec is None:
+ # For a scalar return type, the non-participating device has None
+ # as its local result
+ local_results = None
+ else:
+
+ def default_tensor(spec: DTensorSpec) -> torch.Tensor:
+ if spec.tensor_meta is not None:
+ shape = spec.tensor_meta.shape
+ dtype = spec.tensor_meta.dtype
+ if len(shape) == 0:
+ # scalar tensor
+ return torch.zeros((), dtype=dtype)
+ else:
+ # non-scalar tensor
+ return torch.tensor([], dtype=dtype)
+ else:
+ raise RuntimeError(f"{spec} has no tensor metadata.")
+
+ if isinstance(spec, DTensorSpec):
+ # return a Tensor value
+ local_results = default_tensor(spec)
+ elif isinstance(spec, Sequence):
+ # return a List[Tensor] value
+ local_results = [
+ default_tensor(s) if s is not None else None for s in spec
+ ]
+ assert isinstance(local_results, List)
+ if None in local_results:
+ ret_type = str(ret_list[0].type)
+ raise NotImplementedError(
+ f"return type {ret_type} in DTensor op is not supported"
+ )
+
if output_sharding.output_spec is None:
if op_call == aten.equal.default:
+ # For equal operator, The local results from all devices should be all-gathered
+ # and a reduce op (AND) will be performed on the list of results to ensure SPMD
+ # execution. We can extend this for more ops if necessary.
obj_list = [None for _ in range(dist.get_world_size())]
dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined]
obj_list = list(filter(lambda x: x is not None, obj_list))
@@ -250,9 +252,6 @@
suggested_input_schema: OpSchema,
) -> None:
# NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
-
- # TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten
- # Need to fix all the ops before doing this.
if op_info.args_tree_spec is not None:
flatten_args_schema_to_reshard = tuple(
pytree.tree_leaves(suggested_input_schema.args_schema)
@@ -283,13 +282,13 @@
args: Tuple[object, ...],
kwargs: Dict[str, object],
) -> OpInfo:
- # get runtime schema to determine whether to use pytree to flatten inputs
+ # get runtime schema info to determine whether to use pytree to flatten inputs
runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
op_call, None
)
if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
- # flatten args/kwargs when necessary
+ # flatten args/kwargs when op says necessary
tree_args, args_spec = pytree.tree_flatten(args)
args_list: Sequence[object] = tree_args
else:
@@ -301,41 +300,6 @@
local_kwargs: Dict[str, object] = {}
mesh: Optional[DeviceMesh] = None
- def try_get_replicate_spec(
- tensor_arg: torch.Tensor, mesh: "DeviceMesh"
- ) -> DTensorSpec:
- # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg.
- if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
- warnings.warn(
- "Found a non-scalar tensor with numel=1 and ndim!=0, "
- "we are implicitly creating a replicated DTensor for it. "
- "However, please consider changing it to a scalar tensor "
- "or explicitly create a DTensor under distributed enviroment."
- )
-
- # if the arg.numel() == 1, arg.ndim could be 0 or 1.
- if (
- tensor_arg.ndim <= 1
- and tensor_arg.numel() == 1
- or self._allow_implicit_replication
- ):
- # scalar tensor can be safely treated as replicated
- replication_spec = DTensorSpec(
- mesh,
- (Replicate(),) * mesh.ndim,
- tensor_meta=TensorMeta(
- shape=tensor_arg.shape,
- stride=tensor_arg.stride(),
- dtype=tensor_arg.dtype,
- ),
- )
- else:
- raise RuntimeError(
- f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
- " torch.Tensor to DTensor before calling distributed operators!"
- )
- return replication_spec
-
for arg in args_list:
if isinstance(arg, dtensor.DTensor):
args_schema.append(arg._spec)
@@ -350,7 +314,9 @@
mesh = arg.device_mesh
elif isinstance(arg, torch.Tensor):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
- args_schema.append(try_get_replicate_spec(arg, mesh))
+ args_schema.append(
+ self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
+ )
local_args.append(arg)
else:
args_schema.append(arg)
@@ -369,7 +335,9 @@
mesh = v.device_mesh
elif isinstance(v, torch.Tensor):
mesh = mesh or try_find_mesh_from_args(op_call, args_list)
- kwargs_schema[k] = try_get_replicate_spec(v, mesh)
+ kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
+ op_call, v, mesh
+ )
local_kwargs[k] = v
else:
kwargs_schema[k] = v
@@ -418,3 +386,36 @@
# if the res contains only non tensor values (i.e. int/float/none), we simply return it
# without rewrapping to DTensor.
return res
+
+ def _try_replicate_spec_for_scalar_tensor(
+ self,
+ op_call: torch._ops.OpOverload,
+ tensor_arg: torch.Tensor,
+ mesh: "DeviceMesh",
+ ) -> DTensorSpec:
+ # util function to produce a replicate spec for a scalar tensor arg/kwarg
+ if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
+ warnings.warn(
+ "Found a non-scalar tensor with numel=1 and ndim!=0, "
+ "we are implicitly creating a replicated DTensor for it. "
+ "However, please consider changing it to a scalar tensor "
+ "or explicitly create a DTensor under distributed enviroment."
+ )
+
+ if tensor_arg.numel() == 1 or self._allow_implicit_replication:
+ # scalar tensor can be safely treated as replicated
+ replication_spec = DTensorSpec(
+ mesh,
+ (Replicate(),) * mesh.ndim,
+ tensor_meta=TensorMeta(
+ shape=tensor_arg.shape,
+ stride=tensor_arg.stride(),
+ dtype=tensor_arg.dtype,
+ ),
+ )
+ else:
+ raise RuntimeError(
+ f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
+ " torch.Tensor to DTensor before calling distributed operators!"
+ )
+ return replication_spec
diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py
index 7c7aea7..7b8af39 100644
--- a/torch/distributed/_tensor/_op_schema.py
+++ b/torch/distributed/_tensor/_op_schema.py
@@ -214,10 +214,10 @@
@dataclass
class OpSchema:
"""
- OpSchema is a data class that describes an operator input schemas, it
- includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order
- preserved). It is mainly used by the dispatching logic below to run things like
- sharding propagation.
+ OpSchema is a data class that describes an operator input schemas, it includes
+ DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
+ preserved). It is mainly used by the DTensor's dispatching logic to perform various
+ actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
NOTE: this should be used as a read only data class
TODO: make this a frozen dataclass
@@ -225,9 +225,9 @@
Args:
op: the operator overload we are intercepting
args_schema: contains args except that the DTensor args have been replaced
- with its DTensorSpec
+ with its DTensorSpec or OpStrategy
kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
- with its DTensorSpec
+ with its DTensorSpec or OpStrategy
"""
op: OpOverload
@@ -427,13 +427,13 @@
@dataclass
class OutputSharding:
"""
- OutputSharding is a data class that is used by the sharding propagation
- rules, it could set the output_spec upon successful propagation, and if
- it failed, output_spec would become None and sharding propagation rules
- could give a list of suggestions for inputs to reshard.
+ OutputSharding is a data class that is used by the sharding propagation,
+ it could set the output_spec upon successful propagation. If needs_redistribute
+ is set to True, a redistribute_schema would be returned together to indicate
+ the input arguments needs to be redistributed before the op execution.
- NOTE: the schema_suggestion generated by sharding propagation should be
- exactly the same as the operator OpSchema, except the DTensor DTensorSpecs
+ NOTE: the redistribute_schema generated by sharding propagation should be
+ exactly the same as the operator OpSchema, except the DTensorSpecs
"""
output_spec: OutputSpecType