[shard] make state_dict hook be consistent
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79650
for root module we shouldn't accidentally add a "." for state_dict keys, it
should be empty instead to match the module.state_dict behavior
Differential Revision: [D37191203](https://our.internmc.facebook.com/intern/diff/D37191203/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D37191203/)!
Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
diff --git a/test/distributed/_shard/checkpoint/test_checkpoint.py b/test/distributed/_shard/checkpoint/test_checkpoint.py
index 88ade9f..3e562d3 100644
--- a/test/distributed/_shard/checkpoint/test_checkpoint.py
+++ b/test/distributed/_shard/checkpoint/test_checkpoint.py
@@ -148,12 +148,12 @@
# we make the first stored shard smaller
self.assertTrue(
- ".sharded" in metadata.state_dict_metadata,
+ "sharded" in metadata.state_dict_metadata,
f"keys: {metadata.state_dict_metadata.keys()}",
)
sizes = (
- metadata.state_dict_metadata[".sharded"]
+ metadata.state_dict_metadata["sharded"]
.storage_metadata[0]
.shard_metadata.shard_sizes
)
@@ -172,12 +172,12 @@
# we make the first stored shard smaller
self.assertTrue(
- ".sharded" in metadata.state_dict_metadata,
+ "sharded" in metadata.state_dict_metadata,
f"keys: {metadata.state_dict_metadata.keys()}",
)
sizes = (
- metadata.state_dict_metadata[".sharded"]
+ metadata.state_dict_metadata["sharded"]
.storage_metadata[0]
.shard_metadata.shard_sizes
)
@@ -197,12 +197,12 @@
metadata = self.gen_metadata()
regular = metadata.state_dict_metadata["regular"]
- metadata.state_dict_metadata[".sharded"] = regular
+ metadata.state_dict_metadata["sharded"] = regular
with self.assertRaisesRegex(ValueError, "ShardedTensorStorageMetadata but found"):
validate_metadata(module.state_dict(), metadata)
metadata = self.gen_metadata()
- sharded = metadata.state_dict_metadata[".sharded"]
+ sharded = metadata.state_dict_metadata["sharded"]
metadata.state_dict_metadata["regular"] = sharded
with self.assertRaisesRegex(ValueError, "TensorStorageMetadata but found"):
validate_metadata(module.state_dict(), metadata)
diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
index 8192f24..07d2a9e 100644
--- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
+++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
@@ -1098,7 +1098,11 @@
# Test save
m._register_state_dict_hook(state_dict_hook)
buffer = io.BytesIO()
- torch.save(m.state_dict(), buffer)
+ mod_state_dict = m.state_dict()
+ mod_state_keys = mod_state_dict.keys()
+ self.assertTrue("sharded_tensor1" in mod_state_keys)
+ self.assertTrue("submodule.sharded_tensor2" in mod_state_keys)
+ torch.save(mod_state_dict, buffer)
# Test load.
module_load = MyShardedModel1()
@@ -1108,6 +1112,10 @@
state_dict_deser = torch.load(buffer)
module_load.load_state_dict(state_dict_deser, strict=False)
+ module_load._register_state_dict_hook(state_dict_hook)
+ loaded_dict_keys = module_load.state_dict().keys()
+ self.assertTrue("sharded_tensor1" in loaded_dict_keys)
+ self.assertTrue("submodule.sharded_tensor2" in loaded_dict_keys)
# Verify after load.
self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1))
self.assertTrue(torch.equal(m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2))
diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py
index 2457aa2..2d5d1cb 100644
--- a/torch/distributed/_shard/sharded_tensor/__init__.py
+++ b/torch/distributed/_shard/sharded_tensor/__init__.py
@@ -399,7 +399,9 @@
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
if isinstance(attr, ShardedTensor):
- destination[prefix + submodule_name + '.' + attr_name] = attr
+ mod_prefix = prefix + submodule_name
+ key = mod_prefix + ('.' if mod_prefix else '') + attr_name
+ destination[key] = attr
def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""
@@ -407,7 +409,8 @@
"""
for submodule_name, submodule in module.named_modules():
for attr_name, attr in submodule.__dict__.items():
- key = prefix + submodule_name + '.' + attr_name
+ mod_prefix = prefix + submodule_name
+ key = mod_prefix + ('.' if mod_prefix else '') + attr_name
if key in state_dict:
if isinstance(state_dict[key], ShardedTensor):
setattr(submodule, attr_name, state_dict[key])