[spmd] Introduce ParallelMode and add DTensorExpandMode (#98452)
This PR introduces a ParallelMode interface to define how to do
SPMD expansion and optimize the captured graph. This would be
beneifical for different parallelisms to expand differently
and apply different optimization passes
Put DTensorExpandMode as the first parallel mode that does the
existing dtensor_expand functionality.
Differential Revision: [D45174399](https://our.internmc.facebook.com/intern/diff/D45174399)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98452
Approved by: https://github.com/mrshenli
diff --git a/torch/distributed/_spmd/api.py b/torch/distributed/_spmd/api.py
index 2f8baef..840b8d5 100644
--- a/torch/distributed/_spmd/api.py
+++ b/torch/distributed/_spmd/api.py
@@ -28,14 +28,10 @@
import torch.utils._pytree as pytree
from torch import fx
-from torch._subclasses import FakeTensorMode
-from torch.distributed._spmd.distribute import (
- _convert_to_distributed,
- distribute,
- Schema,
-)
+from torch.distributed._spmd.distribute import distribute, Schema
from torch.distributed._spmd.distributed_graph import DistributedGraph
-from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard
+from torch.distributed._spmd.parallel_mode import DTensorExpandMode, ParallelMode
+from torch.distributed._tensor import Placement, Replicate
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo, CodeGen
from torch.nn.utils import stateless
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
@@ -179,67 +175,14 @@
return gm
-_placements_override: Dict[int, List[Placement]] = {}
+# Use a dtensor expand mode for now to preserve the old behavior
+# and avoid breaking existing code
+dtensor_expand_mode = DTensorExpandMode()
def _override_placements(t: torch.Tensor, placements: List[Placement]):
- global _placements_override
- _placements_override[id(t)] = placements
-
-
-def _dtensor_expand(
- gm: fx.GraphModule,
- params_and_buffers: Dict[str, Any],
- named_states: Dict[str, Any],
- args: Tuple[Any, ...],
- kwargs: Dict[str, Any],
-) -> fx.GraphModule:
- flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values()))
-
- mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
- shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
- # FIXME: allow other sharding schemas
- replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
-
- inps, schemas = [], []
-
- for p in pytree.tree_flatten(params_and_buffers)[0]:
- assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
- inps.append(p)
- schemas.append(replicate_schema)
-
- for o in pytree.tree_flatten(named_states)[0]:
- if isinstance(o, torch.Tensor):
- inps.append(o)
- schemas.append(replicate_schema)
- else:
- inps.append(torch.empty(0))
- schemas.append(replicate_schema)
-
- for a in flat_args:
- if isinstance(a, torch.Tensor):
- inps.append(a)
- if id(a) in _placements_override:
- schemas.append(
- Schema(mesh=mesh, placements=_placements_override[id(a)])
- )
- else:
- schemas.append(shard_schema)
- else:
- # Create dummy tensor and schema for non-tensor inputs for
- # the purpose of dtensor expansion. Non-tensor inputs are
- # guaranteed unused in dispatcher graphs produced by make_fx.
- # However, we still need to respect them so that tensor inputs
- # match wtih their placeholders.
- inps.append(torch.empty(0))
- schemas.append(shard_schema)
-
- with FakeTensorMode(allow_non_fake_inputs=True):
- fake_inps = [torch.empty_like(inp) for inp in inps]
-
- return _convert_to_distributed(
- gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
- )[0]
+ global dtensor_expand_mode
+ dtensor_expand_mode._placements_override[id(t)] = placements
@contextmanager
@@ -429,6 +372,7 @@
def _compile(
func: Callable,
module_override: Optional[Dict[Union[Type[Any], str], Override]],
+ parallel_mode: ParallelMode,
*args: Any,
**kwargs: Any,
) -> _CompiledResult:
@@ -510,8 +454,16 @@
**buffers,
}
- # 4. Use DTensor to insert collectives
- gm = _dtensor_expand(gm, params_and_buffers, named_states, args, kwargs)
+ # 4. parallel mode to expand a single device graph to a distributed graph
+ gm = parallel_mode.partition(
+ gm,
+ mod,
+ opt,
+ params_and_buffers,
+ named_states,
+ args,
+ kwargs,
+ )
# 5. Move the responsibility of flattening the input arguments from the
# graph module to the caller. This serves two purposes:
@@ -551,6 +503,7 @@
def compile(
module_override: Optional[Dict[Union[Type[Any], str], Override]] = None,
gm_transformation: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None,
+ parallel_mode: Optional[ParallelMode] = None,
):
r"""
Compile and optimize a callable, which can be a train step within a training
@@ -569,6 +522,10 @@
a callback that will be called after the original callable is
compiled and distributed (usually after the first iteration) to
transform the compiled GraphModule into a new optimized one.
+ parallel_mode (Optional[ParallelMode]): a :class:`ParallelMode` object
+ that specifies how to parallelize the callable. Each ParallelMode
+ would have its own strategy to partition the model and the captured
+ graph (Default: ``None``)
"""
def inner(func: Callable):
@@ -581,7 +538,12 @@
compiled_obj = wrapper.__dict__.get(COMPILED_OBJECT_KEY, None)
if compiled_obj is None:
first_iter = True
- compiled_obj = _compile(func, module_override, *args, **kwargs)
+ global dtensor_expand_mode
+ mode: ParallelMode = (
+ dtensor_expand_mode if parallel_mode is None else parallel_mode
+ )
+
+ compiled_obj = _compile(func, module_override, mode, *args, **kwargs)
wrapper.__dict__[COMPILED_OBJECT_KEY] = compiled_obj
flat_inps = compiled_obj.flat_state + pytree.tree_flatten([args, kwargs])[0]
diff --git a/torch/distributed/_spmd/parallel_mode.py b/torch/distributed/_spmd/parallel_mode.py
new file mode 100644
index 0000000..2b487c6
--- /dev/null
+++ b/torch/distributed/_spmd/parallel_mode.py
@@ -0,0 +1,133 @@
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Dict, List, Optional, Tuple
+
+import torch
+import torch.distributed as dist
+import torch.utils._pytree as pytree
+from torch._subclasses import FakeTensorMode
+from torch.distributed._spmd.distribute import _convert_to_distributed, Schema
+from torch.distributed._tensor import DeviceMesh, Placement, Replicate, Shard
+
+from torch.fx import GraphModule
+
+
+class ParallelMode(ABC):
+ """
+ Basic Parallel Mode interface. Each parallelism pattern should implement
+ this interface to describe how to partition and compile the graph in the
+ spmd compiler.
+ """
+
+ @abstractmethod
+ def partition(
+ self,
+ gm: GraphModule,
+ model: torch.nn.Module,
+ optimizer: Optional[torch.optim.Optimizer],
+ params_and_buffers: Dict[str, Any],
+ named_states: Dict[str, Any],
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ ) -> GraphModule:
+ """
+ Partition a single device graph to a distributed graph.
+
+ TODO(@wanchaol): some of these arguments are not necessary for
+ partitioning, remove the unnecessary ones later.
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def transform_and_compile(self, gm: GraphModule) -> GraphModule:
+ """
+ Transform and compile a distributed graph with a set of graph
+ transformation and optimization passes for each parallel mode.
+
+ The returned result should be a compiled executable graph in
+ the distributed environment.
+ """
+ # TODO: add more necessary arguments to this interface.
+ raise NotImplementedError()
+
+
+class DTensorExpandMode(ParallelMode):
+ """
+ The DTensor Expand mode. It's replicating the parameters and
+ shard the inputs to represent DDP like behavior, it's currently
+ a transitent mode before we move to the new data parallel expansion.
+ """
+
+ def __init__(
+ self, custom_passes: Optional[Callable[[GraphModule], GraphModule]] = None
+ ):
+ self._placements_override: Dict[int, List[Placement]] = {}
+ if custom_passes is not None:
+ self._gm_passes: Callable[[GraphModule], GraphModule] = custom_passes
+ else:
+ # TODO: add a few default passes here.
+ self._gm_passes = lambda gm: gm
+
+ def partition(
+ self,
+ gm: GraphModule,
+ model: torch.nn.Module,
+ optimizer: Optional[torch.optim.Optimizer],
+ params_and_buffers: Dict[str, Any],
+ named_states: Dict[str, Any],
+ args: Tuple[Any, ...],
+ kwargs: Dict[str, Any],
+ ) -> GraphModule:
+ flat_args, _ = pytree.tree_flatten(list(args) + list(kwargs.values()))
+
+ mesh = DeviceMesh("cuda", torch.arange(dist.get_world_size()).cuda())
+ shard_schema: Schema = Schema(mesh=mesh, placements=[Shard(0)])
+ # FIXME: allow other sharding schemas
+ replicate_schema: Schema = Schema(mesh=mesh, placements=[Replicate()])
+
+ inps, schemas = [], []
+
+ for p in pytree.tree_flatten(params_and_buffers)[0]:
+ assert isinstance(p, torch.Tensor), f"expecting Tensor but got {type(p)}"
+ inps.append(p)
+ schemas.append(replicate_schema)
+
+ for o in pytree.tree_flatten(named_states)[0]:
+ if isinstance(o, torch.Tensor):
+ inps.append(o)
+ schemas.append(replicate_schema)
+ else:
+ inps.append(torch.empty(0))
+ schemas.append(replicate_schema)
+
+ for a in flat_args:
+ if isinstance(a, torch.Tensor):
+ inps.append(a)
+ if id(a) in self._placements_override:
+ schemas.append(
+ Schema(mesh=mesh, placements=self._placements_override[id(a)])
+ )
+ else:
+ schemas.append(shard_schema)
+ else:
+ # Create dummy tensor and schema for non-tensor inputs for
+ # the purpose of dtensor expansion. Non-tensor inputs are
+ # guaranteed unused in dispatcher graphs produced by make_fx.
+ # However, we still need to respect them so that tensor inputs
+ # match wtih their placeholders.
+ inps.append(torch.empty(0))
+ schemas.append(shard_schema)
+
+ with FakeTensorMode(allow_non_fake_inputs=True):
+ fake_inps = [torch.empty_like(inp) for inp in inps]
+
+ return _convert_to_distributed(
+ gm, fake_inps, schemas, default_mesh=mesh, _allow_partial=False
+ )[0]
+
+ def transform_and_compile(self, gm: GraphModule) -> GraphModule:
+ """
+ Transform and compile a distributed graph with a set of graph transformation
+ and optimization passes for the dtensor fallback parallel mode.
+ """
+ # TODO: move the trasnformation passed to this function
+ return self._gm_passes(gm)