[FSDP] Provide a utility API to allow users easily to set state_dict_type (#73716)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73716

fsdp_state_dict_type() traverses the modules and calls state_dict_type when it is necessary, eliminating the burden from the users.
ghstack-source-id: 150775052

Test Plan: CI

Reviewed By: rohan-varma

Differential Revision: D34532321

fbshipit-source-id: 339dcb14a692d5d83c266892644393464d77db40
(cherry picked from commit bb4634c77d6b46644a7daad005ff52738753221a)
diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py
index 86734a1..b9e43ab 100644
--- a/test/distributed/fsdp/test_fsdp_state_dict.py
+++ b/test/distributed/fsdp/test_fsdp_state_dict.py
@@ -185,7 +185,7 @@
         except KeyError:
             raise ValueError(f"No state_dict type for {state_dict_type}")
 
-        with model.state_dict_type(enum_val):
+        with FSDP.state_dict_type(model, enum_val):
             return model.state_dict()
 
     @staticmethod
@@ -197,7 +197,7 @@
         except KeyError:
             raise ValueError(f"No state_dict for {state_dict_type}")
 
-        with model.state_dict_type(enum_val):
+        with FSDP.state_dict_type(model, enum_val):
             return model.load_state_dict(state_dict)
 
     def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""):
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 85b415d..f32fa7b 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -766,12 +766,15 @@
         else:
             return False
 
+    @staticmethod
     @contextlib.contextmanager
-    def state_dict_type(self, state_dict_type: StateDictType) -> Generator:
+    def state_dict_type(module: nn.Module, state_dict_type: StateDictType) -> Generator:
         """
-        A context manager to set the state_dict_type of this FSDP module and
-        its descendant FSDP modules.
-        .. note:: This API should be called for only the root FSDP module.
+        A context manager to set the state_dict_type of all the descendant FSDP
+        modules of the target module. The target module does not have to be a FSDP
+        module. If the target module is a FSDP module, its state_dict_type will
+        also be changed.
+        .. note:: This API should be called for only the top-level (root) module.
         .. note:: The default state_dict_type is StateDictTyp.FULL_STATE_DICT.
         .. note:: This API enables users to transparently use the conventional
         ``state_dict`` API to take model checkpoints in cases where the root
@@ -785,13 +788,11 @@
         Args:
             state_dict_type (StateDictType): the desired state_dict_type to set.
         """
-        if not self.check_is_root():
-            raise RuntimeError(
-                "state_dict_type context manager can only be called from the root FSDP module."
-            )
-        prev_state_dict_type = self._state_dict_type
-        for module in self.fsdp_modules(self):
-            if module._state_dict_type != prev_state_dict_type:
+        prev_state_dict_type = None
+        for module in FullyShardedDataParallel.fsdp_modules(module):
+            if prev_state_dict_type is None:
+                prev_state_dict_type = module._state_dict_type
+            if prev_state_dict_type != module._state_dict_type:
                 raise RuntimeError(
                     "All FSDP module should the same state_dict_type."
                 )
@@ -799,7 +800,8 @@
         try:
             yield
         finally:
-            for module in self.fsdp_modules(self):
+            assert prev_state_dict_type is not None  # Avoid mypy warning
+            for module in FullyShardedDataParallel.fsdp_modules(module):
                 module._state_dict_type = prev_state_dict_type
 
     def _full_post_state_dict_hook(
@@ -891,7 +893,7 @@
         ``state_dict`` on every rank, which could result in OOM if the model
         cannot fit on a single GPU. As a result, :func:`state_dict_type` API is
         available to configure between `state_dict` implementations. User can
-        thus use `with self.state_dict_type(StateDictType.LOCAL_STATE_DICT)`
+        thus use `with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)`
         context manager to perform a local checkpoint that will store only local
         shards of the module. Currently, the only supported implementations are
         ``StateDictType.LOCAL_STATE_DICT`` and ``StateDictType.FULL_STATE_DICT``
@@ -948,7 +950,7 @@
         sharded, so the resulting state_dict can only be loaded after the module
         has been wrapped with FSDP.
         """
-        with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+        with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT):
             return self.state_dict(*args, **kwargs)
 
     def _full_pre_load_state_dict_hook(
@@ -1027,7 +1029,7 @@
         FSDP to load the full parameter context on each rank which could result
         in GPU OOM. As a result, :func:`state_dict_type` API is available to
         configure between `load_state_dict` implementations. User can thus use
-        ``with self.state_dict_type(StateDictType.LOCAL_STATE_DICT)`` context
+        ``with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT)`` context
         manager to load a local state dict checkpoint that will restore only
         local shards of the module. Currently, the only supported
         implementations are ``StateDictType.LOCAL_STATE_DICT`` and
@@ -1079,7 +1081,7 @@
         """
         Load states from a flatten, sharded state dictionary.
         """
-        with self.state_dict_type(StateDictType.LOCAL_STATE_DICT):
+        with self.state_dict_type(self, StateDictType.LOCAL_STATE_DICT):
             return self.load_state_dict(state_dict, *args)
 
     def forward(self, *args: Any, **kwargs: Any) -> Any: