[FSDP][Easy] Fix return in docstrings

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75327

Approved by: https://github.com/rohan-varma
diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py
index 2c6d321..6b38720 100644
--- a/torch/distributed/fsdp/fully_sharded_data_parallel.py
+++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py
@@ -2405,11 +2405,11 @@
                 ``model.parameters()``. (Default: ``None``)
 
         Returns:
-            full_osd (Dict[str, Any]): A :class:`dict` containing the optimizer
-                state for ``model`` 's original unflattened parameters and
-                including keys "state" and "param_groups" following the
-                convention of :meth:`torch.optim.Optimizer.state_dict` if on
-                rank 0, and an empty :class:`dict` otherwise.
+            Dict[str, Any]: A :class:`dict` containing the optimizer state for
+            ``model`` 's original unflattened parameters and including keys
+            "state" and "param_groups" following the convention of
+            :meth:`torch.optim.Optimizer.state_dict` if on rank 0, and an empty
+            :class:`dict` otherwise.
         """
         osd = optim.state_dict()
         osd_state, osd_param_groups = osd["state"], osd["param_groups"]  # alias
@@ -2535,10 +2535,9 @@
                 ``model.parameters()``. (Default: ``None``)
 
         Returns:
-            sharded_optim_state_dict (Dict[str, Any]): The full optimizer
-                state dict remapped to flattened parameters instead of
-                unflattened parameters and restricted to only include this
-                rank's part of the optimizer state.
+            Dict[str, Any]: The full optimizer state dict remapped to flattened
+            parameters instead of unflattened parameters and restricted to only
+            include this rank's part of the optimizer state.
         """
         full_osd = full_optim_state_dict  # alias
         if "state" not in full_osd or "param_groups" not in full_osd:
@@ -2637,8 +2636,8 @@
             >>> wrapped_optim.load_state_dict(sharded_osd)
 
         Returns:
-            rekeyed_osd (Dict[str, Any]): The optimizer state dict re-keyed
-                using the parameter keys specified by ``optim_state_key_type``.
+            Dict[str, Any]: The optimizer state dict re-keyed using the
+            parameter keys specified by ``optim_state_key_type``.
         """
         assert optim_state_key_type in \
             (OptimStateKeyType.PARAM_NAME, OptimStateKeyType.PARAM_ID)
diff --git a/torch/distributed/fsdp/optim_utils.py b/torch/distributed/fsdp/optim_utils.py
index 1a77b75..fcc16c7 100644
--- a/torch/distributed/fsdp/optim_utils.py
+++ b/torch/distributed/fsdp/optim_utils.py
@@ -52,13 +52,12 @@
             in the "state" part of the optimizer state dict.
 
     Returns:
-        unflat_param_state (List[Dict[str, Any]]): A :class:`list` holding
-            the entries in the "state" part of the optimizer state dict
-            corresponding to the unflattened parameters comprising the
-            flattened parameter ``flat_param`` if on the target rank or an
-            empty :class:`list` otherwise. The final optimizer state dict will
-            need to map these entries using the proper unflattened parameter
-            IDs.
+        List[Dict[str, Any]]: A :class:`list` holding the entries in the
+        "state" part of the optimizer state dict corresponding to the
+        unflattened parameters comprising the flattened parameter
+        ``flat_param`` if on the target rank or an empty :class:`list`
+        otherwise. The final optimizer state dict will need to map these
+        entries using the proper unflattened parameter IDs.
     """
     assert sum(p is flat_param for p in fsdp_module.params) == 1, \
         "`fsdp_module` must own `flat_param`"
@@ -90,12 +89,12 @@
 
     Args:
         flat_param (FlatParameter): The flattened parameter.
-        flat_param_state (Dict[str, Any]): The entry in the "state" part of
-        the optimizer state dict corresponding to the flattened parameter.
+        flat_param_state (Dict[str, Any]): The entry in the "state" part of the
+            optimizer state dict corresponding to the flattened parameter.
 
     Returns:
