blob: 0d221b8df1a0ac4e7efc9bcff71752f7178b9895 [file] [log] [blame]
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
import torch
import torch.distributed as dist
# Import the entire FSDP file to avoid circular imports
import torch.distributed.fsdp.fully_sharded_data_parallel as FSDP
from torch.distributed.fsdp.flatten_params_wrapper import FlatParameter
OPTIM_TARGET_RANK = 0 # rank on which to save full optimizer state
class ConsolidatedOptimState:
"""
This holds the consolidated optimizer state on the target rank. Positive-
dimension tensor state is communicated across ranks, while zero-dimension
tensor state and non-tensor state is taken directly from the target rank.
PyTorch version 1.12 moved to using zero-dimension tensors for scalar
values, but user implemented optimizers may still use float (i.e. a
non-tensor). Thus, we support both and handle them identically.
Attributes:
tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension
tensor state name to the unsharded flattened tensor representing
the state.
zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero-
dimension tensor state name to its value.
non_tensor_state (Dict[str, Any]): Mapping from non-tensor state
name to its value.
"""
tensor_state: Dict[str, torch.Tensor] = {}
zero_dim_tensor_state: Dict[str, torch.Tensor] = {}
non_tensor_state: Dict[str, Any] = {}
def _unflatten_optim_state(
fsdp_module,
flat_param: FlatParameter,
flat_param_state: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""
Unflattens the optimizer state, consisting of the "state" part and the
"param_groups" part. Unflattening the "state" part involves consolidating
the state on the target rank and remapping from flattened to unflattened
parameter IDs, and the "param_groups" part only involves remapping from
flattened to unflattened parameter IDs.
Args:
fsdp_module (FullyShardedDataParallel): FSDP module that owns
``flat_param``, i.e. holds it in ``self.params``.
flat_param (FlatParameter): The flattened parameter.
flat_param_state (Dict[str, Any]): Entry for the flattened parameter
in the "state" part of the optimizer state dict.
Returns:
unflat_param_state (List[Dict[str, Any]]): A :class:`list` holding
the entries in the "state" part of the optimizer state dict
corresponding to the unflattened parameters comprising the
flattened parameter ``flat_param`` if on the target rank or an
empty :class:`list` otherwise. The final optimizer state dict will
need to map these entries using the proper unflattened parameter
IDs.
"""
assert sum(p is flat_param for p in fsdp_module.params) == 1, \
"`fsdp_module` must own `flat_param`"
consolidated_state = _communicate_optim_state(
fsdp_module, flat_param, flat_param_state,
)
to_save = fsdp_module.rank == OPTIM_TARGET_RANK
unflat_param_state = _unflatten_communicated_optim_state(
fsdp_module,
flat_param,
consolidated_state,
) if to_save else []
return unflat_param_state
def _communicate_optim_state(
fsdp_module,
flat_param: FlatParameter,
flat_param_state: Dict[str, Any],
) -> ConsolidatedOptimState:
"""
Communicates the optimizer state for a flattened parameter ``flat_param``
across ranks so that the target rank holds the entire non-sharded optimizer
state.
If ``N`` is the number of tensor optimizer states in the optimizer state
dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1``
otherwise (where the plus 1 comes from all-gathering the padding per rank).
Args:
flat_param (FlatParameter): The flattened parameter.
flat_param_state (Dict[str, Any]): The entry in the "state" part of
the optimizer state dict corresponding to the flattened parameter.
Returns:
state (ConsolidatedOptimState): Consolidated optimizer state for
``flat_param``; the state is not populated for non-target ranks.
"""
param_index = -1
for i, param in enumerate(fsdp_module.params):
if param is flat_param:
param_index = i
break
assert param_index >= 0, "`fsdp_module` must own `flat_param`"
state = ConsolidatedOptimState()
tensor_state, zero_dim_tensor_state, non_tensor_state = \
state.tensor_state, state.zero_dim_tensor_state, state.non_tensor_state
process_group = fsdp_module.process_group
tensor_buffer = None # initialize lazily in case it is not needed
to_save = fsdp_module.rank == OPTIM_TARGET_RANK
for state_name, value in flat_param_state.items():
# Positive-dimension tensor state: communicate across ranks
if torch.is_tensor(value) and value.dim() > 0:
# If the parameter is not sharded (e.g. world size of 1), then
# neither is the positive-dimension tensor state, so no need to
# communicate it -- we take the target rank's value
if not flat_param._is_sharded:
tensor_state[state_name] = value.cpu()
continue
if tensor_buffer is None:
# Assume that positive-dimension tensor optimizer state
# has the same shape as the sharded flattened parameter
buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined]
tensor_buffer = value.new_zeros(*buffer_size)
dist._all_gather_base(tensor_buffer, value, group=process_group)
if to_save:
assert hasattr(flat_param, "_orig_size"), \
"Sharded flattened parameter should have `_orig_size` set"
unpadded_numel = flat_param._orig_size.numel() # type: ignore[attr-defined]
tensor_state[state_name] = tensor_buffer[:unpadded_numel].cpu()
# Zero-dimension tensor state and non-tensor state: take this rank's
# value directly (`deepcopy()`ing to avoid aliasing surprises)
elif to_save:
if _is_zero_dim_tensor(value):
zero_dim_tensor_state[state_name] = value
else:
non_tensor_state[state_name] = value
return state
def _unflatten_communicated_optim_state(
fsdp_module,
flat_param: FlatParameter,
state: ConsolidatedOptimState,
) -> List[Dict[str, Any]]:
"""
Unflattens the communicated optimizer state (given by ``tensor_state``,
``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flattened
parameter ``flat_param``. This should only be called on the target rank.
Args:
fsdp_module (FullyShardedDataParallel): FSDP module that owns
``flat_param``, i.e. holds it in ``self.params``.
flat_param (FlatParameter): The flattened parameter.
state (ConsolidatedOptimState): Consolidated optimizer state.
Returns:
unflat_param_state (List[Dict[str, Any]]): A :class:`list` holding
the entries in the "state" part of the optimizer state dict
corresponding to the unflattened parameters comprising the
flattened parameter ``flat_param``. The final optimizer state dict
will need to map these entries using the proper unflattened
parameter IDs.
"""
assert sum(p is flat_param for p in fsdp_module.params) == 1, \
"`fsdp_module` must own `flat_param`"
unflat_param_state: List[Dict[str, Any]] = []
flat_param_views: Dict[str, Iterator] = {}
num_unflat_params = flat_param._num_unflattened_params
tensor_state, zero_dim_tensor_state, non_tensor_state = \
state.tensor_state, state.zero_dim_tensor_state, state.non_tensor_state
for _ in range(num_unflat_params):
unflat_state_param = {}
# Add positive-dimension tensor state: unflatten with views
for state_name, flat_tensor in tensor_state.items():
views_generated = state_name in flat_param_views
if not views_generated:
param_views = flat_param.get_param_views(flat_tensor)
flat_param_views[state_name] = param_views
else:
param_views = flat_param_views[state_name]
unflat_state_param[state_name] = next(param_views)
# Add zero-dimension tensor state: take the target rank's value
for state_name, zero_dim_tensor in zero_dim_tensor_state.items():
unflat_state_param[state_name] = zero_dim_tensor
# Add non-tensor state: take the target rank's value
for state_name, non_tensor in non_tensor_state.items():
unflat_state_param[state_name] = non_tensor
unflat_param_state.append(unflat_state_param)
return unflat_param_state
def _flatten_optim_state(
unflat_osd_state: Dict[str, Dict[str, Any]],
unflat_param_names: List[str],
fsdp_module,
flat_param: FlatParameter,
) -> Dict[str, Any]:
"""
Flattens the optimizer state in ``full_optim_state_dict`` for a single
flattened parameter ``flat_param`` in ``fsdp_module`` corresponding to
the unflattened parameter names in ``unflat_param_names``.
Args:
unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the
optimizer state dict corresponding to the unflattened parameters.
unflat_param_names (List[str]): A :class:`list` of unflattened
parameter names corresponding to the flattened parameter
``flat_param``.
fsdp_module (FullyShardedDataParallel): FSDP module owning the
flattened parameter.
flat_param (FlatParameter): The flattened parameter.
Returns:
flat_state (Dict[str, Any]): A :class:`dict` mapping state names to
their values for a particular flattened parameter. The sharded
optimizer state dict's "state" part will map the flattened
parameter ID to this returned value.
"""
num_unflat_params = len(unflat_param_names)
assert num_unflat_params > 0, \
"Expects at least one unflattened parameter corresponding to the " \
"flattened parameter"
unflat_param_shapes = flat_param._param_shapes
num_unflat_param_shapes = len(unflat_param_shapes)
assert num_unflat_params == num_unflat_param_shapes, \
f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}"
# Check if these unflattened parameters have any optimizer state
has_state = [
bool(unflat_param_name in unflat_osd_state)
for unflat_param_name in unflat_param_names
]
# If none of the unflattened parameters comprising this flattened parameter
# have any state, then we do not want an entry in the optimizer state dict
if not any(has_state):
return {} # no need to flatten any state
# There may still be some unflattened parameters with state and some
# without
unflat_param_states = [
unflat_osd_state[unflat_param_name]
if unflat_param_name in unflat_osd_state else None
for unflat_param_name in unflat_param_names
]
# Check that the unflattened parameters have the same state names
state_names = None
for unflat_param_state in unflat_param_states:
if unflat_param_state is None:
continue
if state_names is None:
state_names = set(unflat_param_state.keys())
else:
if state_names != set(unflat_param_state.keys()):
raise ValueError(
"Differing optimizer state names for the unflattened "
f"parameters: {unflat_param_names}"
)
assert state_names is not None
# Flatten the state
flat_state: Dict[str, Any] = {}
for state_name in state_names:
state_values = [
unflat_param_state[state_name]
if unflat_param_state is not None else None
for unflat_param_state in unflat_param_states
]
non_none_state_values = [v for v in state_values if v is not None]
are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True
for v in non_none_state_values:
are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0
are_zero_dim_tensors &= _is_zero_dim_tensor(v)
are_non_tensors &= not torch.is_tensor(v)
types = set(type(v) for v in non_none_state_values)
if len(types) != 1 or not (
are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors
):
raise ValueError(
f"Differing optimizer state types for state {state_name}, "
f"values {non_none_state_values}, and unflattened parameter "
f"names {unflat_param_names}"
)
if are_pos_dim_tensors:
flat_tensor = _flatten_tensor_optim_state(
state_name, state_values, unflat_param_names,
unflat_param_shapes, flat_param,
)
# Shard the flattened tensor immediately to minimize the max memory
# usage
sharded_flat_tensor, _ = fsdp_module._get_shard(flat_tensor)
flat_state[state_name] = sharded_flat_tensor
elif are_zero_dim_tensors:
flat_state[state_name] = _flatten_zero_dim_tensor_optim_state(
state_name, state_values, unflat_param_names,
)
else:
assert are_non_tensors
flat_state[state_name] = _flatten_non_tensor_optim_state(
state_name, state_values, unflat_param_names,
)
return flat_state
def _flatten_tensor_optim_state(
state_name: str,
pos_dim_tensors: List[torch.Tensor],
unflat_param_names: List[str],
unflat_param_shapes: List[torch.Size],
flat_param: FlatParameter,
) -> torch.Tensor:
"""
Flattens the positive-dimension tensor optimizer state given by the values
``tensors`` for the state ``state_name`` for a single flattened parameter
``flat_param`` corresponding to the unflattened parameter names
``unflat_param_names`` and unflatted parameter shapes
``unflat_param_shapes``. This flattens each unflattened parameter's tensor
state into one tensor.
NOTE: We use zero tensors for any unflattened parameters without state
since some value is required to fill those entries. This assumes that the
zero tensor is mathematically equivalent to having no state, which is true
for Adam's ``exp_avg`` and ``exp_avg_sq`` but may not be true for all
optimizers.
Args:
state_name (str): Optimizer state name.
pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor
optimizer state values for the unflattened parameters corresponding
to the single flattened parameter.
unflat_param_names (List[str]): A :class:`list` of unflattened
parameter names corresponding to the single flattened parameter.
unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes
corresponding to the single flattened parameter.
flat_param (FlatParameter): The flattened parameter.
Returns:
flat_tensor (torch.Tensor): A flattened tensor containing the optimizer
state corresponding to ``state_name`` constructed by concatenating
the unflattened parameter tensor states in ``pos_dim_tensors``
(using zero tensors for any unflattened parameters without the
state).
"""
non_none_tensors = [t for t in pos_dim_tensors if t is not None]
# Check that all are tensors with the same dtype
dtypes = set(t.dtype for t in non_none_tensors)
if len(dtypes) != 1:
raise ValueError(
"All unflattened parameters comprising a single flattened "
"parameter must have positive-dimension tensor state with the "
f"same dtype but got dtypes {dtypes} for state {state_name} and "
f"unflattened parameter names {unflat_param_names}"
)
dtype = next(iter(dtypes))
# Check that each tensor state matches its parameter's shape
for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes):
if tensor is None and len(shape) == 0:
raise ValueError(
"Flattening a zero-dimension parameter is not supported"
)
elif tensor is not None and tensor.shape != shape:
raise ValueError(
"Tensor optimizer state does not have same shape as its "
f"parameter: {tensor.shape} {shape}"
)
# Flatten the tensor states
cpu_device = torch.device("cpu")
tensors = [
torch.flatten(state_value.to(cpu_device)) if state_value is not None
else torch.flatten(torch.zeros(
size=shape, dtype=dtype, device=cpu_device,
))
for state_value, shape
in zip(pos_dim_tensors, unflat_param_shapes)
]
padding = flat_param.num_padded
if padding > 0:
tensors.append(torch.zeros(padding, dtype=dtype, device=cpu_device))
flat_tensor = torch.cat(tensors)
# `flat_tensor`'s shape should be 1D and less than or equal to the
# flattened parameter's shape (where the inequality is strict for positive
# padding)
if not flat_param._is_sharded: # currently, only when world size is 1
# If the parameter is not sharded, then `_full_param_padded` is not
# used, so we skip the shape check
return flat_tensor
full_padded_dim = flat_param._full_param_padded.dim() # type: ignore[attr-defined]
full_padded_shape = flat_param._full_param_padded.shape # type: ignore[attr-defined]
assert flat_tensor.dim() == 1, \
f"`flat_tensor` should be 1D but got {flat_tensor.dim()} dims"
assert full_padded_dim == 1, \
f"`_full_param_padded` should be 1D but got {full_padded_dim} dims"
assert flat_tensor.shape[0] <= full_padded_shape[0], \
f"tensor optim state: {flat_tensor.shape} " \
f"parameter: {full_padded_shape}"
return flat_tensor
def _flatten_zero_dim_tensor_optim_state(
state_name: str,
zero_dim_tensors: List[torch.Tensor],
unflat_param_names: List[str],
) -> torch.Tensor:
"""
Flattens the zero-dimension tensor optimizer state given by the values
``zero_dim_tensors`` for the state ``state_name`` for a single flattened
parameter corresponding to the unflattened parameter names
``unflat_param_names`` by enforcing that all tensors are the same and using
that common value.
NOTE: The requirement that the tensors are the same across all unflattened
parameters comprising the flattened parameter is needed to maintain the
invariant that FSDP performs the same computation as its non-sharded
equivalent. This means that none of the unflattened parameters can be
missing this state since imposing a value may differ from having no value.
For example, for Adam's "step", no value means maximum bias correction,
while having some positive value means less bias correction.
Args:
state_name (str): Optimizer state name.
zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state
for the unflattened parameters corresponding to the single
flattened parameter.
unflat_param_names (List[str]): A :class:`list` of unflattened
parameter names corresponding to the single flattened parameter.
Returns:
zero_dim_tensor (torch.Tensor): A zero-dimensional tensor giving the
value of the state ``state_name`` for all unflattened parameters
corresponding to the names ``unflat_param_names``.
"""
non_none_tensors = [t for t in zero_dim_tensors if t is not None]
# Enforce that all have the same value and dtype
values_set = set(t.item() for t in zero_dim_tensors)
dtypes = set(t.dtype for t in zero_dim_tensors)
if len(non_none_tensors) != len(zero_dim_tensors) or \
len(values_set) != 1 or len(dtypes) != 1:
raise ValueError(
"All unflattened parameters comprising a single flattened "
"parameter must have scalar state with the same value and dtype "
f"but got values {values_set} and dtypes {dtypes} for state "
f"{state_name} and unflattened parameter names "
f"{unflat_param_names}"
)
value = next(iter(values_set))
dtype = next(iter(dtypes))
return torch.tensor(value, dtype=dtype, device=torch.device("cpu"))
def _flatten_non_tensor_optim_state(
state_name: str,
non_tensors: List[Any],
unflat_param_names: List[str],
) -> Any:
"""
Flattens the non-tensor optimizer state given by the values ``non_tensors``
for the state ``state_name`` for a single flattened parameter corresponding
to the unflattened parameter names ``unflat_param_names`` by enforcing that
all values are the same and using that common value.
See the note in :func:`_flatten_zero_dim_tensor_optim_state`.
Args:
state_name (str): Optimizer state name.
non_tensors (List[Any]): Non-tensor optimizer state for the unflattened
parameters corresponding to the single flattened parameter.
unflat_param_names (List[str]): A :class:`list` of unflattened
parameter names corresponding to the single flattened parameter.
Returns:
non_tensor (Any): A non-tensor giving the value of the state
``state_name`` for all unflattened parameters corresponding to the
names ``unflat_param_names``.
"""
non_none_non_tensors = [nt for nt in non_tensors if nt is not None]
# Enforce that all have the same value (same type already checked)
non_tensor_set = set(non_tensors)
if len(non_none_non_tensors) != len(non_tensors) or \
len(non_tensor_set) != 1:
raise ValueError(
"All unflattened parameters comprising a single flattened "
"parameter must have scalar state with the same value and dtype "
f"but got values {non_tensor_set} for state {state_name} and "
f"unflattened parameter names {unflat_param_names}"
)
non_tensor = next(iter(non_tensor_set))
return non_tensor
def _get_flat_param_to_fsdp_module(
model: torch.nn.Module,
):
"""
Constructs a mapping from FSDP flattened parameters to their owning FSDP
modules and ensures that all FSDP modules are initialized.
Args:
model (torch.nn.model): Root module (which may or may not be a
:class:`FullyShardedDataParallel` instance).
"""
flat_param_to_fsdp_module = {}
for module in model.modules():
if isinstance(module, FSDP.FullyShardedDataParallel):
module._lazy_init()
for param in module.params: # may have none
flat_param_to_fsdp_module[param] = module
return flat_param_to_fsdp_module
def _get_param_id_to_param(
model: torch.nn.Module,
optim_input: Optional[Union[
List[Dict[str, Any]], Iterable[torch.nn.Parameter],
]] = None,
) -> List[torch.nn.Parameter]:
"""
Constructs a mapping from parameter IDs to parameters. This may be used
both for models with ``FlatParameter`` s and without.
NOTE: We critically assume that, whether the optimizer input is a list of
parameters or a list of parameter groups, :class:`torch.optim.Optimizer`
enumerates the parameter IDs in order. In other words, for a parameter list
input, the parameter IDs should be in that list order, and for a parameter
groups input, the parameter IDs should be in order within each parameter
group and in order across parameter groups.
Args:
model (torch.nn.Module): Model whose parameters are passed into the
optimizer.
optim_input (Optional[Union[List[Dict[str, Any]],
Iterable[torch.nn.Parameter]]]): Input passed into the optimizer
representing either a :class:`list` of parameter groups or an
iterable of parameters; if ``None``, then this method assumes the
input was ``model.parameters()``. (Default: ``None``)
Returns:
param_id_to_param (List[torch.nn.Parameter]): Mapping from parameter
IDs to parameters, where the parameter ID is implicitly the index
in the :class:`list`.
"""
# Assume the standard case of passing `model.parameters()` to the optimizer
# if `optim_input` is not specified
if optim_input is None:
return list(model.parameters())
try:
params = list(optim_input)
except TypeError:
raise TypeError(
"Optimizer input should be an iterable of Tensors or dicts, "
f"but got {optim_input}"
)
if len(params) == 0:
raise ValueError("Optimizer input should not be empty")
# Check if the optimizer input represents tensors or parameter groups
all_tensors = True
all_dicts = True
for param in params:
all_tensors &= isinstance(param, torch.Tensor)
all_dicts &= isinstance(param, dict)
if not all_tensors and not all_dicts:
raise TypeError(
"Optimizer input should be an iterable of Tensors or dicts"
)
if all_tensors:
return params # type: ignore[return-value]
assert all_dicts
param_id_to_param = []
for param_group in params:
has_params_key = "params" in param_group # type: ignore[operator]
assert has_params_key, \
"A parameter group should map \"params\" to a list of the " \
"parameters in the group"
for param in param_group["params"]: # type: ignore[index]
# Implicitly map `flat_param_id` (current length of the list) to
# `param`
param_id_to_param.append(param)
return param_id_to_param # type: ignore[return-value]
def _get_param_to_param_id(
model: torch.nn.Module,
optim_input: Optional[Union[
List[Dict[str, Any]], Iterable[torch.nn.Parameter],
]] = None,
) -> Dict[torch.nn.Parameter, int]:
"""Constructs the inverse mapping of :func:`_get_param_id_to_param`."""
param_id_to_param = _get_param_id_to_param(model, optim_input)
return {
param: param_id for param_id, param in enumerate(param_id_to_param)
}
def _get_unflat_to_flat_param_ids(
flat_to_unflat_param_ids: Dict[int, List[int]],
) -> List[int]:
"""
Inverts the mapping ``flat_to_unflat_param_ids`` to be from unflattened
parameter ID to flattened parameter ID, where the unflattened parameter ID
is the index in the returned :class:`list`. There may be multiple
unflattened parameter IDs mapping to the same flattened parameter ID.
Args:
flat_to_unflat_param_ids (Dict[int, List[int]]): A mapping from
flattened parameter ID to a :class:`list` of corresponding
unflattened parameter IDs.
Returns:
unflat_to_flat_param_ids (List[int]): A mapping from unflattened
parameter ID to flattened parameter ID, where the unflattened
parameter ID is the index in the :class:`list`.
"""
# Construct as a dict and then convert to list
unflat_to_flat_param_ids = {}
for flat_param_id, unflat_param_ids in flat_to_unflat_param_ids.items():
for unflat_param_id in unflat_param_ids:
assert unflat_param_id not in unflat_to_flat_param_ids, \
"`flat_to_unflat_param_ids` has the unflattened parameter " \
f"ID {unflat_param_id} mapped to multiple flattened " \
"parameter IDs"
unflat_to_flat_param_ids[unflat_param_id] = flat_param_id
num_unflat_param_ids = len(unflat_to_flat_param_ids)
unflat_param_ids_set = set(unflat_to_flat_param_ids.keys())
assert unflat_param_ids_set == set(range(num_unflat_param_ids)), \
"The set of unflattened parameter IDs should be {0, ..., " + \
str(num_unflat_param_ids - 1) + "} but got " + \
f"{unflat_param_ids_set}"
return [
unflat_to_flat_param_ids[unflat_param_id]
for unflat_param_id in range(num_unflat_param_ids)
]
def _is_zero_dim_tensor(x: Any) -> bool:
return torch.is_tensor(x) and x.dim() == 0