[shard] make state_dict hook be consistent

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79650

for root module we shouldn't accidentally add a "." for state_dict keys, it
should be empty instead to match the module.state_dict behavior

Differential Revision: [D37191203](https://our.internmc.facebook.com/intern/diff/D37191203/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D37191203/)!

Approved by: https://github.com/pritamdamania87, https://github.com/fduwjj
diff --git a/test/distributed/_shard/checkpoint/test_checkpoint.py b/test/distributed/_shard/checkpoint/test_checkpoint.py
index 88ade9f..3e562d3 100644
--- a/test/distributed/_shard/checkpoint/test_checkpoint.py
+++ b/test/distributed/_shard/checkpoint/test_checkpoint.py
@@ -148,12 +148,12 @@
 
         # we make the first stored shard smaller
         self.assertTrue(
-            ".sharded" in metadata.state_dict_metadata,
+            "sharded" in metadata.state_dict_metadata,
             f"keys: {metadata.state_dict_metadata.keys()}",
         )
 
         sizes = (
-            metadata.state_dict_metadata[".sharded"]
+            metadata.state_dict_metadata["sharded"]
             .storage_metadata[0]
             .shard_metadata.shard_sizes
         )
@@ -172,12 +172,12 @@
 
         # we make the first stored shard smaller
         self.assertTrue(
-            ".sharded" in metadata.state_dict_metadata,
+            "sharded" in metadata.state_dict_metadata,
             f"keys: {metadata.state_dict_metadata.keys()}",
         )
 
         sizes = (
-            metadata.state_dict_metadata[".sharded"]
+            metadata.state_dict_metadata["sharded"]
             .storage_metadata[0]
             .shard_metadata.shard_sizes
         )
@@ -197,12 +197,12 @@
 
         metadata = self.gen_metadata()
         regular = metadata.state_dict_metadata["regular"]
-        metadata.state_dict_metadata[".sharded"] = regular
+        metadata.state_dict_metadata["sharded"] = regular
         with self.assertRaisesRegex(ValueError, "ShardedTensorStorageMetadata but found"):
             validate_metadata(module.state_dict(), metadata)
 
         metadata = self.gen_metadata()
-        sharded = metadata.state_dict_metadata[".sharded"]
+        sharded = metadata.state_dict_metadata["sharded"]
         metadata.state_dict_metadata["regular"] = sharded
         with self.assertRaisesRegex(ValueError, "TensorStorageMetadata but found"):
             validate_metadata(module.state_dict(), metadata)
diff --git a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
index 8192f24..07d2a9e 100644
--- a/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
+++ b/test/distributed/_shard/sharded_tensor/test_sharded_tensor.py
@@ -1098,7 +1098,11 @@
         # Test save
         m._register_state_dict_hook(state_dict_hook)
         buffer = io.BytesIO()
-        torch.save(m.state_dict(), buffer)
+        mod_state_dict = m.state_dict()
+        mod_state_keys = mod_state_dict.keys()
+        self.assertTrue("sharded_tensor1" in mod_state_keys)
+        self.assertTrue("submodule.sharded_tensor2" in mod_state_keys)
+        torch.save(mod_state_dict, buffer)
 
         # Test load.
         module_load = MyShardedModel1()
@@ -1108,6 +1112,10 @@
         state_dict_deser = torch.load(buffer)
         module_load.load_state_dict(state_dict_deser, strict=False)
 
+        module_load._register_state_dict_hook(state_dict_hook)
+        loaded_dict_keys = module_load.state_dict().keys()
+        self.assertTrue("sharded_tensor1" in loaded_dict_keys)
+        self.assertTrue("submodule.sharded_tensor2" in loaded_dict_keys)
         # Verify after load.
         self.assertTrue(torch.equal(m.sharded_tensor1, module_load.sharded_tensor1))
         self.assertTrue(torch.equal(m.submodule.sharded_tensor2, module_load.submodule.sharded_tensor2))
diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py
index 2457aa2..2d5d1cb 100644
--- a/torch/distributed/_shard/sharded_tensor/__init__.py
+++ b/torch/distributed/_shard/sharded_tensor/__init__.py
@@ -399,7 +399,9 @@
     for submodule_name, submodule in module.named_modules():
         for attr_name, attr in submodule.__dict__.items():
             if isinstance(attr, ShardedTensor):
-                destination[prefix + submodule_name + '.' + attr_name] = attr
+                mod_prefix = prefix + submodule_name
+                key = mod_prefix + ('.' if mod_prefix else '') + attr_name
+                destination[key] = attr
 
 def pre_load_state_dict_hook(module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
     """
@@ -407,7 +409,8 @@
     """
     for submodule_name, submodule in module.named_modules():
         for attr_name, attr in submodule.__dict__.items():
-            key = prefix + submodule_name + '.' + attr_name
+            mod_prefix = prefix + submodule_name
+            key = mod_prefix + ('.' if mod_prefix else '') + attr_name
             if key in state_dict:
                 if isinstance(state_dict[key], ShardedTensor):
                     setattr(submodule, attr_name, state_dict[key])