blob: 19abf6865bd0ea1764f4127bc074f4f1973a6518 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import functools
import os
import sys
import warnings
from collections import namedtuple
from contextlib import suppress
from copy import deepcopy
from typing import Any, Tuple
import torch
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch.distributed.fsdp import (
CPUOffload,
FlatParameter,
FullyShardedDataParallel as FSDP,
ShardingStrategy,
)
from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES
from torch.distributed.fsdp.flat_param import _FSDP_USE_UNSAFE_SETATTR
from torch.distributed.fsdp.wrap import (
always_wrap_policy,
ModuleWrapPolicy,
transformer_auto_wrap_policy,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
_assert_module_states,
CUDAInitMode,
FSDPInitMode,
FSDPTest,
NestedWrappedModule,
TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestFSDPMisc(FSDPTest):
@property
def world_size(self):
return 2
@property
def process_group(self):
return dist.distributed_c10d._get_default_group()
@skip_if_lt_x_gpu(2)
def test_fsdp_namedtuple(self):
# Ensure namedtuple support, preventing issues such as
# https://github.com/pytorch/pytorch/issues/83053
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(100, 100)
def forward(self, x):
return x
m = MyModule().cuda()
m = FSDP(m)
t = torch.ones(1, device="cuda", requires_grad=True)
MyOutputType = namedtuple(
"MyOutputType", ["a", "b", "c", "d"], defaults=(t, t, t, t)
)
inp = MyOutputType()
out = m(inp)
# Ensure hooks are registered
for x in out:
self.assertNotEqual([], list(x._backward_hooks.values()))
# TODO: we should check backward() and param is resharded
# as well, but this is blocked by
# https://github.com/pytorch/pytorch/issues/83107 and
# https://github.com/pytorch/pytorch/issues/83129
@skip_if_lt_x_gpu(2)
def test_fsdp_not_all_outputs_used_in_loss(self):
self.run_subtests(
{
"sharding_strategy": [
ShardingStrategy.FULL_SHARD,
ShardingStrategy.SHARD_GRAD_OP,
ShardingStrategy.NO_SHARD,
]
},
self._test_fsdp_not_all_outputs_used_in_loss,
)
def _test_fsdp_not_all_outputs_used_in_loss(
self, sharding_strategy: ShardingStrategy
):
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.lin1 = nn.Linear(4, 4)
self.lin2 = nn.Linear(4, 4)
def forward(self, x):
a = self.lin1(x)
b = self.lin2(x)
return (a, b)
def _check_resharded(fsdp_module):
for handle in fsdp_module._handles:
param = handle.flat_param
if handle.uses_sharded_strategy:
full_param = param._full_param_padded
self.assertEqual(full_param.storage().size(), 0)
self.assertEqual(param.data_ptr(), param._local_shard.data_ptr())
def _check_equal(local, fsdp):
with FSDP.summon_full_params(fsdp):
for p1, p2 in zip(fsdp.parameters(), local.parameters()):
torch.testing.assert_close(p1, p2)
fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
m = MyModule().cuda()
m_local = deepcopy(m)
local_m = m_local
prev_params = [p.clone() for p in m_local.parameters()]
m.lin1 = fsdp_ctor(m.lin1)
m = fsdp_ctor(m)
_check_equal(m_local, m)
opt = torch.optim.SGD(m.parameters(), lr=1e-3)
opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)
for i in range(6):
t = torch.ones(4, device="cuda")
a, b = m(t)
local_a, local_b = local_m(t)
if i < 2:
# use both params in loss computation. Later,
# b will go unused and we check grads are the
# same as local training.
loss = (a @ b).sum()
loss_local = (local_a @ local_b).sum()
else:
loss = a.sum()
loss_local = local_a.sum()
loss.backward()
loss_local.backward()
_check_resharded(m)
opt.step()
opt_local.step()
_check_equal(m_local, m)
# Ensure at least some change from previous params, otherwise
# above check would be vacuously true.
self.assertTrue(
any(
not torch.equal(p1, p2)
for p1, p2 in zip(prev_params, m_local.parameters())
)
)
prev_params = [p.clone() for p in local_m.parameters()]
opt.zero_grad()
opt_local.zero_grad()
dist.barrier()
@skip_if_lt_x_gpu(2)
@parametrize("use_second_layer", [True, False])
@parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy):
# When use_second_layer=True, b is involved in forward computation but does
# not receive grad in backward. Otherwise, b is not involved in forward
# computation.
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(10, 10)
self.b = nn.Linear(10, 10)
def forward(self, x, y):
out1 = self.a(x)
if use_second_layer:
out2 = self.b(y)
return out1, out2
else:
return out1
fsdp = FSDP(
MyModel().cuda(),
sharding_strategy=sharding_strategy,
auto_wrap_policy=always_wrap_policy,
)
x = torch.randn(10, 10, device="cuda")
y = torch.randn(10, 10, device="cuda")
for i in range(4):
if use_second_layer:
a, b = fsdp(x, y)
else:
a = fsdp(x, y)
loss = a.sum()
loss.backward()
# self.a receives grad, self.b does not
a_grad = fsdp.module.a._handles[0].flat_param.grad
b_grad = fsdp.module.b._handles[0].flat_param.grad
self.assertIsNotNone(a_grad)
self.assertIsNone(b_grad)
@skip_if_lt_x_gpu(2)
def test_device_id_auto_wrap(self):
"""Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
nested FSDP instances."""
self.run_subtests(
{"use_callable": [False, True]},
self._test_device_id_auto_wrap,
)
def _test_device_id_auto_wrap(self, use_callable: bool):
module_classes = {TransformerEncoderLayer, TransformerDecoderLayer}
if use_callable:
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls=module_classes,
)
else:
auto_wrap_policy = ModuleWrapPolicy(module_classes)
fsdp_kwargs = {
"auto_wrap_policy": auto_wrap_policy,
"device_id": torch.cuda.current_device(),
}
fsdp_model = TransformerWithSharedParams.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs,
)
for fsdp_module in FSDP.fsdp_modules(fsdp_model):
self.assertEqual(
fsdp_module.compute_device,
torch.device("cuda", torch.cuda.current_device()),
)
@skip_if_lt_x_gpu(2)
def test_fsdp_device_id_cpu_offload(self):
"""
Tests FSDP when specifying both ``device_id`` and parameter CPU
offloading.
"""
self.run_subtests(
{"use_orig_params": [False, True]},
self._test_fsdp_device_id_cpu_offload,
)
def _test_fsdp_device_id_cpu_offload(self, use_orig_params: bool):
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(10, 10),
nn.Linear(10, 10),
)
self.lin = nn.Linear(10, 10)
def forward(self, x):
return self.lin(self.seq(x))
model = MyModel()
# Choose a wrapping policy such that there are (1) nested FSDP
# instances and (2) the parent FSDP instance has managed parameters
auto_wrap_policy = ModuleWrapPolicy({nn.Sequential})
fsdp_model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=True),
device_id=torch.cuda.current_device(),
use_orig_params=use_orig_params,
)
cpu_device = torch.device("cpu")
for handle in traversal_utils._get_fsdp_handles(fsdp_model):
self.assertEqual(handle.flat_param.device, cpu_device)
@skip_if_lt_x_gpu(2)
@parametrize("use_index", [True, False])
def test_fsdp_device_id(self, use_index):
"""
Tests the FSDP ``device_id`` argument:
- Wrapping a CPU module should move the module to the GPU matching
``device_id``
- Wrapping a GPU module already on the GPU matching ``device_id``
should not raise an error
- Wrapping a GPU module already on GPU and passing a GPU device
without specifying a device ID (i.e. ``torch.device("cuda")``) warns
"""
dev_id = (
torch.cuda.current_device()
if use_index
else torch.device("cuda", torch.cuda.current_device())
)
def _check_device_matches(module, device_id):
"""Checks that the ``FlatParameter``s in ``module`` have device
matching ``device_id``."""
devices = {
p.device for p in module.parameters() if isinstance(p, FlatParameter)
}
assert len(devices) > 0
self.assertEqual(1, len(devices))
found_device = devices.pop()
if use_index and not isinstance(device_id, torch.device):
device = torch.device("cuda", device_id)
else:
device = device_id
self.assertEqual(found_device, device)
# Check that FSDP parameters are moved to `device_id` for a CPU module
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
fsdp_kwargs={"device_id": dev_id},
)
_check_device_matches(nested_wrapped_module, dev_id)
# Check that specifying `device_id` for a GPU module already on that
# device does not raise an error
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs={"device_id": dev_id},
)
_check_device_matches(nested_wrapped_module, dev_id)
# Check that passing in `torch.device("cuda")` for a GPU module warns
regex = "does not have an explicit index"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
)
with context:
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_BEFORE,
fsdp_kwargs={"device_id": torch.device("cuda")},
)
_check_device_matches(
nested_wrapped_module, torch.device("cuda", torch.cuda.current_device())
)
@skip_if_lt_x_gpu(2)
def test_module_device_mismatches_device_id(self):
"""Tests that specifying a ``device_id`` argument to FSDP for a GPU
module that does not match the GPU device ID raises an error."""
context = (
self.assertRaisesRegex(ValueError, f"cuda:{self.rank} vs cuda:0")
if self.rank != 0
else suppress()
)
with context:
NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
# Move wrapped modules to CUDA before wrapping with FSDP
cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
# Should raise error since rank 1 is given `device_id=0` when
# the model is on cuda:1
fsdp_kwargs={"device_id": 0},
)
@skip_if_lt_x_gpu(2)
def test_multi_device_not_supported(self):
"""Tests that wrapping a multi-device module (i.e. with submodules on
both GPU and CPU) with FSDP raises an error."""
class MultiDeviceModule(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(1, 1).cuda()
self.b = nn.Linear(1, 1)
with self.assertRaisesRegex(
RuntimeError, "FSDP only supports single device modules"
):
FSDP(MultiDeviceModule())
@skip_if_lt_x_gpu(2)
def test_no_params(self):
"""
Test that device_id and cpu init work if module has no params
(they are effective noops, but ensure FSDP does not assume module
has parameters during init)
"""
# Test CPU
no_params = nn.ReLU()
module = FSDP(no_params)
# Test CUDA
no_params = nn.ReLU().cuda()
module = FSDP(no_params)
# Test CPU + device_id
no_params = nn.ReLU()
module = FSDP(no_params, device_id=torch.cuda.current_device())
# For modules with no params, wrong device_id will raise error about
# inconsistency between compute_device and device_id, since compute_device
# is computed as torch.cuda.current_device when there are no params.
no_params = nn.ReLU().cuda()
context = (
(
self.assertRaisesRegex(
ValueError, f"Inconsistent.*cuda:{self.rank} vs cuda:0"
)
)
if self.rank != 0
else suppress()
)
with context:
module = FSDP(no_params, device_id=0)
@skip_if_lt_x_gpu(2)
def test_fsdp_cpu_init_stays_on_cpu(self):
"""Tests that passing a CPU module to FSDP preserves that the wrapped
module is on CPU after FSDP initialization, albeit after loging a
warning, and that FSDP moves CPU input to GPU before the forward."""
torch.cuda.set_device(self.rank)
regex = "passed-in `module` is on CPU"
context = self.assertWarnsRegex(
expected_warning=UserWarning, expected_regex=regex
)
with context:
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
)
fsdp_model = FSDP(nested_wrapped_module, self.process_group)
devices = {p.device for p in fsdp_model.parameters()}
self.assertEqual(1, len(devices))
self.assertEqual(torch.device("cpu"), devices.pop())
fsdp_model = fsdp_model.cuda()
# Ensure fwd + backward can be performed after moving to CUDA.
# CPU input also tests that input is correctly moved to appropriate
# CUDA device.
inp = fsdp_model.module.get_input(device=torch.device("cpu"))
fsdp_model(*inp).sum().backward()
@skip_if_lt_x_gpu(2)
def test_cpu_init_with_sync_module_states(self):
"""Tests that passing ``sync_module_states=True`` raises an error for
a CPU module since the synchronization requires GPU communication,
while additionally passing ``device_id`` does not raise an error."""
nested_wrapped_module = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.RECURSIVE,
CUDAInitMode.CUDA_NEVER,
)
with self.assertRaisesRegex(
ValueError, "The module has CPU parameters when `sync_module_states=True`"
):
FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)
# Specifying device_id with sync_module_states=True works.
FSDP(
nested_wrapped_module,
self.process_group,
device_id=torch.cuda.current_device(),
sync_module_states=True,
)
@skip_if_lt_x_gpu(2)
def test_fsdp_same_model_across_ranks(self):
"""
FSDP broadcasts model from rank 0 to ensure it starts off with the same
values.
"""
class MyModel(nn.Module):
def __init__(self, rank):
super().__init__()
# Seed via rank to make model different across ranks
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
self.lin = nn.Linear(10, 10, bias=False)
self.register_buffer("buffer", torch.ones(1) * rank)
m = MyModel(self.rank).cuda()
_assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
)
# sync_module_states also works with CPU module with device_id passed in
m = MyModel(self.rank)
_assert_module_states(
m, process_group=self.process_group, assert_fn=self.assertNotEqual
)
# Passing sync_module_states into FSDP makes model the same during init.
fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
with fsdp.summon_full_params(fsdp):
_assert_module_states(
fsdp, process_group=self.process_group, assert_fn=self.assertEqual
)
@skip_if_lt_x_gpu(2)
def test_homogeneous_attributes(self):
"""
Tests that passing heterogeneous values for attributes designated as
homogeneous raises an error.
"""
# Manually construct this list but verify against the global list of
# homogeneous attribute names
all_attr_name_and_values = [
("_use_orig_params", False, True),
("limit_all_gathers", False, True),
]
self.assertEqual(
[
attr_name_and_values[0]
for attr_name_and_values in all_attr_name_and_values
],
HOMOGENEOUS_ATTR_NAMES,
)
self.run_subtests(
{"attr_name_and_values": all_attr_name_and_values},
self._test_homogeneous_attributes,
)
def _test_homogeneous_attributes(self, attr_name_and_values: Tuple[str, Any, Any]):
model = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.NO_FSDP,
CUDAInitMode.CUDA_BEFORE,
{},
)
attr_name = attr_name_and_values[0]
fsdp_kwargs_inner = {attr_name.lstrip("_"): attr_name_and_values[1]}
fsdp_kwargs_outer = {attr_name.lstrip("_"): attr_name_and_values[2]}
model.module[1] = FSDP(model.module[1], **fsdp_kwargs_inner)
fsdp_model = FSDP(model, **fsdp_kwargs_outer)
# Run a forward to trigger lazy initialization and the error
with self.assertRaisesRegex(
ValueError, f"Expects one homogeneous value for {attr_name}"
):
inp = fsdp_model.module.get_input(torch.device("cuda"))
fsdp_model(*inp)
class TestFSDPMiscWorldSize1(FSDPTest):
@property
def world_size(self) -> int:
return 1
@skip_if_lt_x_gpu(1)
def test_world_size_1_sharding_strategy_warning(self):
"""
Tests that FSDP issues a warning when it switches to using ``NO_SHARD``
when the world size is 1.
"""
warning_prefix = "FSDP is switching to use `NO_SHARD` instead of"
# If the user already passes `NO_SHARD`, then there should not be a
# warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # trigger all warnings
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.NO_SHARD)
for warning in w:
self.assertTrue(
warning.category != UserWarning
or not str(warning.message).startswith(warning_prefix)
)
# Check that a warning is issued
warning_suffix = " since the world size is 1."
# - Pass `FULL_SHARD` or `None`
expected_regex_full_shard = (
warning_prefix + " " + str(ShardingStrategy.FULL_SHARD) + warning_suffix
)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.FULL_SHARD)
with self.assertWarnsRegex(UserWarning, expected_regex_full_shard):
FSDP(nn.Linear(3, 3).cuda())
# - Pass `SHARD_GRAD_OP`
expected_regex_shard_grad_op = (
warning_prefix + " " + str(ShardingStrategy.SHARD_GRAD_OP) + warning_suffix
)
with self.assertWarnsRegex(UserWarning, expected_regex_shard_grad_op):
FSDP(
nn.Linear(3, 3).cuda(), sharding_strategy=ShardingStrategy.SHARD_GRAD_OP
)
@skip_if_lt_x_gpu(1)
def test_training_device_mismatch_errors(self):
"""
Tests that, when training starts, if FSDP parameters are not on the
expected device, then an informative error is raised. This applies for
both no parameter CPU offloading and parameter CPU offloading.
"""
# Incorrectly not moving from CPU -> GPU
model = torch.nn.Linear(10, 10)
fsdp_model = FSDP(model)
inp = torch.randn((2, 10))
with self.assertRaisesRegex(
RuntimeError,
"An FSDP-managed module unexpectedly has parameters on cpu. Make "
"sure to move the module to cuda:0 before training.",
):
fsdp_model(inp)
# Incorrectly moving from CPU -> GPU
model = torch.nn.Linear(10, 10)
fsdp_model = FSDP(model, cpu_offload=CPUOffload(offload_params=True))
fsdp_model.to(torch.device("cuda"))
inp = torch.randn((2, 10))
with self.assertRaisesRegex(
RuntimeError,
"An FSDP-managed module with parameter CPU offloading enabled has "
"parameters on cuda:0. Make sure to not move the module from CPU "
"when offloading parameters.",
):
fsdp_model(inp)
@skip_if_lt_x_gpu(2)
def test_unsafe_setattr(self):
"""
Tests that the environment variable for using unsafe setattr gates as
expected.
"""
self.run_subtests(
{"use_orig_params": [False, True]},
self._test_unsafe_setattr,
)
def _test_unsafe_setattr(self, use_orig_params: bool):
called_setattr_override = False
class SetattrLinear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, device: torch.device) -> None:
super().__init__()
self.weight = nn.Parameter(
torch.randn((in_dim, out_dim), device=device)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.weight
def __setattr__(self, name: str, value: Any) -> None:
nonlocal called_setattr_override
called_setattr_override = True
return super().__setattr__(name, value)
# Construct FSDP module without changing any environment variables and
# run forward, which triggers both unsharded and sharded view setting
module = SetattrLinear(5, 5, torch.device("cuda"))
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
inp = torch.randn((8, 5), device=torch.device("cuda"))
called_setattr_override = False
fsdp_module(inp)
self.assertTrue(called_setattr_override)
# Repeat with unsafe setattr explicitly enabled
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "1"
module = SetattrLinear(5, 5, torch.device("cuda"))
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
called_setattr_override = False
fsdp_module(inp)
self.assertFalse(called_setattr_override)
# Repeat with unsafe setattr explicitly disabled
os.environ[_FSDP_USE_UNSAFE_SETATTR] = "0"
module = SetattrLinear(5, 5, torch.device("cuda"))
fsdp_module = FSDP(module, use_orig_params=use_orig_params)
called_setattr_override = False
fsdp_module(inp)
self.assertTrue(called_setattr_override)
instantiate_parametrized_tests(TestFSDPMisc)
if __name__ == "__main__":
run_tests()