[composable] Enable replicate + trec_shard overall (#98890)
replicate + trec_shard works if we shard / replicate individually, such as follows:
```
m = TestSparseNN()
shard(m.sparse)
replicate(m.dense)
```
but does not work if users do the following:
```
m = TestSparseNN()
shard(m, sharders=[...])
replicate(m)
```
Many upstream trainers use the latter use case, as sharding is not done on individual module level but rather overall module by specifying planners that contain logic for how to shard different embedding table types.
This diff enables the latter approach (while keeping the former intact), but users need to specify `ignored_modules` to ignore embedding tables in replicate(). This is similar to FSDP (class based and composable) and DDP today.
Differential Revision: [D44899155](https://our.internmc.facebook.com/intern/diff/D44899155/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98890
Approved by: https://github.com/mrshenli, https://github.com/yhcharles
diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py
index e5c9f0f..6198819 100644
--- a/test/distributed/_composable/test_replicate.py
+++ b/test/distributed/_composable/test_replicate.py
@@ -8,23 +8,22 @@
import torch.nn.functional as F
from torch import nn
from torch.distributed._composable.replicate import replicate
-from torch.testing._internal.common_distributed import MultiProcessTestCase
+from torch.testing._internal.common_distributed import (
+ MultiProcessTestCase,
+ skip_if_lt_x_gpu,
+)
from torch.testing._internal.common_utils import run_tests
class Net(nn.Module):
def __init__(self):
super().__init__()
- self.fc1 = nn.Linear(2, 10, bias=False)
- self.fc2 = nn.Linear(10, 50, bias=False)
- self.fc3 = nn.Linear(50, 4, bias=False)
- self.relu = nn.ReLU()
+ self.fc1 = nn.Linear(2, 2)
+ self.fc2 = nn.Linear(2, 2)
+ self.fc3 = nn.Linear(2, 2)
def forward(self, x):
- x = self.relu(self.fc1(x))
- x = self.relu(self.fc2(x))
- x = self.fc3(x)
- return F.softmax(x, dim=1)
+ return self.fc3(self.fc2(self.fc1(x)))
class ReplicateStateDictTest(MultiProcessTestCase):
@@ -74,6 +73,10 @@
class ReplicateTest(MultiProcessTestCase):
+ @property
+ def world_size(self) -> int:
+ return 2
+
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@@ -96,7 +99,7 @@
local_batch_size = 1
global_batch_size = self.world_size * local_batch_size
input = torch.randn(global_batch_size, 2)
- target = torch.randn(global_batch_size, 4)
+ target = torch.randn(global_batch_size, 2)
def step_model(model, input, target):
model.train()
@@ -136,6 +139,40 @@
replicate_model = replicate(deepcopy(model))
self._compare_module(model, replicate_model)
+ @skip_if_lt_x_gpu(2)
+ def test_replicate_ignore_module(self):
+ dist.init_process_group(
+ backend="gloo",
+ rank=self.rank,
+ world_size=self.world_size,
+ store=dist.FileStore(self.file_name, self.world_size),
+ )
+ torch.cuda.set_device(self.rank)
+ # Seed ensures diff input and thus different local grads across ranks.
+ torch.manual_seed(self.rank)
+ torch.cuda.manual_seed(self.rank)
+ model = Net().cuda()
+ replicate(model, ignored_modules=[model.fc1])
+ inp = torch.randn(5, 2, device="cuda") * (self.rank + 1)
+ out = model(inp) * 10
+ out.sum().backward()
+ # FC1 grads should not be synchronized, FC2 and 3 should be.
+ fc1_grad = model.fc1.weight.grad
+ tensor_list = [torch.zeros_like(fc1_grad) for _ in range(dist.get_world_size())]
+ dist.all_gather(tensor_list, fc1_grad)
+ grad, rest = tensor_list[0], tensor_list[1:]
+ for g in rest:
+ self.assertNotEqual(grad, g)
+
+ for dp_grad in [model.fc2.weight.grad, model.fc3.weight.grad]:
+ tensor_list = [
+ torch.zeros_like(dp_grad) for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(tensor_list, dp_grad)
+ grad, rest = tensor_list[0], tensor_list[1:]
+ for g in rest:
+ self.assertEqual(grad, g)
+
def test_replicate_multi_module(self):
model = Net()
replicate_model = deepcopy(model)
diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py
index 320c31b..09655bc 100644
--- a/torch/distributed/_composable/replicate.py
+++ b/torch/distributed/_composable/replicate.py
@@ -1,4 +1,4 @@
-from typing import Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
import torch
import torch.nn as nn
@@ -11,6 +11,7 @@
@contract()
def replicate(
module: nn.Module, # NOTE: contract now supports single module only
+ ignored_modules: Optional[Iterable[torch.nn.Module]] = None,
**kwargs,
) -> nn.Module:
r"""Replicates a module
@@ -24,24 +25,32 @@
>>> replicate(module)
"""
torch._C._log_api_usage_once("torch.distributed.replicate")
- _ReplicateState().mark_module(module, **kwargs)
+ _ReplicateState(ignored_modules=ignored_modules).mark_module(module, **kwargs)
return module
-def _can_compose(module: nn.Module) -> bool:
- r"""Check if module is composable for `replicate` API."""
- return "fully_shard" not in _get_registry(module)
+def _is_fully_sharded(module: nn.Module) -> bool:
+ r"""Check if module is marked with fully_shard."""
+ return "fully_shard" in _get_registry(module)
class _ReplicateState:
- def __init__(self) -> None:
+ def __init__(self, ignored_modules: Optional[Iterable[torch.nn.Module]]) -> None:
self.module: Optional[nn.Module] = None
self.has_initialized: bool = False
self._param_list: nn.ParameterList = nn.ParameterList()
self.kwargs: dict = {}
+ self.ignored_modules: Set[torch.nn.Module] = (
+ set(ignored_modules) if ignored_modules is not None else set()
+ )
+ self.ignored_params: Set[torch.nn.Parameter] = {
+ p for m in self.ignored_modules for p in m.parameters()
+ }
+ # Only used for testing
+ self._names: List[str] = []
def mark_module(self, module: nn.Module, **kwargs) -> None:
- if not _can_compose(module):
+ if _is_fully_sharded(module):
raise AssertionError(
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
)
@@ -52,26 +61,20 @@
module.register_forward_hook(self.forward_post_hook) # type: ignore[arg-type]
self.kwargs = kwargs
- def _recursive_collect_params(self, module: nn.Module) -> None:
- # skip if managed by other APIs
- if not _can_compose(module):
+ def _collect_params(self, module: nn.Module) -> None:
+ # skip if managed by fully_sharded API
+ if _is_fully_sharded(module):
return
- # skip if module parameters already collected
- replicate_state = replicate.state(module)
- # if replicate_state is None, `module` is a child module that has not been explicitly
- # tagged as replicate().
- if replicate_state is not None:
- if hasattr(replicate_state, "_params_collected"):
- if replicate_state._params_collected:
- return
- replicate_state._params_collected = True
+ if module in self.ignored_modules:
+ return # if module A is ignored, all of A's children are also ignored.
self._param_list.extend(
- param for param in module.parameters(recurse=False) if param.requires_grad
+ p for p in module.parameters(recurse=False) if p not in self.ignored_params
)
- for child in module.children():
- self._recursive_collect_params(child)
+
+ for child_module in module.children():
+ self._collect_params(child_module)
def init_helper(self) -> None:
if self.has_initialized:
@@ -79,7 +82,9 @@
self.has_initialized = True
- self._recursive_collect_params(self.module) # type: ignore[arg-type]
+ self._collect_params(self.module) # type: ignore[arg-type]
+ # Only saved for testing
+ replicate.state(self.module)._names = self._names
self._ddp = DistributedDataParallel(self._param_list, **self.kwargs)