[SPMD] Add the default graph module transformation that is applied after tracing and expansion (#98182)
This PR adds the GraphModuleTransformation class that can be used as the
default transformation after the `train_step()` is traced and expand. The
current implementation includes:
1. Wrap the input graph module with IterGraphModule. This will enable the futher graph optimizations which are all implemented based on IterGraphModule.
2. Ability to lower the graph module to the Inductor. To achieve this goal, `lower_to_inductor()` is implemented.
TODO:
1. The `override` and `gm_transofmation` have overlapping functions -- `override.transform` can be used to achieve the same function as `gm_transformation`. However, the current semantics of `override` is to override and transform partial graphs while `gm_transformation` is to transform the entire expaned GM. The final UX of `compile()` needs some discussion.
2. The current `lower_to_inductor()` assumes that the entire graph can be lowered to Inductor. This assumption is okay for integration of graph optimizations but is too restrictive for many models. We should upstream `partial_lowering()`.
Differential Revision: [D44616783](https://our.internmc.facebook.com/intern/diff/D44616783/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98182
Approved by: https://github.com/mrshenli
diff --git a/test/distributed/_spmd/test_transformation.py b/test/distributed/_spmd/test_transformation.py
new file mode 100644
index 0000000..fe13b2e
--- /dev/null
+++ b/test/distributed/_spmd/test_transformation.py
@@ -0,0 +1,139 @@
+# Owner(s): ["oncall: distributed"]
+
+import unittest
+from copy import deepcopy
+from functools import wraps
+
+import torch
+import torch.nn as nn
+from torch._inductor.utils import has_triton
+from torch.distributed._spmd.api import compile
+from torch.distributed._spmd.gm_transformation import GraphModuleTransformation
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ with_comms as base_with_comms,
+)
+
+
+def with_comms(func):
+ @base_with_comms
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ # make sure we set different random seeds for each rank
+ # otherwise we dont need DDP / SPMD
+ # (we would have the same parameters and inputs everywhere)
+ torch.manual_seed(self.rank)
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
+class DummyModel(nn.Module):
+ def __init__(self, layers: int, dim: int):
+ super().__init__()
+ modules = []
+ for _ in range(layers):
+ modules.extend([nn.Linear(dim, dim), nn.ReLU()])
+ self.mod = nn.Sequential(*modules)
+
+ def forward(self, x):
+ return self.mod(x)
+
+
+class TransformationTest(DTensorTestBase):
+ @property
+ def world_size(self):
+ return 2
+
+ def _init(self, batch_size, layers, dim):
+ torch.manual_seed(0)
+ model = DummyModel(layers, dim).cuda()
+ ddp_model = DDP(deepcopy(model), device_ids=[self.rank])
+ optim = torch.optim.Adam(
+ model.parameters(), lr=0.01, foreach=True, capturable=True
+ )
+ ddp_optim = torch.optim.Adam(
+ ddp_model.parameters(), lr=0.01, foreach=True, capturable=True
+ )
+ batch = torch.randn(batch_size, dim).cuda()
+
+ # materialize optimizer states
+ out = model(batch)
+ out.sum().backward()
+ optim.step()
+ optim.zero_grad()
+
+ ddp_out = ddp_model(batch)
+ ddp_out.sum().backward()
+ ddp_optim.step()
+ ddp_optim.zero_grad()
+
+ self.assertEqual(ddp_out, out)
+ self.assertEqual(list(ddp_model.parameters()), list(model.parameters()))
+ return model, optim, ddp_model, ddp_optim
+
+ def _test_tran_step_with_ddp(self, train_step, num_iters, batch_size, layers, dim):
+ def _ddp_train_step(model, optim, batch):
+ model(batch).sum().backward()
+ with torch.no_grad():
+ for p in model.parameters():
+ p.grad *= self.world_size
+ optim.step()
+ optim.zero_grad()
+
+ model, optim, ddp_model, ddp_optim = self._init(batch_size, layers, dim)
+ for _ in range(num_iters):
+ batch = torch.randn(batch_size, dim).cuda()
+ out = train_step(model, optim, batch)
+ ddp_out = _ddp_train_step(ddp_model, ddp_optim, batch)
+ self.assertEqual(list(ddp_model.parameters()), list(model.parameters()))
+
+ @skip_if_lt_x_gpu(2)
+ @with_comms
+ def test_basic_transformation(self):
+ batch_size = 100
+ layers = 10
+ dim = 100
+ num_iters = 5
+
+ @compile(gm_transformation=GraphModuleTransformation(num_iters=num_iters))
+ def train_step(model, optim, batch):
+ model(batch).sum().backward()
+ optim.step()
+ optim.zero_grad()
+
+ self._test_tran_step_with_ddp(train_step, num_iters, batch_size, layers, dim)
+
+ @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
+ @skip_if_lt_x_gpu(2)
+ @with_comms
+ def test_inductor(self):
+ batch_size = 100
+ layers = 10
+ dim = 100
+ num_iters = 5
+
+ @compile(
+ gm_transformation=GraphModuleTransformation(
+ num_iters=num_iters, enable_inductor=True
+ )
+ )
+ def train_step(model, optim, batch):
+ model(batch).sum().backward()
+ optim.step()
+ optim.zero_grad()
+
+ # TODO: there are issues when lowering the optimizer. Disable
+ # the test for now.
+ """
+ self._test_tran_step_with_ddp(
+ train_step, num_iters, batch_size, layers, dim
+ )
+ """
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/torch/distributed/_spmd/gm_transformation.py b/torch/distributed/_spmd/gm_transformation.py
new file mode 100644
index 0000000..1227cc4
--- /dev/null
+++ b/torch/distributed/_spmd/gm_transformation.py
@@ -0,0 +1,132 @@
+import operator
+from typing import Any, Callable, Dict, List, Optional
+
+from functorch import make_fx
+
+import torch
+import torch.nn as nn
+
+from torch import fx
+from torch._inductor.compile_fx import compile_fx_inner
+from torch._inductor.decomposition import select_decomp_table
+from torch.distributed._spmd.graph_utils import OP
+from torch.distributed._spmd.iter_graph_module import IterGraphModule
+from torch.utils._pytree import tree_flatten
+
+
+class InductorWrapper(nn.Module):
+ def __init__(self, gm: fx.GraphModule, enable_cudagraphs: bool) -> None:
+ super().__init__()
+ self._gm = gm
+ self._compiled: Optional[nn.Module] = None
+ self._enable_cudagraphs = enable_cudagraphs
+
+ def forward(self, *args: Any) -> Any:
+ if self._compiled is None:
+ gm = make_fx(self._gm, decomposition_table=select_decomp_table())(*args)
+ self._compiled = compile_fx_inner(
+ gm,
+ list(args),
+ cudagraphs=self._enable_cudagraphs,
+ )
+ list_args, _ = tree_flatten(args)
+ return self._compiled(list_args)
+
+
+def lower_to_inductor(
+ gm: torch.fx.GraphModule, enable_cudagraphs: bool
+) -> torch.fx.GraphModule:
+ """
+ This API lowers the entire `gm` to the Inductor
+ """
+ orig_placeholders: List[fx.Node] = []
+ orig_output_args: List[Any] = []
+ output: fx.Node = next(iter(gm.graph.nodes))
+ move_nodes: List[fx.Node] = []
+
+ for node in gm.graph.nodes:
+ if node.op == OP.OUTPUT:
+ output = node
+ orig_output_args, _ = tree_flatten((node.args, node.kwargs))
+ elif node.op == OP.PLACEHOLDER:
+ orig_placeholders.append(node)
+ else:
+ move_nodes.append(node)
+
+ subgraph: torch.fx.Graph = torch.fx.Graph()
+ node_mapping: Dict[torch.fx.Node, torch.fx.Node] = {}
+ attrs = {}
+
+ # Map all the inputs/placeholders first.
+ for p in orig_placeholders:
+ node_mapping[p] = subgraph.node_copy(p)
+
+ # Create all other non-placeholders nodes
+ for node in move_nodes:
+ if node.op == OP.GET_ATTR:
+ attrs[node.target] = getattr(gm, node.target)
+ node_mapping[node] = subgraph.node_copy(node, lambda n: node_mapping[n])
+
+ output_args = tuple(
+ node_mapping[n] if n is not None else None for n in orig_output_args
+ )
+ subgraph.output(result=output_args)
+
+ # Remove unused placeholders from the subgraph. This is required as the
+ # `train_step()` has module and optimizer as the inputs which cannot be
+ # lowered to Inductor.
+ placeholders: List[torch.fx.Node] = []
+ for placeholder in orig_placeholders:
+ new_placeholder = node_mapping[placeholder]
+ if len(new_placeholder.users) == 0:
+ subgraph.erase_node(new_placeholder)
+ else:
+ placeholders.append(placeholder)
+
+ # Create the subgraph node in the original graph.
+ sub_gm = torch.fx.GraphModule(root=attrs, graph=subgraph)
+ gm.subgraph = InductorWrapper(sub_gm, enable_cudagraphs)
+ with gm.graph.inserting_after(move_nodes[-1]):
+ subgraph_call = gm.graph.create_node(
+ op=OP.CALL_MODULE, target="subgraph", args=tuple(placeholders)
+ )
+
+ # Redistribute the output from the subgraph to the original output.
+ output_idx = 0
+ for i, node in enumerate(orig_output_args):
+ with gm.graph.inserting_after(subgraph_call):
+ new_node = gm.graph.call_function(
+ operator.getitem, (subgraph_call, output_idx)
+ )
+ output_idx += 1
+ orig_output_args[i] = new_node
+ assert output_idx == len(output_args)
+ gm.graph.erase_node(output)
+ gm.graph.output(result=orig_output_args)
+
+ gm.graph.eliminate_dead_code()
+ gm.recompile()
+
+ return gm
+
+
+class GraphModuleTransformation:
+ def __init__(
+ self,
+ num_iters: int,
+ enable_inductor: bool = False,
+ enable_cudagraphs: bool = False,
+ ) -> None:
+ self.num_iters = num_iters
+ self.enable_inductor = enable_inductor
+ self.enable_cudagraphs = enable_cudagraphs
+
+ def __call__(self, gm: fx.GraphModule) -> Callable:
+ iter_gm = IterGraphModule(gm)
+ iter_gm.freeze_cross_iter_movement()
+ iter_gm.setup(self.num_iters)
+
+ if self.enable_inductor:
+ iter_gm.main_gm = lower_to_inductor(iter_gm.main_gm, self.enable_cudagraphs)
+
+ return iter_gm