blob: 92c9c352b3c5ea93248dae051d46e8572192018f [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import copy
import functools
import sys
from itertools import chain
from typing import Callable, Tuple, Type, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed._composable import fully_shard, replicate
# importing fully_shard as FSDP2 since the original fully_shard is used in this test.
# TODO: remove old composable fully_shard so that we don't have to import new fully_shard as FSDP2
from torch.distributed._composable.fsdp import (
fully_shard as FSDP2,
fully_shard as fsdp_fully_shard,
)
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._tensor import DTensor, init_device_mesh
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.checkpoint import state_dict as ptd_state_dict
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
_patch_optimizer_state_dict,
get_model_state_dict,
get_optimizer_state_dict,
get_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.optim import _apply_optimizer_in_backward
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.testing._internal.common_dist_composable import (
CompositeParamModel,
UnitModule,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MultiProcessTestCase,
with_comms,
)
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
from torch.utils._pytree import tree_all, tree_all_only
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestStateDict(DTensorTestBase, VerifyStateDictMixin):
"""Tests state_dict and load_state_dict"""
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
def _test_save_load(
self,
init_model_optim: Callable,
test_frozen: bool = False,
) -> None:
options = StateDictOptions(ignore_frozen_params=test_frozen)
# Initialize original model and distributed model.
model, optim, copy_optim, dist_model, dist_optim = init_model_optim()
# Train 10 steps.
for i in range(10):
batch = torch.rand(8, 100, device="cuda")
model(batch).sum().backward()
optim.step()
dist_model(batch).sum().backward()
if not isinstance(dist_optim, list):
dist_optim.step()
dist_optim.zero_grad()
else:
for _dist_optim in dist_optim:
_dist_optim.zero_grad()
optim.zero_grad()
# Get the state_dict, and compare the result
msd = model.state_dict()
osd = optim.state_dict()
dist_msd, dist_osd = get_state_dict(
dist_model, optimizers=dist_optim, options=options
)
self._verify_msd(msd, dist_msd, options)
self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
self._verify_osd(model, optim, osd, dist_osd)
# Initialize a completely new model to simulate checkpoint load.
_, _, _, dist_model, dist_optim = init_model_optim()
# Simulate DCP distributed load. We need to first get the state_dict and
# pass them to DCP to load the saved state_dict from the storage.
# Then finally we can call set_state_dict().
if not isinstance(dist_optim, list):
dist_optim = [dist_optim]
if test_frozen:
# We won't be able to load the partial state_dict back.
return
# Since we already have the state_dict saved before, no need to call DCP.
# We can directly load them back. This asser is to ensure that optimizer
# state storage are initialized.
# self.assertEqual(len(curr_dist_osd[STATE]), len(dist_osd[STATE]))
set_model_state_dict(
dist_model,
model_state_dict=dist_msd,
options=options,
)
set_optimizer_state_dict(
dist_model,
optimizers=dist_optim,
optim_state_dict=dist_osd,
options=options,
)
# Check if the new state_dict are the same
dist_msd, dist_osd = get_state_dict(
dist_model, optimizers=dist_optim, options=options
)
self._verify_msd(msd, dist_msd, options)
# TODO: Ditto
# self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
self._verify_osd(model, optim, osd, dist_osd)
# Test _patch_model_state_dict, and _patch_optimizer_state_dict
_patch_model_state_dict(dist_model, options=options)
_patch_optimizer_state_dict(dist_model, optimizers=dist_optim, options=options)
dist_msd = dist_model.state_dict()
dist_osd = dist_optim[0].state_dict()
self._verify_msd(msd, dist_msd, options)
self._verify_osd_by_load(model, optim, copy_optim, dist_osd)
self._verify_osd(model, optim, osd, dist_osd)
def _test_fsdp(
self,
*,
use_orig_params: bool,
use_composable: bool,
use_dtensor: bool,
wrapping: Tuple[nn.Module] = (),
compile_model: bool = False,
optimizer_class: Type[Optimizer],
) -> None:
if not use_orig_params and use_composable:
return
# TODO: remove this return after we complete the composable API side change for device_mesh
if use_composable and use_dtensor:
return
def init_model_optim():
if use_dtensor:
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
if wrapping:
strategy = set(wrapping)
else:
strategy = {UnitModule}
if use_composable:
dist_model = fully_shard(
copy.deepcopy(orig_model), policy=ModuleWrapPolicy(strategy)
)
else:
if use_dtensor:
device_mesh = init_device_mesh("cuda", (self.world_size,))
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy(strategy),
use_orig_params=use_orig_params,
device_mesh=device_mesh,
)
else:
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy(strategy),
use_orig_params=use_orig_params,
)
if compile_model:
dist_model = torch.compile(dist_model)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp(self) -> None:
self.run_subtests(
{
"use_orig_params": [True, False],
"use_composable": [True, False],
"use_dtensor": [True, False],
"wrapping": [(), (nn.Linear, UnitModule)],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
},
self._test_fsdp,
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_compiled_fsdp(self) -> None:
self.run_subtests(
{
"use_orig_params": [True],
"use_composable": [False],
"use_dtensor": [False],
"wrapping": [()],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
},
self._test_fsdp,
)
def _test_fsdp2(
self,
*,
reshard_after_forward: Union[bool, int],
optimizer_class: Type[Optimizer],
compile_model: bool,
foreach: bool = True,
):
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(
orig_model.parameters(), lr=1e-3, foreach=foreach
)
copy_optim = optimizer_class(
orig_model.parameters(), lr=1e-3, foreach=foreach
)
dist_model = FSDP2(
copy.deepcopy(orig_model),
reshard_after_forward=reshard_after_forward,
)
if compile_model:
dist_model = torch.compile(dist_model)
dist_optim = optimizer_class(
dist_model.parameters(), lr=1e-3, foreach=foreach
)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp2(self) -> None:
self.run_subtests(
{
"reshard_after_forward": [True, False],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
"compile_model": [True, False],
},
self._test_fsdp2,
)
def _test_ddp(self, use_composable: bool, optimizer_class: Type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
if use_composable:
dist_model = replicate(copy.deepcopy(orig_model))
else:
dist_model = DDP(copy.deepcopy(orig_model))
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
@with_comms
@skip_if_lt_x_gpu(2)
def test_ddp(self) -> None:
self.run_subtests(
{
"use_composable": [True, False],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
},
self._test_ddp,
)
def _test_fsdp_ddp(
self,
use_composable: bool,
optimizer_class: Type[Optimizer],
optim_in_backward: bool = False,
test_frozen: bool = False,
) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
if test_frozen:
for param in chain(
orig_model.u1.parameters(), orig_model.u2.parameters()
):
param.requires_grad = False
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
dist_model = copy.deepcopy(orig_model)
if use_composable:
replicate(dist_model.l)
fully_shard(dist_model, policy=ModuleWrapPolicy({UnitModule}))
else:
dist_model.l = DDP(dist_model.l)
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
use_orig_params=optim_in_backward,
ignored_modules=[dist_model.l],
)
if optim_in_backward:
_apply_optimizer_in_backward(
optimizer_class, dist_model.parameters(), {"lr": 1e-3}
)
dist_optim = [
p._in_backward_optimizers[0] for p in dist_model.parameters()
]
else:
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim, test_frozen)
@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp_ddp(self) -> None:
self.run_subtests(
{
"use_composable": [True, False],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
},
self._test_fsdp_ddp,
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_frozen_parameters(self) -> None:
self.run_subtests(
{
"use_composable": [True],
"optimizer_class": [torch.optim.Adam, torch.optim.AdamW],
"test_frozen": [True],
},
self._test_fsdp_ddp,
)
# TODO: enable use_dtensor once 2D device_mesh support is fully landed.
"""
@with_comms
@skip_if_lt_x_gpu(2)
def test_use_dtensor(self) -> None:
self._test_fsdp_ddp(use_composable=False, use_dtensor=True)
"""
# TODO: enable the test after FSDP + apply_optimizer_in_backward works.
# Disable this test as it is broken after
# https://github.com/pytorch/pytorch/pull/108298.
"""
@with_comms
@skip_if_lt_x_gpu(2)
def test_apply_optimizer_in_backward(self) -> None:
self.run_subtests(
{"use_composable": [True, False]},
self._test_fsdp_ddp,
optim_in_backward=True,
)
"""
def _test_single_gpu(self, optimizer_class: Type[Optimizer]) -> None:
def init_model_optim():
orig_model = CompositeParamModel(device=torch.device("cuda"))
orig_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
copy_optim = optimizer_class(orig_model.parameters(), lr=1e-3)
model_copy = copy.deepcopy(orig_model)
optim_copy = optimizer_class(model_copy.parameters(), lr=1e-3)
return orig_model, orig_optim, copy_optim, model_copy, optim_copy
self._test_save_load(init_model_optim)
@with_comms
@skip_if_lt_x_gpu(1)
def test_single_gpu(self) -> None:
self.run_subtests(
{"optimizer_class": [torch.optim.Adam, torch.optim.AdamW]},
self._test_single_gpu,
)
@with_comms
@skip_if_lt_x_gpu(1)
def test_strict(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model_state_dict = get_model_state_dict(model)
key = next(iter(model_state_dict.keys()))
model_state_dict["abc"] = torch.zeros(10)
with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
set_model_state_dict(model, model_state_dict=model_state_dict)
model_state_dict.pop(key)
incompatible_keys = set_model_state_dict(
model,
model_state_dict=model_state_dict,
options=StateDictOptions(strict=False),
)
self.assertEqual(incompatible_keys.missing_keys, [key])
self.assertEqual(incompatible_keys.unexpected_keys, ["abc"])
model_state_dict.pop("abc")
with self.assertRaisesRegex(RuntimeError, "Missing key"):
set_model_state_dict(model, model_state_dict=model_state_dict)
def _test_cpu_offload_full_state_dict(
self, optimizer_class: Type[Optimizer]
) -> None:
orig_model = CompositeParamModel(device=torch.device("cuda"))
device_mesh = init_device_mesh("cuda", (self.world_size,))
dist_model = FSDP(
copy.deepcopy(orig_model),
auto_wrap_policy=ModuleWrapPolicy({UnitModule}),
use_orig_params=True,
device_mesh=device_mesh,
)
dist_optim = optimizer_class(dist_model.parameters(), lr=1e-3)
mst, ost = get_state_dict(
dist_model,
dist_optim,
options=StateDictOptions(cpu_offload=True),
)
cpu_device = torch.device("cpu")
def is_cpu(v):
if isinstance(v, DTensor):
return v.device == cpu_device
elif isinstance(v, ShardedTensor):
shards = v.local_shards()
if not shards:
return True
return shards[0].tensor.device == cpu_device
else:
return v.device == cpu_device
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst)
)
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost)
)
mst, ost = get_state_dict(
dist_model, dist_optim, options=StateDictOptions(full_state_dict=True)
)
self.assertTrue(
tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), mst)
)
self.assertTrue(
tree_all(lambda v: not isinstance(v, (DTensor, ShardedTensor)), ost)
)
mst, ost = get_state_dict(
dist_model,
dist_optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
if self.rank == 0:
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, mst)
)
self.assertTrue(
tree_all_only((torch.Tensor, DTensor, ShardedTensor), is_cpu, ost)
)
else:
self.assertEqual(mst, {})
self.assertEqual(ost, {})
@with_comms
@skip_if_lt_x_gpu(2)
def test_cpu_offload_full_state_dict(self) -> None:
self.run_subtests(
{"optimizer_class": [torch.optim.Adam, torch.optim.AdamW]},
self._test_cpu_offload_full_state_dict,
)
@with_comms
@skip_if_lt_x_gpu(1)
def test_activation_ckpt_fqns_ddp(self) -> None:
"""Tests that activation checkpointing prefixes are removed from module names"""
model = CompositeParamModel(device=torch.device("cuda"))
original_keys = get_model_state_dict(model).keys()
apply_activation_checkpointing(model)
model = DDP(model)
new_keys = get_model_state_dict(model).keys()
self.assertEqual(original_keys, new_keys)
@with_comms
@skip_if_lt_x_gpu(1)
def test_activation_ckpt_fqns_fsdp1(self) -> None:
self.run_subtests(
{"use_orig_params": [True, False]},
self._test_activation_ckpt_fqns_fsdp1,
)
def _test_activation_ckpt_fqns_fsdp1(self, use_orig_params: bool) -> None:
"""Tests that activation checkpointing prefixes are removed from module names"""
model = CompositeParamModel(device=torch.device("cuda"))
original_keys = get_model_state_dict(model).keys()
apply_activation_checkpointing(model)
model = FSDP(model, use_orig_params=use_orig_params)
new_keys = get_model_state_dict(model).keys()
self.assertEqual(original_keys, new_keys)
@with_comms
@skip_if_lt_x_gpu(1)
def test_extra_state(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
def get_extra_state(self):
return "MyState"
def set_extra_state(self, state):
return
UnitModule.get_extra_state = get_extra_state
UnitModule.set_extra_state = set_extra_state
ddp_model = DDP(copy.deepcopy(model))
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
self.assertEqual(model.state_dict()["u1._extra_state"], "MyState")
self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
@with_comms
@skip_if_lt_x_gpu(1)
def test_non_persistent_buffers(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model.register_buffer(
"dont_save_me", torch.rand(100, device="cuda"), persistent=False
)
ddp_model = DDP(copy.deepcopy(model))
set_model_state_dict(ddp_model, get_model_state_dict(ddp_model))
self.assertEqual(model.state_dict(), get_model_state_dict(ddp_model))
def _test_broadcast_from_rank0(self, wrapper) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
optim = torch.optim.Adam(model.parameters())
fsdp_model = wrapper(copy.deepcopy(model))
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
batch = torch.rand(8, 100, device="cuda")
model(batch).sum().backward()
optim.step()
states, optim_states = get_state_dict(model, optim)
fsdp_model(batch).sum().backward()
fsdp_optim.step()
def check(equal):
fsdp_states = get_model_state_dict(
fsdp_model,
options=StateDictOptions(full_state_dict=True),
)
fsdp_optim_states = get_optimizer_state_dict(
fsdp_model,
fsdp_optim,
options=StateDictOptions(full_state_dict=True),
)
if equal:
self.assertEqual(states, fsdp_states)
self.assertEqual(optim_states, fsdp_optim_states)
else:
self.assertNotEqual(states, fsdp_states)
self.assertNotEqual(optim_states, fsdp_optim_states)
check(equal=True)
fsdp_model(batch).sum().backward()
fsdp_optim.step()
check(equal=False)
# Drop the states to simulate loading from rank0
if dist.get_rank() > 0:
load_states = {}
load_states2 = {}
load_optim_states = {}
else:
load_states = copy.deepcopy(states)
load_states2 = copy.deepcopy(states)
load_optim_states = copy.deepcopy(optim_states)
set_model_state_dict(
fsdp_model,
model_state_dict=load_states,
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
)
set_optimizer_state_dict(
fsdp_model,
fsdp_optim,
optim_state_dict=load_optim_states,
options=StateDictOptions(broadcast_from_rank0=True, full_state_dict=True),
)
check(equal=True)
# Verify the `strict` flag.
load_states = load_states2
if load_states:
key = next(iter(load_states.keys()))
load_states.pop(key)
with self.assertRaisesRegex(RuntimeError, "Missing key"):
set_model_state_dict(
fsdp_model,
model_state_dict=load_states,
options=StateDictOptions(
broadcast_from_rank0=True, full_state_dict=True
),
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_broadcast_from_rank0(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
self.run_subtests(
{
"wrapper": [
functools.partial(FSDP2, mesh=device_mesh),
functools.partial(FSDP, device_mesh=device_mesh),
]
},
self._test_broadcast_from_rank0,
)
@with_comms
@skip_if_lt_x_gpu(4)
def test_broadcast_from_rank0_hsdp(self) -> None:
device_mesh = init_device_mesh("cuda", (2, self.world_size // 2))
self.run_subtests(
{
"wrapper": [
functools.partial(
FSDP,
device_mesh=device_mesh,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
),
]
},
self._test_broadcast_from_rank0,
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_fsdp_root_not_initialized(self) -> None:
# This test verifies that FSDP root is not initialized but we should
# still be able to get the state_dict without errors because
# fsdp_model.state_dict() will trigger the FSDP initialization.
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
fsdp_optim = torch.optim.Adam(fsdp_model.parameters())
get_model_state_dict(fsdp_model)
get_optimizer_state_dict(fsdp_model, fsdp_optim)
@with_comms
@skip_if_lt_x_gpu(2)
def test_optim_state_dict_param_matching(self) -> None:
# This test verifies parameters between optim and optim_state_dict
# "initial_lr" is added to optim_state_dict, but not to the new optim
# We test whether "initial_lr" appear in optim after
# set_optimizer_state_dict.
device = "cuda"
torch.manual_seed(0)
model = nn.Sequential(
*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]
)
for layer in model:
fully_shard(layer)
fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
torch.optim.lr_scheduler.LambdaLR(
optim, lr_lambda=[lambda epoch: 0.95**epoch]
)
opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
model,
optim,
options=ptd_state_dict.StateDictOptions(
full_state_dict=True, cpu_offload=True
),
)
if dist.get_rank() == 0:
self.assertTrue("initial_lr" in opt_state_dict["param_groups"][0])
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
self.assertTrue("initial_lr" not in optim.param_groups[0])
ptd_state_dict.set_optimizer_state_dict(
model,
optim,
optim_state_dict=opt_state_dict,
options=ptd_state_dict.StateDictOptions(
broadcast_from_rank0=True, full_state_dict=True
),
)
if dist.get_rank() == 0:
self.assertTrue("initial_lr" in optim.param_groups[0])
@with_comms
@skip_if_lt_x_gpu(2)
def test_optim_state_dict_tensor_matching(self) -> None:
device = "cuda"
torch.manual_seed(0)
model = nn.Sequential(
*[nn.Linear(4, 4, device=device, bias=False) for _ in range(2)]
)
for layer in model:
fsdp_fully_shard(layer)
fsdp_fully_shard(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
x = torch.randn((4, 4), device=device)
model(x).sum().backward()
optim.step()
optim.zero_grad()
self.assertIsInstance(
list(optim.state.values())[0]["exp_avg"], DTensor # noqa: RUF015
)
opt_state_dict = ptd_state_dict.get_optimizer_state_dict(
model,
optim,
options=ptd_state_dict.StateDictOptions(full_state_dict=True),
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
ptd_state_dict.set_optimizer_state_dict(
model,
optim,
optim_state_dict=opt_state_dict,
options=ptd_state_dict.StateDictOptions(full_state_dict=True),
)
self.assertIsInstance(
list(optim.state.values())[0]["exp_avg"], DTensor # noqa: RUF015
)
@with_comms
@skip_if_lt_x_gpu(2)
def test_flattened_osd(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
fsdp_model = FSDP2(copy.deepcopy(model), mesh=device_mesh)
fsdp_optim = torch.optim.AdamW(fsdp_model.parameters())
batch = torch.rand(8, 100, device="cuda")
fsdp_model(batch).sum().backward()
fsdp_optim.step()
fsdp_optim.zero_grad()
osd1 = get_optimizer_state_dict(fsdp_model, fsdp_optim)
osd2 = get_optimizer_state_dict(
fsdp_model,
fsdp_optim,
options=StateDictOptions(flatten_optimizer_state_dict=True),
)
fsdp_optim2 = torch.optim.AdamW(fsdp_model.parameters())
set_optimizer_state_dict(
fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd2
)
self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict())
set_optimizer_state_dict(
fsdp_model, optimizers=fsdp_optim2, optim_state_dict=osd1
)
self.assertEqual(fsdp_optim.state_dict(), fsdp_optim2.state_dict())
@with_comms
@skip_if_lt_x_gpu(1)
def test_deprecate_partial(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
model_state_dict1 = get_model_state_dict(model)
model_state_dict1 = copy.deepcopy(model_state_dict1)
with self.assertWarnsRegex(
FutureWarning,
"Getting submodules only model/optim state_dict is deprecated",
):
model_state_dict2 = get_model_state_dict(model, submodules={model.l})
model_state_dict2 = copy.deepcopy(model_state_dict2)
with self.assertWarnsRegex(
FutureWarning,
"Getting submodules only model/optim state_dict is deprecated",
):
model_state_dict3 = get_model_state_dict(
model,
submodules={model.l},
options=StateDictOptions(keep_submodule_prefixes=False),
)
model_state_dict3 = copy.deepcopy(model_state_dict3)
self.assertEqual(len(model_state_dict2), 2)
self.assertEqual(len(model_state_dict3), 2)
for key in model_state_dict3.keys():
full_fqn = f"l.{key}"
value1 = model_state_dict1[full_fqn]
value2 = model_state_dict2[full_fqn]
value3 = model_state_dict3[key]
self.assertEqual(value1, value2)
self.assertEqual(value2, value3)
zeros_state_dict = {
k: torch.zeros_like(v) for k, v in model_state_dict1.items()
}
model.load_state_dict(zeros_state_dict)
set_model_state_dict(
model,
model_state_dict=model_state_dict2,
options=StateDictOptions(strict=False),
)
self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
model.load_state_dict(zeros_state_dict)
with self.assertWarnsRegex(FutureWarning, "Passing model_state_dict as a "):
set_model_state_dict(
model,
model_state_dict={model.l: model_state_dict3},
options=StateDictOptions(strict=False),
)
self.assertEqual(model.l.weight, model_state_dict1["l.weight"])
self.assertEqual(model.l.bias, model_state_dict1["l.bias"])
@with_comms
@skip_if_lt_x_gpu(1)
def test_deprecate_fsdp_api(self) -> None:
device_mesh = init_device_mesh("cuda", (self.world_size,))
model = CompositeParamModel(device=torch.device("cuda"))
fsdp_model = FSDP(copy.deepcopy(model), device_mesh=device_mesh)
with self.assertWarnsRegex(
FutureWarning,
r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated",
):
with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
fsdp_model.state_dict()
with self.assertRaisesRegex(AssertionError, "FutureWarning not triggered"):
with self.assertWarnsRegex(
FutureWarning,
r"FSDP.state_dict_type\(\) and FSDP.set_state_dict_type\(\) are being deprecated",
):
get_model_state_dict(model)
@with_comms
@skip_if_lt_x_gpu(2)
def test_shared_weight(self):
class TiedEmbeddingModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.decoder = nn.Linear(embedding_dim, vocab_size)
self.decoder.weight = self.embedding.weight # Tying weights
def forward(self, input):
input = (input * 10).to(torch.int)
embedded = self.embedding(input)
output = self.decoder(embedded)
return output
def init_model_optim():
device_mesh = init_device_mesh("cuda", (self.world_size,))
orig_model = TiedEmbeddingModel(10000, 300).to(torch.device("cuda"))
orig_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
copy_optim = torch.optim.AdamW(orig_model.parameters(), lr=1e-3)
dist_model = FSDP(copy.deepcopy(orig_model), device_mesh=device_mesh)
dist_optim = torch.optim.AdamW(dist_model.parameters(), lr=1e-3)
return orig_model, orig_optim, copy_optim, dist_model, dist_optim
self._test_save_load(init_model_optim)
class TestNoComm(MultiProcessTestCase):
def setUp(self) -> None:
super().setUp()
self._spawn_processes()
@skip_if_lt_x_gpu(1)
def test_no_dist(self) -> None:
model = CompositeParamModel(device=torch.device("cuda"))
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
self.assertFalse(dist.is_initialized())
msd = get_model_state_dict(
model, options=StateDictOptions(full_state_dict=True, cpu_offload=True)
)
for v in msd.values():
self.assertFalse(v.is_cuda)
self.assertEqual(model.state_dict(), msd)
set_model_state_dict(model, model.state_dict())
osd = get_optimizer_state_dict(
model,
optim,
options=StateDictOptions(full_state_dict=True, cpu_offload=True),
)
set_optimizer_state_dict(model, optim, osd)
set_optimizer_state_dict(model, optim, optim.state_dict())
if __name__ == "__main__":
run_tests()