[dtensor] refactor and improve readability of _dispatch.py (#132682)

as titled. It also changes some comments of _op_schema.py to make them
update to date

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132682
Approved by: https://github.com/XilunWu
ghstack dependencies: #131210
diff --git a/torch/distributed/_tensor/_dispatch.py b/torch/distributed/_tensor/_dispatch.py
index 2fc2761..c224f41 100644
--- a/torch/distributed/_tensor/_dispatch.py
+++ b/torch/distributed/_tensor/_dispatch.py
@@ -123,55 +123,11 @@
         assert output_sharding is not None, "output sharding should not be None"
 
         mesh = op_info.mesh
-        if mesh.get_coordinate() is None:
-            # For a non-participating device, we do:
-            #   1. if the return type is scalar, set the local result to None.
-            #   The local results from all devices will then be all-gathered
-            #   and a reduce op will be performed on the list of results
-            #   with appropriate operators:
-            #       for bool type, we by default use AND to reduce;
-            #       we can extend for more ops if necessary.
-            #   2. if the return type is Tensor or List[Tensor], return empty
-            #   tensor(s) with correct dtype.
-            spec = output_sharding.output_spec
-            ret_list = op_info.schema.op._schema.returns
-
-            if spec is None:
-                # For a scalar return type, the non-participating device has None
-                # as its local result
-                local_results: object = None
-            else:
-
-                def default_tensor(spec: DTensorSpec) -> torch.Tensor:
-                    if spec.tensor_meta is not None:
-                        shape = spec.tensor_meta.shape
-                        dtype = spec.tensor_meta.dtype
-                        if len(shape) == 0:
-                            # scalar tensor
-                            return torch.zeros((), dtype=dtype)
-                        else:
-                            # non-scalar tensor
-                            return torch.tensor([], dtype=dtype)
-                    else:
-                        raise RuntimeError(f"{spec} has no tensor metadata.")
-
-                if isinstance(spec, DTensorSpec):
-                    # return a Tensor value
-                    local_results = default_tensor(spec)
-                elif isinstance(spec, Sequence):
-                    # return a List[Tensor] value
-                    local_results = [
-                        default_tensor(s) if s is not None else None for s in spec
-                    ]
-                    assert isinstance(local_results, List)
-                    if None in local_results:
-                        ret_type = str(ret_list[0].type)
-                        raise NotImplementedError(
-                            f"return type {ret_type} in DTensor op is not supported"
-                        )
-        else:
+        if mesh.get_coordinate() is not None:
+            # computation that happens in the current rank of the mesh, normal case
             if output_sharding.needs_redistribute:
-                # compute locally with redistribute first if needed
+                # If sharding propagation decision needs redistribute, perform redistribute
+                # on args first, which could potentially modify args (i.e. allgather certain arg)
                 assert output_sharding.redistribute_schema is not None
                 self.redistribute_local_args(
                     op_info, output_sharding.redistribute_schema
@@ -201,16 +157,62 @@
                     if random._rng_tracker and not first_local_arg.is_meta
                     else contextlib.nullcontext()
                 )
-
-                # For DTensor random operator, run it within a distribute region
+                # For DTensor random operator, run it within a RNGTracker context to
+                # ensure the random number generator is properly distributed.
                 with rng_context:
                     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
             else:
+                # normal case, run local sharded op computation
                 local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
 
-        # communicate the result to all ranks for some operators that return scalar value
+        else:
+            # For a non-participating device (happens on rank that does not belong to
+            # the device mesh), we do:
+            #   1. if the return type is scalar, set the local result to None.
+            #   2. if the return type is Tensor or List[Tensor], return empty
+            #   tensor(s) with correct dtype.
+            spec = output_sharding.output_spec
+            ret_list = op_info.schema.op._schema.returns
+
+            if spec is None:
+                # For a scalar return type, the non-participating device has None
+                # as its local result
+                local_results = None
+            else:
+
+                def default_tensor(spec: DTensorSpec) -> torch.Tensor:
+                    if spec.tensor_meta is not None:
+                        shape = spec.tensor_meta.shape
+                        dtype = spec.tensor_meta.dtype
+                        if len(shape) == 0:
+                            # scalar tensor
+                            return torch.zeros((), dtype=dtype)
+                        else:
+                            # non-scalar tensor
+                            return torch.tensor([], dtype=dtype)
+                    else:
+                        raise RuntimeError(f"{spec} has no tensor metadata.")
+
+                if isinstance(spec, DTensorSpec):
+                    # return a Tensor value
+                    local_results = default_tensor(spec)
+                elif isinstance(spec, Sequence):
+                    # return a List[Tensor] value
+                    local_results = [
+                        default_tensor(s) if s is not None else None for s in spec
+                    ]
+                    assert isinstance(local_results, List)
+                    if None in local_results:
+                        ret_type = str(ret_list[0].type)
+                        raise NotImplementedError(
+                            f"return type {ret_type} in DTensor op is not supported"
+                        )
+
         if output_sharding.output_spec is None:
             if op_call == aten.equal.default:
+                # For equal operator, The local results from all devices should be all-gathered
+                # and a reduce op (AND) will be performed on the list of results to ensure SPMD
+                # execution. We can extend this for more ops if necessary.
                 obj_list = [None for _ in range(dist.get_world_size())]
                 dist.all_gather_object(obj_list, local_results)  # type: ignore[possibly-undefined]
                 obj_list = list(filter(lambda x: x is not None, obj_list))
@@ -250,9 +252,6 @@
         suggested_input_schema: OpSchema,
     ) -> None:
         # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it