-        state (ConsolidatedOptimState): Consolidated optimizer state for
-            ``flat_param``; the state is not populated for non-target ranks.
+        ConsolidatedOptimState: Consolidated optimizer state for
+        ``flat_param``; the state is not populated for non-target ranks.
     """
     param_index = -1
     for i, param in enumerate(fsdp_module.params):
@@ -158,12 +157,11 @@
         state (ConsolidatedOptimState): Consolidated optimizer state.
 
     Returns:
-        unflat_param_state (List[Dict[str, Any]]): A :class:`list` holding
-            the entries in the "state" part of the optimizer state dict
-            corresponding to the unflattened parameters comprising the
-            flattened parameter ``flat_param``. The final optimizer state dict
-            will need to map these entries using the proper unflattened
-            parameter IDs.
+        List[Dict[str, Any]]: A :class:`list` holding the entries in the
+        "state" part of the optimizer state dict corresponding to the
+        unflattened parameters comprising the flattened parameter
+        ``flat_param``. The final optimizer state dict will need to map these
+        entries using the proper unflattened parameter IDs.
     """
     assert sum(p is flat_param for p in fsdp_module.params) == 1, \
         "`fsdp_module` must own `flat_param`"
@@ -216,10 +214,10 @@
         flat_param (FlatParameter): The flattened parameter.
 
     Returns:
-        flat_state (Dict[str, Any]): A :class:`dict` mapping state names to
-            their values for a particular flattened parameter. The sharded
-            optimizer state dict's "state" part will map the flattened
-            parameter ID to this returned value.
+        Dict[str, Any]: A :class:`dict` mapping state names to their values for
+        a particular flattened parameter. The sharded optimizer state dict's
+        "state" part will map the flattened parameter ID to this returned
+        value.
     """
     num_unflat_params = len(unflat_param_names)
     assert num_unflat_params > 0, \
@@ -339,11 +337,10 @@
         flat_param (FlatParameter): The flattened parameter.
 
     Returns:
-        flat_tensor (torch.Tensor): A flattened tensor containing the optimizer
-            state corresponding to ``state_name`` constructed by concatenating
-            the unflattened parameter tensor states in ``pos_dim_tensors``
-            (using zero tensors for any unflattened parameters without the
-            state).
+        torch.Tensor: A flattened tensor containing the optimizer state
+        corresponding to ``state_name`` constructed by concatenating the
+        unflattened parameter tensor states in ``pos_dim_tensors`` (using zero
+        tensors for any unflattened parameters without the state).
     """
     non_none_tensors = [t for t in pos_dim_tensors if t is not None]
     # Check that all are tensors with the same dtype
@@ -429,9 +426,9 @@
             parameter names corresponding to the single flattened parameter.
 
     Returns:
-        zero_dim_tensor (torch.Tensor): A zero-dimensional tensor giving the
-            value of the state ``state_name`` for all unflattened parameters
-            corresponding to the names ``unflat_param_names``.
+        torch.Tensor: A zero-dimensional tensor giving the value of the state
+        ``state_name`` for all unflattened parameters corresponding to the
+        names ``unflat_param_names``.
     """
     non_none_tensors = [t for t in zero_dim_tensors if t is not None]
     # Enforce that all have the same value and dtype
@@ -472,9 +469,9 @@
             parameter names corresponding to the single flattened parameter.
 
     Returns:
-        non_tensor (Any): A non-tensor giving the value of the state
-            ``state_name`` for all unflattened parameters corresponding to the
-            names ``unflat_param_names``.
+        Any: A non-tensor giving the value of the state ``state_name`` for all
+        unflattened parameters corresponding to the names
+        ``unflat_param_names``.
     """
     non_none_non_tensors = [nt for nt in non_tensors if nt is not None]
     # Enforce that all have the same value (same type already checked)
@@ -538,9 +535,8 @@
             input was ``model.parameters()``. (Default: ``None``)
 
     Returns:
-        param_id_to_param (List[torch.nn.Parameter]): Mapping from parameter
-            IDs to parameters, where the parameter ID is implicitly the index
-            in the :class:`list`.
+        List[torch.nn.Parameter]: Mapping from parameter IDs to parameters,
+        where the parameter ID is implicitly the index in the :class:`list`.
     """
     # Assume the standard case of passing `model.parameters()` to the optimizer
     # if `optim_input` is not specified
@@ -610,9 +606,9 @@
             unflattened parameter IDs.
 
     Returns:
-        unflat_to_flat_param_ids (List[int]): A mapping from unflattened
-            parameter ID to flattened parameter ID, where the unflattened
-            parameter ID is the index in the :class:`list`.
+        List[int]: A mapping from unflattened parameter ID to flattened
+        parameter ID, where the unflattened parameter ID is the index in the
+        :class:`list`.
     """
     # Construct as a dict and then convert to list
     unflat_to_flat_param_ids = {}