blob: 4c023f1b800497a95fe5354130dd0fc51a908147 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import functools
import sys
from collections import namedtuple
from contextlib import suppress
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, CPUOffload
from torch.distributed.fsdp.wrap import (
always_wrap_policy,
transformer_auto_wrap_policy,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
CUDAInitMode,
FSDPInitMode,
FSDPTest,
NestedWrappedModule,
TransformerWithSharedParams,
_assert_module_states,
)
from torch.testing._internal.common_utils import (
TEST_WITH_DEV_DBG_ASAN,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestFSDPMisc(FSDPTest):
@property
def world_size(self):
return 2
@property
def process_group(self):
return dist.distributed_c10d._get_default_group()
@skip_if_lt_x_gpu(2)
def test_fsdp_namedtuple(self):
# Ensure namedtuple support, preventing issues such as
# https://github.com/pytorch/pytorch/issues/83053
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(100, 100)
def forward(self, x):
return x
m = MyModule().cuda()
m = FSDP(m)
t = torch.ones(1, device="cuda", requires_grad=True)
MyOutputType = namedtuple(
"MyOutputType",
["a", "b", "c", "d"],
defaults=(t, t, t, t)
)
inp = MyOutputType()
out = m(inp)
# Ensure hooks are registered
for x in out:
self.assertNotEqual([], list(x._backward_hooks.values()))
# TODO: we should check backward() and param is resharded
# as well, but this is blocked by
# https://github.com/pytorch/pytorch/issues/83107 and
# https://github.com/pytorch/pytorch/issues/83129
@skip_if_lt_x_gpu(2)
@parametrize("use_second_layer", [True, False])
@parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy):
# When use_second_layer=True, b is involved in forward computation but does
# not receive grad in backward. Otherwise, b is not involved in forward
# computation.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
def forward(self, x, y):
out1 = self.a(x)
if use_second_layer:
out2 = self.b(y)
return out1, out2
else:
return out1
fsdp = FSDP(
MyModel().cuda(),
sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy
)
x = torch.randn(10, 10, device='cuda')
y = torch.randn(10, 10, device='cuda')
for i in range(4):
if use_second_layer:
a, b = fsdp(x, y)
else:
a = fsdp(x, y)
loss = a.sum()
loss.backward()
# self.a receives grad, self.b does not
a_grad = fsdp.module.a._fsdp_wrapped_module.flat_param.grad
b_grad = fsdp.module.b._fsdp_wrapped_module.flat_param.grad
self.assertIsNotNone(a_grad)
self.assertIsNone(b_grad)
@skip_if_lt_x_gpu(2)
def test_device_id_auto_wrap(self):
"""Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
nested FSDP instances."""
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(),
}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs,
)
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
self.assertEqual(
fsdp_module.device_id,
torch.device("cuda", torch.cuda.current_device()),
)
@skip_if_lt_x_gpu(2)
def test_fsdp_device_id_cpu_offload(self):
"""
Ensures that even if device_id is specified but we have
CPU offload, module is on CPU after init.
"""
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
def forward(self, x):
return self.b(self.a(x))
model = MyModel()
fsdp = FSDP(
model,
auto_wrap_policy=always_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device()
)
cpu_device = torch.device("cpu")
for fsdp_unit in FSDP.fsdp_modules(fsdp):
# This FSDP unit may not directly manage
# any parameters.
if len(fsdp_unit.params) > 0:
fsdp_param = fsdp_unit.params[0]
self.assertEqual(fsdp_param.device, cpu_device)
@skip_if_lt_x_gpu(2)
@parametrize("use_index", [True, False])
def test_fsdp_device_id(self, use_index):
"""
Tests the FSDP ``device_id`` argument:
- Wrapping a CPU module should move the module to the GPU matching
``device_id``
- Wrapping a GPU module already on the GPU matching ``device_id``
should not raise an error
- Wrapping a GPU module already on GPU and passing a GPU device
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
"""
dev_id = (
torch.cuda.current_device() if use_index
else torch.device("cuda", torch.cuda.current_device())
)
def _check_device_matches(module, device_id):
"""Checks that the ``FlatParameter``s in ``module`` have device
matching ``device_id``."""
devices = {
p.device for p in module.parameters()
if isinstance(p, FlatParameter)
}
assert len(devices) > 0
self.assertEqual(1, len(devices))
found_device = devices.pop()
if use_index and not isinstance(device_id, torch.device):
device = torch.device("cuda", device_id)
else:
device = device_id
self.assertEqual(found_device, device)
# Check that FSDP parameters are moved to `device_id` for a CPU module
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
fsdp_kwargs={"device_id": dev_id},
)
_check_device_matches(nested_wrapped_module, dev_id)
# Check that specifying `device_id` for a GPU module already on that
# device does not raise an error
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs={"device_id": dev_id},
)
_check_device_matches(nested_wrapped_module, dev_id)
# Check that passing in `torch.device("cuda")` for a GPU module warns
regex = "does not have explicit index"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
)
with context:
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs={"device_id": torch.device("cuda")}
)
_check_device_matches(
nested_wrapped_module,
torch.device("cuda", torch.cuda.current_device())
)
@skip_if_lt_x_gpu(2)
def test_module_device_mismatches_device_id(self):
"""Tests that specifying a ``device_id`` argument to FSDP for a GPU
module that does not match the GPU device ID raises an error."""
context = (
self.assertRaisesRegex(
RuntimeError,
f"on rank {self.rank}.*cuda:0, but is on cuda:{self.rank}"
) if self.rank != 0 else suppress()
)
with context:
NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
# Move wrapped modules to CUDA before wrapping with FSDP
cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
# Should raise error since rank 1 is given `device_id=0` when
# the model is on cuda:1
fsdp_kwargs={"device_id": 0},
)
@skip_if_lt_x_gpu(2)
def test_multi_device_not_supported(self):
"""Tests that wrapping a multi-device module (i.e. with submodules on
both GPU and CPU) with FSDP raises an error."""
class MultiDeviceModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1).cuda()
self.b = nn.Linear(1, 1)
with self.assertRaisesRegex(
RuntimeError, "FSDP only supports single device modules"
):
FSDP(MultiDeviceModule())
@skip_if_lt_x_gpu(2)
def test_no_params(self):
"""
Test that device_id and cpu init work if module has no params
(they are effective noops, but ensure FSDP does not assume module
has parameters during init)
"""
# Test CPU
no_params = nn.ReLU()
module = FSDP(no_params)
# Test CUDA
no_params = nn.ReLU().cuda()
module = FSDP(no_params)
# Test CPU + device_id
no_params = nn.ReLU()
module = FSDP(no_params, device_id=torch.cuda.current_device())
# For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params.
no_params = nn.ReLU().cuda()
context = (
self.assertRaisesRegex(
AssertionError,
f"Inconsistent.*cuda:{self.rank} vs cuda:0"
)
) if self.rank != 0 else suppress()
with context:
module = FSDP(no_params, device_id=0)
@skip_if_lt_x_gpu(2)
def test_fsdp_cpu_init_stays_on_cpu(self):
"""Tests that passing a CPU module to FSDP preserves that the wrapped
module is on CPU after FSDP initialization, albeit after loging a
warning, and that FSDP moves CPU input to GPU before the forward."""
torch.cuda.set_device(self.rank)
regex = "Module is put on CPU"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
)
with context:
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
)
fsdp_model = FSDP(nested_wrapped_module, self.process_group)
devices = {p.device for p in fsdp_model.parameters()}
self.assertEqual(1, len(devices))
self.assertEqual(torch.device("cpu"), devices.pop())
fsdp_model = fsdp_model.cuda()
# Ensure fwd + backward can be performed after moving to CUDA.
# CPU input also tests that input is correctly moved to appropriate
# CUDA device.
inp = fsdp_model.module.get_input(device=torch.device("cpu"))
fsdp_model(*inp).sum().backward()
@skip_if_lt_x_gpu(2)
def test_cpu_init_with_sync_module_states(self):
"""Tests that passing ``sync_module_states=True`` raises an error for
a CPU module since the synchronization requires GPU communication,
while additionally passing ``device_id`` does not raise an error."""
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
)
with self.assertRaisesRegex(
ValueError,
"Module has CPU parameters, but sync_module_states=True is specified."
):
FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)
# Specifying device_id with sync_module_states=True works.
FSDP(
nested_wrapped_module,
self.process_group,
device_id=torch.cuda.current_device(),
sync_module_states=True,
)
@skip_if_lt_x_gpu(2)
def test_fsdp_same_model_across_ranks(self):
"""
FSDP broadcasts model from rank 0 to ensure it starts off with the same
values.
"""
class MyModel(nn.Module):
def __init__(self, rank):
super().__init__()
# Seed via rank to make model different across ranks
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.register_buffer("buffer", torch.ones(1) * rank)
m = MyModel(self.rank).cuda()
_assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
# sync_module_states also works with CPU module with device_id passed in
m = MyModel(self.rank)
_assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
instantiate_parametrized_tests(TestFSDPMisc)
if __name__ == "__main__":
run_tests()