[FSDP][Easy] Fix return in docstrings
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75327
Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 2c6d321..6b38720 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -2405,11 +2405,11 @@
``model.parameters()``. (Default: ``None``)
Returns:
- full_osd (Dict[str, Any]): A :class:`dict` containing the optimizer
- state for ``model`` 's original unflattened parameters and
- including keys "state" and "param_groups" following the
- convention of :meth:`torch.optim.Optimizer.state_dict` if on
- rank 0, and an empty :class:`dict` otherwise.
+ Dict[str, Any]: A :class:`dict` containing the optimizer state for
+ ``model`` 's original unflattened parameters and including keys
+ "state" and "param_groups" following the convention of
+ :meth:`torch.optim.Optimizer.state_dict` if on rank 0, and an empty
+ :class:`dict` otherwise.
"""
osd = optim.state_dict()
osd_state, osd_param_groups = osd["state"], osd["param_groups"] # alias
@@ -2535,10 +2535,9 @@
``model.parameters()``. (Default: ``None``)
Returns:
- sharded_optim_state_dict (Dict[str, Any]): The full optimizer
- state dict remapped to flattened parameters instead of
- unflattened parameters and restricted to only include this
- rank's part of the optimizer state.
+ Dict[str, Any]: The full optimizer state dict remapped to flattened
+ parameters instead of unflattened parameters and restricted to only
+ include this rank's part of the optimizer state.
"""
full_osd = full_optim_state_dict # alias
if "state" not in full_osd or "param_groups" not in full_osd:
@@ -2637,8 +2636,8 @@
>>> wrapped_optim.load_state_dict(sharded_osd)
Returns:
- rekeyed_osd (Dict[str, Any]): The optimizer state dict re-keyed
- using the parameter keys specified by ``optim_state_key_type``.
+ Dict[str, Any]: The optimizer state dict re-keyed using the
+ parameter keys specified by ``optim_state_key_type``.
"""
assert optim_state_key_type in \
(OptimStateKeyType.PARAM_NAME, OptimStateKeyType.PARAM_ID)
diff --git a/torch/distributed/fsdp/optim_utils.py b/torch/distributed/fsdp/optim_utils.py
index 1a77b75..fcc16c7 100644
--- a/torch/distributed/fsdp/optim_utils.py
+++ b/torch/distributed/fsdp/optim_utils.py
@@ -52,13 +52,12 @@
in the "state" part of the optimizer state dict.
Returns:
- unflat_param_state (List[Dict[str, Any]]): A :class:`list` holding
- the entries in the "state" part of the optimizer state dict
- corresponding to the unflattened parameters comprising the
- flattened parameter ``flat_param`` if on the target rank or an
- empty :class:`list` otherwise. The final optimizer state dict will
- need to map these entries using the proper unflattened parameter
- IDs.
+ List[Dict[str, Any]]: A :class:`list` holding the entries in the
+ "state" part of the optimizer state dict corresponding to the
+ unflattened parameters comprising the flattened parameter
+ ``flat_param`` if on the target rank or an empty :class:`list`
+ otherwise. The final optimizer state dict will need to map these
+ entries using the proper unflattened parameter IDs.
"""
assert sum(p is flat_param for p in fsdp_module.params) == 1, \
"`fsdp_module` must own `flat_param`"
@@ -90,12 +89,12 @@
Args:
flat_param (FlatParameter): The flattened parameter.
- flat_param_state (Dict[str, Any]): The entry in the "state" part of
- the optimizer state dict corresponding to the flattened parameter.
+ flat_param_state (Dict[str, Any]): The entry in the "state" part of the
+ optimizer state dict corresponding to the flattened parameter.
Returns:
- state (ConsolidatedOptimState): Consolidated optimizer state for
- ``flat_param``; the state is not populated for non-target ranks.
+ ConsolidatedOptimState: Consolidated optimizer state for
+ ``flat_param``; the state is not populated for non-target ranks.
"""
param_index = -1
for i, param in enumerate(fsdp_module.params):
@@ -158,12 +157,11 @@
state (ConsolidatedOptimState): Consolidated optimizer state.
Returns:
- unflat_param_state (List[Dict[str, Any]]): A :class:`list` holding
- the entries in the "state" part of the optimizer state dict
- corresponding to the unflattened parameters comprising the
- flattened parameter ``flat_param``. The final optimizer state dict
- will need to map these entries using the proper unflattened
- parameter IDs.
+ List[Dict[str, Any]]: A :class:`list` holding the entries in the
+ "state" part of the optimizer state dict corresponding to the
+ unflattened parameters comprising the flattened parameter
+ ``flat_param``. The final optimizer state dict will need to map these
+ entries using the proper unflattened parameter IDs.
"""
assert sum(p is flat_param for p in fsdp_module.params) == 1, \
"`fsdp_module` must own `flat_param`"
@@ -216,10 +214,10 @@
flat_param (FlatParameter): The flattened parameter.
Returns:
- flat_state (Dict[str, Any]): A :class:`dict` mapping state names to
- their values for a particular flattened parameter. The sharded
- optimizer state dict's "state" part will map the flattened
- parameter ID to this returned value.
+ Dict[str, Any]: A :class:`dict` mapping state names to their values for
+ a particular flattened parameter. The sharded optimizer state dict's
+ "state" part will map the flattened parameter ID to this returned
+ value.
"""
num_unflat_params = len(unflat_param_names)
assert num_unflat_params > 0, \
@@ -339,11 +337,10 @@
flat_param (FlatParameter): The flattened parameter.
Returns:
- flat_tensor (torch.Tensor): A flattened tensor containing the optimizer
- state corresponding to ``state_name`` constructed by concatenating
- the unflattened parameter tensor states in ``pos_dim_tensors``
- (using zero tensors for any unflattened parameters without the
- state).
+ torch.Tensor: A flattened tensor containing the optimizer state
+ corresponding to ``state_name`` constructed by concatenating the
+ unflattened parameter tensor states in ``pos_dim_tensors`` (using zero
+ tensors for any unflattened parameters without the state).
"""
non_none_tensors = [t for t in pos_dim_tensors if t is not None]
# Check that all are tensors with the same dtype
@@ -429,9 +426,9 @@
parameter names corresponding to the single flattened parameter.
Returns:
- zero_dim_tensor (torch.Tensor): A zero-dimensional tensor giving the
- value of the state ``state_name`` for all unflattened parameters
- corresponding to the names ``unflat_param_names``.
+ torch.Tensor: A zero-dimensional tensor giving the value of the state
+ ``state_name`` for all unflattened parameters corresponding to the
+ names ``unflat_param_names``.
"""
non_none_tensors = [t for t in zero_dim_tensors if t is not None]
# Enforce that all have the same value and dtype
@@ -472,9 +469,9 @@
parameter names corresponding to the single flattened parameter.
Returns:
- non_tensor (Any): A non-tensor giving the value of the state
- ``state_name`` for all unflattened parameters corresponding to the
- names ``unflat_param_names``.
+ Any: A non-tensor giving the value of the state ``state_name`` for all
+ unflattened parameters corresponding to the names
+ ``unflat_param_names``.
"""
non_none_non_tensors = [nt for nt in non_tensors if nt is not None]
# Enforce that all have the same value (same type already checked)
@@ -538,9 +535,8 @@
input was ``model.parameters()``. (Default: ``None``)
Returns:
- param_id_to_param (List[torch.nn.Parameter]): Mapping from parameter
- IDs to parameters, where the parameter ID is implicitly the index
- in the :class:`list`.
+ List[torch.nn.Parameter]: Mapping from parameter IDs to parameters,
+ where the parameter ID is implicitly the index in the :class:`list`.
"""
# Assume the standard case of passing `model.parameters()` to the optimizer
# if `optim_input` is not specified
@@ -610,9 +606,9 @@
unflattened parameter IDs.
Returns:
- unflat_to_flat_param_ids (List[int]): A mapping from unflattened
- parameter ID to flattened parameter ID, where the unflattened
- parameter ID is the index in the :class:`list`.
+ List[int]: A mapping from unflattened parameter ID to flattened
+ parameter ID, where the unflattened parameter ID is the index in the
+ :class:`list`.
"""
# Construct as a dict and then convert to list
unflat_to_flat_param_ids = {}