blob: eddfa602fec293bf402ae665ef115854a39237e6 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
from copy import deepcopy
import torch
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._tensor import init_device_mesh, Replicate
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
ShardingStrategy,
StateDictType,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Linear(5, 8)
self.relu = nn.ReLU()
self.net2 = nn.Linear(8, 4)
self.net3 = nn.Linear(4, 12)
def forward(self, x):
x = F.relu(self.net1(x))
x = F.relu(self.net2(x))
x = F.relu(self.net3(x))
return x
def get_input(self):
return torch.rand(4, 5, device="cuda")
class SimpleModelUneven(torch.nn.Module):
def __init__(self):
super().__init__()
self.net1 = nn.Linear(5, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 15)
self.net3 = nn.Linear(15, 30)
self.net4 = nn.Linear(30, 5)
def forward(self, x):
x = F.relu(self.net1(x))
x = F.relu(self.net2(x))
x = F.relu(self.net3(x))
x = F.relu(self.net4(x))
return x
def get_input(self):
return torch.rand(4, 5, device="cuda")
class TestHSDPCheckpoint(DTensorTestBase):
@property
def backend(self):
return "cpu:gloo,cuda:nccl"
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("is_even_sharded_model", [True, False])
def test_hsdp_checkpoint(self, is_even_sharded_model) -> None:
CHECKPOINT_DIR = self.temp_dir
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
model = FSDP(
simple_model().cuda(),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
device_mesh=mesh_2d,
)
optim = torch.optim.Adam(model.parameters(), lr=0.1)
FSDP.set_state_dict_type(
model,
StateDictType.SHARDED_STATE_DICT,
)
state_dict = {"model": model.state_dict()}
state_dict_to_save = deepcopy(state_dict)
dist_cp.save_state_dict(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=DefaultSavePlanner(),
)
# Update the parameters so current model state_dict now be different from state_dict_to_save.
model(model.get_input()).sum().backward()
optim.step()
# At this point, the current state dict is different from state_dict_to_save.
for (k1, v1), (k2, v2) in zip(
state_dict_to_save["model"].items(), model.state_dict().items()
):
self.assertEqual(k1, k2)
self.assertEqual(v1.device_mesh, v2.device_mesh)
self.assertEqual(v1.placements, v2.placements)
self.assertNotEqual(v1.to_local(), v2.to_local())
dist_cp.load_state_dict(
state_dict=state_dict_to_save,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=DefaultLoadPlanner(),
)
model.load_state_dict(state_dict_to_save["model"])
state_dict_after_load = model.state_dict()
# After loading, the current model state dict should be the same as state_dict_to_save.
for (k1, v1), (k2, v2) in zip(
state_dict_to_save["model"].items(), model.state_dict().items()
):
self.assertEqual(k1, k2)
self.assertEqual(v1.device_mesh, v2.device_mesh)
self.assertEqual(v1.placements, v2.placements)
self.assertEqual(v1.to_local(), v2.to_local())
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
@parametrize("is_even_sharded_model", [True, False])
def test_hsdp_fsdp_checkpoint_conversion(self, is_even_sharded_model) -> None:
CHECKPOINT_DIR = self.temp_dir
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
# save the hsdp model state_dict
mesh_2d = init_device_mesh(self.device_type, (2, self.world_size // 2))
hsdp_model = FSDP(
simple_model().cuda(),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
device_mesh=mesh_2d,
)
FSDP.set_state_dict_type(
hsdp_model,
StateDictType.SHARDED_STATE_DICT,
)
hsdp_state_dict = {"model": hsdp_model.state_dict()}
dist_cp.save_state_dict(
state_dict=hsdp_state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=DefaultSavePlanner(),
)
# initialize a fsdp model to load checkpoint into
mesh_1d = init_device_mesh(self.device_type, (self.world_size,))
fsdp_model = FSDP(
simple_model().cuda(),
device_mesh=mesh_1d,
)
FSDP.set_state_dict_type(
fsdp_model,
StateDictType.SHARDED_STATE_DICT,
)
fsdp_state_dict = {"model": fsdp_model.state_dict()}
# at this point, the hsdp model parameters are different from fsdp model parameters.
for (k1, v1), (k2, v2) in zip(
hsdp_state_dict["model"].items(), fsdp_state_dict["model"].items()
):
self.assertEqual(k1, k2)
self.assertNotEqual(v1.device_mesh, v2.device_mesh)
self.assertNotEqual(v1.placements, v2.placements)
v1_all_gather = v1.redistribute(
mesh_2d, placements=(Replicate(), Replicate())
)
v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
self.assertNotEqual(v1_all_gather.to_local(), v2_all_gather.to_local())
# load the fsdp state_dict from storage
dist_cp.load_state_dict(
state_dict=fsdp_state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=DefaultLoadPlanner(),
)
fsdp_model.load_state_dict(fsdp_state_dict["model"])
state_dict_after_load = fsdp_model.state_dict()
# After loading, the current model state dict should be the same as hsdp_state_dict.
for (k1, v1), (k2, v2) in zip(
hsdp_state_dict["model"].items(), state_dict_after_load.items()
):
self.assertEqual(k1, k2)
self.assertNotEqual(v1.device_mesh, v2.device_mesh)
self.assertNotEqual(v1.placements, v2.placements)
v1_all_gather = v1.redistribute(
mesh_2d, placements=(Replicate(), Replicate())
)
v2_all_gather = v2.redistribute(mesh_1d, placements=(Replicate(),))
self.assertEqual(v1_all_gather.to_local(), v2_all_gather.to_local())
instantiate_parametrized_tests(TestHSDPCheckpoint)
if __name__ == "__main__":
run_tests()