-
-        # TODO: the op schema should probably just remain flattened so that we can avoid this tree flatten
-        # Need to fix all the ops before doing this.
         if op_info.args_tree_spec is not None:
             flatten_args_schema_to_reshard = tuple(
                 pytree.tree_leaves(suggested_input_schema.args_schema)
@@ -283,13 +282,13 @@
         args: Tuple[object, ...],
         kwargs: Dict[str, object],
     ) -> OpInfo:
-        # get runtime schema to determine whether to use pytree to flatten inputs
+        # get runtime schema info to determine whether to use pytree to flatten inputs
         runtime_schema_info = self.sharding_propagator.op_to_schema_info.get(
             op_call, None
         )
 
         if runtime_schema_info is not None and runtime_schema_info.needs_pytree:
-            # flatten args/kwargs when necessary
+            # flatten args/kwargs when op says necessary
             tree_args, args_spec = pytree.tree_flatten(args)
             args_list: Sequence[object] = tree_args
         else:
@@ -301,41 +300,6 @@
         local_kwargs: Dict[str, object] = {}
         mesh: Optional[DeviceMesh] = None
 
-        def try_get_replicate_spec(
-            tensor_arg: torch.Tensor, mesh: "DeviceMesh"
-        ) -> DTensorSpec:
-            # tensor_arg is an instance of torch.Tensor and could be an arg or kwarg.
-            if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
-                warnings.warn(
-                    "Found a non-scalar tensor with numel=1 and ndim!=0, "
-                    "we are implicitly creating a replicated DTensor for it. "
-                    "However, please consider changing it to a scalar tensor "
-                    "or explicitly create a DTensor under distributed enviroment."
-                )
-
-            # if the arg.numel() == 1, arg.ndim could be 0 or 1.
-            if (
-                tensor_arg.ndim <= 1
-                and tensor_arg.numel() == 1
-                or self._allow_implicit_replication
-            ):
-                # scalar tensor can be safely treated as replicated
-                replication_spec = DTensorSpec(
-                    mesh,
-                    (Replicate(),) * mesh.ndim,
-                    tensor_meta=TensorMeta(
-                        shape=tensor_arg.shape,
-                        stride=tensor_arg.stride(),
-                        dtype=tensor_arg.dtype,
-                    ),
-                )
-            else:
-                raise RuntimeError(
-                    f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
-                    " torch.Tensor to DTensor before calling distributed operators!"
-                )
-            return replication_spec
-
         for arg in args_list:
             if isinstance(arg, dtensor.DTensor):
                 args_schema.append(arg._spec)
@@ -350,7 +314,9 @@
                     mesh = arg.device_mesh
             elif isinstance(arg, torch.Tensor):
                 mesh = mesh or try_find_mesh_from_args(op_call, args_list)
