[state_dict][7/N] Add a fine tuning e2e test case for distributed.state_dict and DCP (#111111)

As title

Differential Revision: [D50209732](https://our.internmc.facebook.com/intern/diff/D50209732/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111111
Approved by: https://github.com/wz337
ghstack dependencies: #111106, #111107, #111275, #111109, #111110, #111120
diff --git a/test/distributed/checkpoint/e2e/test_fine_tuning.py b/test/distributed/checkpoint/e2e/test_fine_tuning.py
new file mode 100644
index 0000000..207c6ad
--- /dev/null
+++ b/test/distributed/checkpoint/e2e/test_fine_tuning.py
@@ -0,0 +1,185 @@
+# Owner(s): ["oncall: distributed"]
+
+import os
+import sys
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dist_cp
+import torch.nn as nn
+from torch.distributed.checkpoint.state_dict import (
+    get_state_dict,
+    set_state_dict,
+    StateDictOptions,
+)
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_fsdp import FSDPTest
+from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
+from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
+
+
+if not dist.is_available():
+    print("Distributed not available, skipping tests", file=sys.stderr)
+    sys.exit(0)
+
+if TEST_WITH_DEV_DBG_ASAN:
+    print(
+        "Skip dev-asan as torch + multiprocessing spawn have known issues",
+        file=sys.stderr,
+    )
+    sys.exit(0)
+
+
+DIM = 500
+
+
+class PreTrainedModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.layer1 = nn.Linear(DIM, DIM)
+        self.layer2 = nn.Linear(DIM, DIM)
+        self.layer3 = nn.Linear(DIM, DIM)
+        self.sequential = nn.Sequential(nn.Linear(DIM, DIM), nn.ReLU())
+        self.module_list = nn.ModuleList([nn.Linear(DIM, DIM), nn.ReLU()])
+        self.relu = nn.ReLU()
+
+    def forward(self, batch):
+        x = self.relu(self.layer1(batch))
+        x = self.relu(self.layer2(x))
+        x = self.relu(self.layer3(x))
+        x = self.sequential(x)
+        x = self.module_list[1](self.module_list[0](x))
+        return x
+
+
+class FineTuningModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.pretrain = PreTrainedModel()
+        for p in self.pretrain.parameters():
+            p.requires_grad = False
+
+        self.layer1 = nn.Linear(DIM, DIM)
+        self.layer2 = nn.Linear(DIM, DIM)
+        self.layer3 = nn.Linear(DIM, DIM)
+        self.relu = nn.ReLU()
+
+    def forward(self, batch):
+        x = self.relu(self.pretrain(batch))
+        x = self.relu(self.layer1(x))
+        x = self.relu(self.layer2(x))
+        x = self.relu(self.layer3(x))
+        return x
+
+
+class TestFineTuning(FSDPTest):
+    @property
+    def world_size(self) -> int:
+        return min(4, torch.cuda.device_count())
+
+    def pretrain(self, pretrain_dir: str) -> None:
+        model = PreTrainedModel().cuda()
+        model = FSDP(model)
+        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+        # Trainining
+        for i in range(3):
+            batch = torch.rand(32, DIM, device="cuda")
+            loss = model(batch).sum()
+            loss.backward()
+            optim.step()
+            optim.zero_grad()
+
+        # Save state_dict
+        model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
+        saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
+        dist_cp.save_state_dict(
+            state_dict=saved_state_dict,
+            storage_writer=dist_cp.FileSystemWriter(pretrain_dir),
+        )
+
+    def finetune(self, pretrain_dir: str, finetune_dir: str) -> None:
+        model = FineTuningModel().cuda()
+        # TODO: make the parallelism more complicated, e.g., using 2D + DDP.
+        model = FSDP(model, use_orig_params=True)
+        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+        # Simulate that the fine tuning restart after 3 iterations
+        for i in range(2):
+            # Load pretrain submodules checkpoint
+            pretrain_state_dict, _ = get_state_dict(
+                model,
+                submodules={model.pretrain},
+                options=StateDictOptions(keep_submodule_prefixes=False),
+            )
+            dist_cp.load_state_dict(
+                {"model": pretrain_state_dict},
+                storage_reader=dist_cp.FileSystemReader(pretrain_dir),
+            )
+            set_state_dict(
+                model,
+                model_state_dict={model.pretrain: pretrain_state_dict},
+                options=StateDictOptions(strict=False),
+            )
+
+            try:
+                # Load training submodules checkpoint
+                model_state_dict, optim_state_dict = get_state_dict(
+                    model,
+                    optimizers=optim,
+                    options=StateDictOptions(ignore_frozen_params=True),
+                )
+                dist_cp.load_state_dict(
+                    {"model": model_state_dict, "optim": optim_state_dict},
+                    storage_reader=dist_cp.FileSystemReader(pretrain_dir),
+                )
+                set_state_dict(
+                    model,
+                    optimizers=optim,
+                    model_state_dict=model_state_dict,
+                    optim_state_dict=optim_state_dict,
+                    options=StateDictOptions(strict=False),
+                )
+            except KeyError:
+                # If this is the first round of the fine tuning, then nothing is saved.
+                # If this is the restart of the fine tuning, then checkpoint should exit.
+                self.assertEqual(i, 0)
+
+            # Trainining
+            for j in range(3):
+                batch = torch.rand(32, DIM, device="cuda")
+                loss = model(batch).sum()
+                loss.backward()
+                optim.step()
+                optim.zero_grad()
+
+            # Save state_dict
+            model_state_dict, optim_state_dict = get_state_dict(
+                model,
+                optimizers=optim,
+                options=StateDictOptions(ignore_frozen_params=True),
+            )
+            saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
+            dist_cp.save_state_dict(
+                state_dict=saved_state_dict,
+                storage_writer=dist_cp.FileSystemWriter(finetune_dir),
+            )
+
+    @skip_if_lt_x_gpu(4)
+    @with_temp_dir
+    def test_fine_tuning(self) -> None:
+        self.assertTrue(os.path.exists(self.temp_dir))
+        pretrain_dir = os.path.join(self.temp_dir, "pretrain")
+        finetune_dir = os.path.join(self.temp_dir, "finetune")
+        print(pretrain_dir, finetune_dir)
+        if self.rank == 0:
+            os.mkdir(pretrain_dir)
+            os.mkdir(finetune_dir)
+        dist.barrier()
+        os.sync()
+        self.assertTrue(os.path.exists(pretrain_dir))
+        self.assertTrue(os.path.exists(finetune_dir))
+
+        self.pretrain(pretrain_dir)
+        self.finetune(pretrain_dir, finetune_dir)