blob: 3b7c820983871ca303d3b3936e7921fe9c04776c [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import copy
import sys
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import checkpoint, fully_shard, replicate
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType
from torch.distributed.fsdp.api import ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_dist_composable import (
CompositeModel,
CompositeParamModel,
UnitModule,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestFSDPCheckpoint(FSDPTest):
@property
def world_size(self) -> int:
return 2
# TODO: Define `use_same_inputs_across_ranks` for now for BC since some
# test model configs do not have a simple base model to compare against. In
# those cases, we use the same inputs across ranks so that the averaged
# gradient equals the local gradient to check for parity. This means that
# the gradient reduction is unchecked.
def _test_parity(
self,
base_model: nn.Module,
test_model: nn.Module,
inp_size: torch.Size,
inp_device: torch.device,
grad_to_none: bool,
use_same_inputs_across_ranks: bool,
):
LR = 0.01
base_optim = torch.optim.Adam(base_model.parameters(), lr=LR)
test_optim = torch.optim.Adam(test_model.parameters(), lr=LR)
for _ in range(5):
if use_same_inputs_across_ranks:
torch.manual_seed(0)
x = torch.randn(inp_size, device=inp_device)
test_loss = test_model(x).sum()
base_loss = base_model(x).sum()
self.assertEqual(test_loss, base_loss)
test_loss.backward()
test_optim.step()
test_optim.zero_grad(set_to_none=grad_to_none)
base_loss.backward()
base_optim.step()
base_optim.zero_grad(set_to_none=grad_to_none)
@skip_if_lt_x_gpu(2)
@parametrize("use_reentrant", [True, False])
def test_wrap_same_submodule(self, use_reentrant: bool):
model = UnitModule(device=torch.device("cuda"))
base_model = copy.deepcopy(model)
test_model = copy.deepcopy(model)
# compose checkpoint and fully_shard
test_model.seq = checkpoint(test_model.seq, use_reentrant=use_reentrant)
test_model.seq = fully_shard(
test_model.seq,
policy=ModuleWrapPolicy({nn.Linear}),
)
self.run_subtests(
{
"base_model": [base_model],
"test_model": [test_model],
"inp_size": [torch.Size((2, 100))],
"inp_device": [torch.device("cuda")],
"grad_to_none": [True, False],
"use_same_inputs_across_ranks": [True],
},
self._test_parity,
)
def _test_checkpoint_fsdp_submodules(self, use_reentrant):
model = CompositeModel(device=torch.device("cuda"))
base_model = copy.deepcopy(model)
test_model = copy.deepcopy(model)
test_model.u1 = fully_shard(test_model.u1, policy=None)
test_model.u2 = fully_shard(test_model.u2)
test_model.u1.seq = checkpoint(test_model.u1.seq, use_reentrant=use_reentrant)
test_model.u2.seq = checkpoint(test_model.u2.seq, use_reentrant=use_reentrant)
self.run_subtests(
{
"base_model": [base_model],
"test_model": [test_model],
"inp_size": [torch.Size((2, 100))],
"inp_device": [torch.device("cuda")],
"grad_to_none": [True, False],
"use_same_inputs_across_ranks": [True],
},
self._test_parity,
)
@skip_if_lt_x_gpu(2)
def test_checkpoint_fsdp_submodules_use_reentrant(self):
# Escape the brackets like `\[` since `[` has special meaning in regex
with self.assertRaisesRegex(
RuntimeError,
r"setStorage: sizes \[100, 100\], strides \[100, 1\], storage "
"offset 0, and itemsize 4 requiring a storage size of 40000 are "
"out of bounds for storage of size 0",
):
self._test_checkpoint_fsdp_submodules(True)
@skip_if_lt_x_gpu(2)
def test_checkpoint_fsdp_submodules_non_reentrant(self):
self._test_checkpoint_fsdp_submodules(False)
@skip_if_lt_x_gpu(2)
def test_checkpoint_fsdp_submodules_with_param(self):
model = CompositeParamModel(device=torch.device("cuda"))
base_model = copy.deepcopy(model)
test_model = copy.deepcopy(model)
test_model.u1.seq = checkpoint(test_model.u1.seq, use_reentrant=False)
test_model.u2.seq = checkpoint(test_model.u2.seq, use_reentrant=False)
test_model = fully_shard(test_model)
self.run_subtests(
{
"base_model": [base_model],
"test_model": [test_model],
"inp_size": [torch.Size((2, 100))],
"inp_device": [torch.device("cuda")],
"grad_to_none": [True, False],
"use_same_inputs_across_ranks": [True],
},
self._test_parity,
)
@skip_if_lt_x_gpu(2)
def test_checkpoint_fsdp_submodules_with_param_no_shard(self):
model = CompositeParamModel(device=torch.device("cuda"))
base_model = copy.deepcopy(model)
test_model = copy.deepcopy(model)
test_model.u1.seq = checkpoint(test_model.u1.seq, use_reentrant=False)
test_model.u2.seq = checkpoint(test_model.u2.seq, use_reentrant=False)
test_model = fully_shard(test_model, strategy=ShardingStrategy.NO_SHARD)
self.run_subtests(
{
"base_model": [base_model],
"test_model": [test_model],
"inp_size": [torch.Size((2, 100))],
"inp_device": [torch.device("cuda")],
"grad_to_none": [True, False],
"use_same_inputs_across_ranks": [True],
},
self._test_parity,
)
@skip_if_lt_x_gpu(2)
def test_composable_fsdp_replicate(self):
# Verify how the APIs can be composed, e.g. if both `fully_shard` and
# `replicate` are applied on the same module, it should raise exception.
model = CompositeModel(device=torch.device("cpu"))
fully_shard(model.l1)
with self.assertRaisesRegex(AssertionError, "Cannot apply .*replicate"):
replicate(model.l1)
replicate(model.l2) # should not raise
@skip_if_lt_x_gpu(2)
def test_fully_shard_replicate_composability(self):
"""
Tests composing ``fully_shard`` and ``replicate``. To save unit test
time, we run the different configs in subtests.
"""
self.run_subtests(
{
"config": [
"1fm,1r",
"1r,1fm",
"1r,1fa",
"1r1fm,1fm",
"1r1fa,1fm",
"1fm1fm,1r1r,1fm",
]
},
self._test_replicate_in_fully_shard,
)
def _test_replicate_in_fully_shard(self, config: str):
"""
To interpret the config, each comma delineates a level in the module
tree ordered bottom-up; 'r' means ``replicate``; 'f' means
``fully_shard``; 'a' means auto wrap; and 'm' means manual wrap.
"""
# Set the seed to ensure that all ranks initialize the same model
torch.manual_seed(0)
if config == "1fm,1r":
base_model = CompositeModel(device=torch.device("cuda"))
test_model = copy.deepcopy(base_model)
fully_shard(test_model.l1)
replicate(test_model)
elif config == "1r,1fm":
base_model = CompositeParamModel(torch.device("cuda"))
test_model = copy.deepcopy(base_model)
replicate(test_model.u1)
fully_shard(test_model)
elif config == "1r,1fa":
base_model = CompositeParamModel(torch.device("cuda"))
test_model = copy.deepcopy(base_model)
replicate(test_model.u1)
fully_shard(test_model, policy=ModuleWrapPolicy({UnitModule}))
elif config == "1r1fm,1fm":
base_model = CompositeParamModel(torch.device("cuda"))
test_model = copy.deepcopy(base_model)
replicate(test_model.u1)
fully_shard(test_model.u2)
fully_shard(test_model)
elif config == "1r1fa,1fm":
base_model = CompositeParamModel(torch.device("cuda"))
test_model = copy.deepcopy(base_model)
replicate(test_model.u1)
fully_shard(test_model.u2, policy=ModuleWrapPolicy({UnitModule}))
fully_shard(test_model)
elif config == "1fm1fm,1r1r,1fm":
base_model = CompositeParamModel(torch.device("cuda"))
test_model = copy.deepcopy(base_model)
fully_shard(test_model.u1.seq)
fully_shard(test_model.u2.seq)
replicate(test_model.u1)
replicate(test_model.u2)
fully_shard(test_model)
else:
raise ValueError(f"Unknown config: {config}")
# Apply data parallelism to the base model for parity since we apply
# data parallelism to the test model
replicate(base_model)
# Set the seed to ensure that ranks get different input data
torch.manual_seed(self.rank + 1)
self._test_parity(
base_model,
test_model,
torch.Size((2, 100)),
torch.device("cuda"),
True,
False,
)
@skip_if_lt_x_gpu(2)
def test_state_dict_fsdp_submodules(self):
model = CompositeModel(device=torch.device("cuda"))
full_shard_args = {"strategy": ShardingStrategy.FULL_SHARD}
no_shard_args = {"strategy": ShardingStrategy.NO_SHARD}
model.u1 = fully_shard(model.u1, **full_shard_args)
model.u2 = fully_shard(model.u2, **no_shard_args)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
)
state_dict = model.state_dict()
for fqn, tensor in state_dict.items():
if "u1" in fqn:
self.assertIsInstance(tensor, ShardedTensor)
elif "u2" in fqn:
self.assertIsInstance(tensor, torch.Tensor)
# Ensure that get_state_dict_type can still correctly get the settings.
_ = FSDP.get_state_dict_type(model)
instantiate_parametrized_tests(TestFSDPCheckpoint)
if __name__ == "__main__":
run_tests()