-                args_schema.append(try_get_replicate_spec(arg, mesh))
+                args_schema.append(
+                    self._try_replicate_spec_for_scalar_tensor(op_call, arg, mesh)
+                )
                 local_args.append(arg)
             else:
                 args_schema.append(arg)
@@ -369,7 +335,9 @@
                     mesh = v.device_mesh
             elif isinstance(v, torch.Tensor):
                 mesh = mesh or try_find_mesh_from_args(op_call, args_list)
-                kwargs_schema[k] = try_get_replicate_spec(v, mesh)
+                kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
+                    op_call, v, mesh
+                )
                 local_kwargs[k] = v
             else:
                 kwargs_schema[k] = v
@@ -418,3 +386,36 @@
             # if the res contains only non tensor values (i.e. int/float/none), we simply return it
             # without rewrapping to DTensor.
             return res
+
+    def _try_replicate_spec_for_scalar_tensor(
+        self,
+        op_call: torch._ops.OpOverload,
+        tensor_arg: torch.Tensor,
+        mesh: "DeviceMesh",
+    ) -> DTensorSpec:
+        # util function to produce a replicate spec for a scalar tensor arg/kwarg
+        if tensor_arg.numel() == 1 and tensor_arg.ndim == 1:
+            warnings.warn(
+                "Found a non-scalar tensor with numel=1 and ndim!=0, "
+                "we are implicitly creating a replicated DTensor for it. "
+                "However, please consider changing it to a scalar tensor "
+                "or explicitly create a DTensor under distributed enviroment."
+            )
+
+        if tensor_arg.numel() == 1 or self._allow_implicit_replication:
+            # scalar tensor can be safely treated as replicated
+            replication_spec = DTensorSpec(
+                mesh,
+                (Replicate(),) * mesh.ndim,
+                tensor_meta=TensorMeta(
+                    shape=tensor_arg.shape,
+                    stride=tensor_arg.stride(),
+                    dtype=tensor_arg.dtype,
+                ),
+            )
+        else:
+            raise RuntimeError(
+                f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all"
+                " torch.Tensor to DTensor before calling distributed operators!"
+            )
+        return replication_spec
diff --git a/torch/distributed/_tensor/_op_schema.py b/torch/distributed/_tensor/_op_schema.py
index 7c7aea7..7b8af39 100644
--- a/torch/distributed/_tensor/_op_schema.py
+++ b/torch/distributed/_tensor/_op_schema.py
@@ -214,10 +214,10 @@
 @dataclass
 class OpSchema:
     """
-    OpSchema is a data class that describes an operator input schemas, it
-    includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order
-    preserved). It is mainly used by the dispatching logic below to run things like
-    sharding propagation.
+    OpSchema is a data class that describes an operator input schemas, it includes
+    DTensorSpecs (instead of DTensor) and non-tensor args/kwargs (positional order
+    preserved). It is mainly used by the DTensor's dispatching logic to perform various
+    actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.)
 
     NOTE: this should be used as a read only data class
     TODO: make this a frozen dataclass
@@ -225,9 +225,9 @@
     Args:
         op: the operator overload we are intercepting
         args_schema: contains args except that the DTensor args have been replaced
-            with its DTensorSpec
+            with its DTensorSpec or OpStrategy
         kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
-            with its DTensorSpec
+            with its DTensorSpec or OpStrategy
     """
 
     op: OpOverload
@@ -427,13 +427,13 @@
 @dataclass
 class OutputSharding:
     """
-    OutputSharding is a data class that is used by the sharding propagation
-    rules, it could set the output_spec upon successful propagation, and if
-    it failed, output_spec would become None and sharding propagation rules
-    could give a list of suggestions for inputs to reshard.
+    OutputSharding is a data class that is used by the sharding propagation,
+    it could set the output_spec upon successful propagation. If needs_redistribute
+    is set to True, a redistribute_schema would be returned together to indicate
+    the input arguments needs to be redistributed before the op execution.
 
-    NOTE: the schema_suggestion generated by sharding propagation should be
-    exactly the same as the operator OpSchema, except the DTensor DTensorSpecs
+    NOTE: the redistribute_schema generated by sharding propagation should be
+    exactly the same as the operator OpSchema, except the DTensorSpecs
     """
 
     output_spec: OutputSpecType