blob: f5950e5476ec9b084a7351b5420f8ccbf4e50195 [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
from unittest.mock import patch
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestDCPCompatbility(TestCase):
def test_metadata(self) -> None:
# Ensure that all the new fields of all the metadata have the default
# values so that we can always deserialize from a legacy metadata.
try:
tensor = torch.zeros(4, 4)
chunk_meta = ChunkStorageMetadata(
torch.Size((1, 1)),
torch.Size((1, 1)),
)
tensor_meta = TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
chunks=[chunk_meta],
)
b_meta = BytesStorageMetadata()
_ = Metadata(state_dict_metadata={"a": tensor_meta, "b": b_meta})
_ = MetadataIndex(fqn="a.b.c")
except Exception as e:
raise RuntimeError(
"The change may break the BC of distributed checkpoint."
) from e
def test_sharded_tensor_dependency(self) -> None:
# Ensure that we can load the existing DCP checkpoints back even if the
# metadata contain # _shard.sharded_tensor.metadata.
from torch.distributed._shard.sharded_tensor.metadata import (
TensorProperties as stp,
)
with patch("torch.distributed.checkpoint.metadata.TensorProperties", stp):
dcp.save(
{"a": torch.zeros(4, 4)},
dcp.FileSystemWriter("/tmp/dcp_testing"),
)
dcp.load(
{"a": torch.zeros(4, 4)},
dcp.FileSystemReader("/tmp/dcp_testing"),
)
if __name__ == "__main__":
run_tests()