[FSDP] Code simplification

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75227

Small change to reduce code duplication.

Differential Revision: [D35374608](https://our.internmc.facebook.com/intern/diff/D35374608/)

Approved by: https://github.com/zhaojuanmao
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index ecc64a9..2c6d321 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -1193,20 +1193,12 @@
 
         self._lazy_init()
         if self._state_dict_type == StateDictType.FULL_STATE_DICT:
-            if self.training_state != TrainingState_.SUMMON_FULL_PARAMS:
-                with self.summon_full_params(recurse=False, writeback=False):
-                    # Since buffers are not sharded and stay casted, restore them to their
-                    # original user module specified types for checkpoint. We take care to
-                    # recast in post_state_dict_hook for consistency with the fact that
-                    # buffers stay casted after forward/backward. We must have the
-                    # call here instead of above because summon_full_params itself
-                    # calls _lazy_init() which would cast the buffers.
-                    if self.mixed_precision is not None and self._is_root:
-                        self._cast_buffers(
-                            dtype=self._orig_buffer_dtypes, recurse=False
-                        )
-                    state_dict = super().state_dict(*args, **kwargs)
-            else:
+            summon_ctx = (
+                self.summon_full_params(recurse=False, writeback=False)
+                if self.training_state != TrainingState_.SUMMON_FULL_PARAMS else
+                contextlib.suppress()
+            )
+            with summon_ctx:
                 # Since buffers are not sharded and stay casted, restore them to their
                 # original user module specified types for checkpoint. We take care to
                 # recast in post_state_dict_hook for consistency with the fact that