[FSDP] Code simplification
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75227
Small change to reduce code duplication.
Differential Revision: [D35374608](https://our.internmc.facebook.com/intern/diff/D35374608/)
Approved by: https://github.com/zhaojuanmao
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index ecc64a9..2c6d321 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -1193,20 +1193,12 @@
self._lazy_init()
if self._state_dict_type == StateDictType.FULL_STATE_DICT:
- if self.training_state != TrainingState_.SUMMON_FULL_PARAMS:
- with self.summon_full_params(recurse=False, writeback=False):
- # Since buffers are not sharded and stay casted, restore them to their
- # original user module specified types for checkpoint. We take care to
- # recast in post_state_dict_hook for consistency with the fact that
- # buffers stay casted after forward/backward. We must have the
- # call here instead of above because summon_full_params itself
- # calls _lazy_init() which would cast the buffers.
- if self.mixed_precision is not None and self._is_root:
- self._cast_buffers(
- dtype=self._orig_buffer_dtypes, recurse=False
- )
- state_dict = super().state_dict(*args, **kwargs)
- else:
+ summon_ctx = (
+ self.summon_full_params(recurse=False, writeback=False)
+ if self.training_state != TrainingState_.SUMMON_FULL_PARAMS else
+ contextlib.suppress()
+ )
+ with summon_ctx:
# Since buffers are not sharded and stay casted, restore them to their
# original user module specified types for checkpoint. We take care to
# recast in post_state_dict_hook for consistency with the fact that