[state_dict][7/N] Add a fine tuning e2e test case for distributed.state_dict and DCP (#111111)
As title
Differential Revision: [D50209732](https://our.internmc.facebook.com/intern/diff/D50209732/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111111
Approved by: https://github.com/wz337
ghstack dependencies: #111106, #111107, #111275, #111109, #111110, #111120
diff --git a/test/distributed/checkpoint/e2e/test_fine_tuning.py b/test/distributed/checkpoint/e2e/test_fine_tuning.py
new file mode 100644
index 0000000..207c6ad
--- /dev/null
+++ b/test/distributed/checkpoint/e2e/test_fine_tuning.py
@@ -0,0 +1,185 @@
+# Owner(s): ["oncall: distributed"]
+
+import os
+import sys
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dist_cp
+import torch.nn as nn
+from torch.distributed.checkpoint.state_dict import (
+ get_state_dict,
+ set_state_dict,
+ StateDictOptions,
+)
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_fsdp import FSDPTest
+from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
+from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
+
+
+if not dist.is_available():
+ print("Distributed not available, skipping tests", file=sys.stderr)
+ sys.exit(0)
+
+if TEST_WITH_DEV_DBG_ASAN:
+ print(
+ "Skip dev-asan as torch + multiprocessing spawn have known issues",
+ file=sys.stderr,
+ )
+ sys.exit(0)
+
+
+DIM = 500
+
+
+class PreTrainedModel(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.layer1 = nn.Linear(DIM, DIM)
+ self.layer2 = nn.Linear(DIM, DIM)
+ self.layer3 = nn.Linear(DIM, DIM)
+ self.sequential = nn.Sequential(nn.Linear(DIM, DIM), nn.ReLU())
+ self.module_list = nn.ModuleList([nn.Linear(DIM, DIM), nn.ReLU()])
+ self.relu = nn.ReLU()
+
+ def forward(self, batch):
+ x = self.relu(self.layer1(batch))
+ x = self.relu(self.layer2(x))
+ x = self.relu(self.layer3(x))
+ x = self.sequential(x)
+ x = self.module_list[1](self.module_list[0](x))
+ return x
+
+
+class FineTuningModel(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+ self.pretrain = PreTrainedModel()
+ for p in self.pretrain.parameters():
+ p.requires_grad = False
+
+ self.layer1 = nn.Linear(DIM, DIM)
+ self.layer2 = nn.Linear(DIM, DIM)
+ self.layer3 = nn.Linear(DIM, DIM)
+ self.relu = nn.ReLU()
+
+ def forward(self, batch):
+ x = self.relu(self.pretrain(batch))
+ x = self.relu(self.layer1(x))
+ x = self.relu(self.layer2(x))
+ x = self.relu(self.layer3(x))
+ return x
+
+
+class TestFineTuning(FSDPTest):
+ @property
+ def world_size(self) -> int:
+ return min(4, torch.cuda.device_count())
+
+ def pretrain(self, pretrain_dir: str) -> None:
+ model = PreTrainedModel().cuda()
+ model = FSDP(model)
+ optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ # Trainining
+ for i in range(3):
+ batch = torch.rand(32, DIM, device="cuda")
+ loss = model(batch).sum()
+ loss.backward()
+ optim.step()
+ optim.zero_grad()
+
+ # Save state_dict
+ model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
+ saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
+ dist_cp.save_state_dict(
+ state_dict=saved_state_dict,
+ storage_writer=dist_cp.FileSystemWriter(pretrain_dir),
+ )
+
+ def finetune(self, pretrain_dir: str, finetune_dir: str) -> None:
+ model = FineTuningModel().cuda()
+ # TODO: make the parallelism more complicated, e.g., using 2D + DDP.
+ model = FSDP(model, use_orig_params=True)
+ optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ # Simulate that the fine tuning restart after 3 iterations
+ for i in range(2):
+ # Load pretrain submodules checkpoint
+ pretrain_state_dict, _ = get_state_dict(
+ model,
+ submodules={model.pretrain},
+ options=StateDictOptions(keep_submodule_prefixes=False),
+ )
+ dist_cp.load_state_dict(
+ {"model": pretrain_state_dict},
+ storage_reader=dist_cp.FileSystemReader(pretrain_dir),
+ )
+ set_state_dict(
+ model,
+ model_state_dict={model.pretrain: pretrain_state_dict},
+ options=StateDictOptions(strict=False),
+ )
+
+ try:
+ # Load training submodules checkpoint
+ model_state_dict, optim_state_dict = get_state_dict(
+ model,
+ optimizers=optim,
+ options=StateDictOptions(ignore_frozen_params=True),
+ )
+ dist_cp.load_state_dict(
+ {"model": model_state_dict, "optim": optim_state_dict},
+ storage_reader=dist_cp.FileSystemReader(pretrain_dir),
+ )
+ set_state_dict(
+ model,
+ optimizers=optim,
+ model_state_dict=model_state_dict,
+ optim_state_dict=optim_state_dict,
+ options=StateDictOptions(strict=False),
+ )
+ except KeyError:
+ # If this is the first round of the fine tuning, then nothing is saved.
+ # If this is the restart of the fine tuning, then checkpoint should exit.
+ self.assertEqual(i, 0)
+
+ # Trainining
+ for j in range(3):
+ batch = torch.rand(32, DIM, device="cuda")
+ loss = model(batch).sum()
+ loss.backward()
+ optim.step()
+ optim.zero_grad()
+
+ # Save state_dict
+ model_state_dict, optim_state_dict = get_state_dict(
+ model,
+ optimizers=optim,
+ options=StateDictOptions(ignore_frozen_params=True),
+ )
+ saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
+ dist_cp.save_state_dict(
+ state_dict=saved_state_dict,
+ storage_writer=dist_cp.FileSystemWriter(finetune_dir),
+ )
+
+ @skip_if_lt_x_gpu(4)
+ @with_temp_dir
+ def test_fine_tuning(self) -> None:
+ self.assertTrue(os.path.exists(self.temp_dir))
+ pretrain_dir = os.path.join(self.temp_dir, "pretrain")
+ finetune_dir = os.path.join(self.temp_dir, "finetune")
+ print(pretrain_dir, finetune_dir)
+ if self.rank == 0:
+ os.mkdir(pretrain_dir)
+ os.mkdir(finetune_dir)
+ dist.barrier()
+ os.sync()
+ self.assertTrue(os.path.exists(pretrain_dir))
+ self.assertTrue(os.path.exists(finetune_dir))
+
+ self.pretrain(pretrain_dir)
+ self.finetune(pretrain_dir, finetune_dir)