[state_dict][8/N] Ignore meta parameters (#112167)

This PR let `get_state_dict` ignore the parameters that are on the meta device.

This PR also demonstrates a possible use case of ignoring meta parameters -- checkpointing pipeline parallelism.

Differential Revision: [D50672521](https://our.internmc.facebook.com/intern/diff/D50672521/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112167
Approved by: https://github.com/wz337
diff --git a/test/distributed/checkpoint/e2e/test_pipeline.py b/test/distributed/checkpoint/e2e/test_pipeline.py
new file mode 100644
index 0000000..cbf60b9
--- /dev/null
+++ b/test/distributed/checkpoint/e2e/test_pipeline.py
@@ -0,0 +1,104 @@
+# Owner(s): ["oncall: distributed"]
+
+import sys
+
+import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dcp
+import torch.nn as nn
+from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_fsdp import FSDPTest
+from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
+from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
+
+
+if not dist.is_available():
+    print("Distributed not available, skipping tests", file=sys.stderr)
+    sys.exit(0)
+
+if TEST_WITH_DEV_DBG_ASAN:
+    print(
+        "Skip dev-asan as torch + multiprocessing spawn have known issues",
+        file=sys.stderr,
+    )
+    sys.exit(0)
+
+
+DIM = 500
+
+
+class PipelineModel(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.layer1 = nn.Linear(DIM, DIM)
+        self.layer2 = nn.Linear(DIM, DIM)
+        self.layer3 = nn.Linear(DIM, DIM)
+        self.layer4 = nn.Linear(DIM, DIM)
+        self.relu = nn.ReLU()
+
+    def forward(self, batch):
+        x = self.relu(self.layer1(batch))
+        x = self.relu(self.layer2(x))
+        x = self.relu(self.layer3(x))
+        x = self.relu(self.layer4(x))
+        return x
+
+
+class TestPipeline(FSDPTest):
+    @property
+    def world_size(self) -> int:
+        return min(4, torch.cuda.device_count())
+
+    def save_with_pipeline(self, pipeline_dir: str) -> None:
+        with torch.device("meta"):
+            model = PipelineModel()
+
+        pipeline_modules = [model.layer1, model.layer2, model.layer3, model.layer4]
+
+        # Materialize the model
+        submodule = pipeline_modules[self.rank]
+        submodule.to_empty(device=torch.device("cuda"))
+        # submodule.reset_parameters()
+        optim = torch.optim.Adam(submodule.parameters(), lr=1e-3)
+
+        # Ignore the training as we don't have a real pipeline parallelism.
+
+        # Save state_dict
+        model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
+        saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
+        dcp.save_state_dict(
+            state_dict=saved_state_dict,
+            storage_writer=dcp.FileSystemWriter(pipeline_dir),
+        )
+
+    def load_with_fsdp(self, pipeline_dir: str) -> None:
+        model = FSDP(PipelineModel().cuda())
+        optim = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+        # Load the checkpoint
+        model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
+        dcp.load_state_dict(
+            {"model": model_state_dict, "optim": optim_state_dict},
+            storage_reader=dcp.FileSystemReader(pipeline_dir),
+        )
+        set_state_dict(
+            model,
+            optimizers=optim,
+            model_state_dict=model_state_dict,
+            optim_state_dict=optim_state_dict,
+        )
+
+    @skip_if_lt_x_gpu(4)
+    @with_temp_dir
+    def test_pipeline(self) -> None:
+        self.assertTrue(os.path.exists(self.temp_dir))
+        pipeline_dir = os.path.join(self.temp_dir, "pipeline")
+        if self.rank == 0:
+            os.mkdir(pipeline_dir)
+        os.sync()
+        dist.barrier()
+        self.assertTrue(os.path.exists(pipeline_dir))
+        self.save_with_pipeline(pipeline_dir)
+        self.load_with_fsdp(pipeline_dir)
diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py
index 4386af9..8523ca1 100644
--- a/torch/distributed/checkpoint/state_dict.py
+++ b/torch/distributed/checkpoint/state_dict.py
@@ -354,6 +354,11 @@
             fqns = _get_fqns(model, key)
             for fqn in fqns:
                 state_dict.pop(fqn)
+
+    for key, p in list(state_dict.items()):
+        if p.is_meta:
+            state_dict.pop(key)
+
     return state_dict