| # Copyright (c) Facebook, Inc. and its affiliates. |
| # |
| # This source code is licensed under the BSD license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import contextlib |
| import copy |
| from abc import ABC, abstractmethod |
| from typing import ( |
| Any, |
| Callable, |
| cast, |
| Dict, |
| Generator, |
| Iterable, |
| Optional, |
| Sequence, |
| Set, |
| Tuple, |
| Type, |
| Union, |
| ) |
| |
| import torch.nn as nn |
| |
| __all__ = [ |
| "always_wrap_policy", |
| "lambda_auto_wrap_policy", |
| "transformer_auto_wrap_policy", |
| "size_based_auto_wrap_policy", |
| "enable_wrap", |
| "wrap", |
| "CustomPolicy", |
| "ModuleWrapPolicy", |
| ] |
| |
| |
| # NOTE: We intentionally keep this function simple and isolate the complexity |
| # to `fn` to enable using this function generically. We may move this to a |
| # non-FSDP-specific folder and/or make it public in the future. |
| def _post_order_apply( |
| root_module: nn.Module, |
| fn: Callable[[nn.Module], Optional[nn.Module]], |
| ): |
| """ |
| This applies ``fn`` to every module in the module tree of ``root_module`` |
| following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, |
| then this replaces the original module with the newly returned one in the |
| tree. Otherwise, ``fn`` should return ``None``, in which case the module is |
| not changed. |
| """ |
| # Track visited modules to avoid visiting shared modules multiple times |
| visited_modules: Set[nn.Module] = {root_module} |
| |
| def _post_order_apply_inner( |
| module: nn.Module, |
| module_name: str, |
| parent_module: Optional[nn.Module], |
| ): |
| for child_module_name, child_module in module.named_children(): |
| if child_module not in visited_modules: |
| visited_modules.add(child_module) |
| _post_order_apply_inner(child_module, child_module_name, module) |
| optional_module = fn(module) |
| if optional_module is not None: |
| assert isinstance(parent_module, nn.Module), ( |
| "Non-root modules should have their parent module set but got " |
| f"{parent_module} for {module}" |
| ) |
| assert module_name, ( |
| "Non-root modules should have their module name set but got " |
| f"an empty module name for {module}" |
| ) |
| assert isinstance( |
| optional_module, nn.Module |
| ), f"fn should return None or an nn.Module but got {optional_module}" |
| setattr(parent_module, module_name, optional_module) |
| |
| _post_order_apply_inner(root_module, "", None) |
| |
| |
| def _construct_wrap_fn( |
| root_module: nn.Module, |
| target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], |
| fsdp_fn: Callable, |
| ) -> Callable[[nn.Module], Optional[nn.Module]]: |
| """ |
| This constructs the "wrap" function to pass to :func:`_post_order_apply` |
| based on ``target_module_to_kwargs``, which should be constructed from the |
| wrapping policy. |
| """ |
| |
| def fn(module: nn.Module) -> Optional[nn.Module]: |
| # Explicitly avoid wrapping the root module since for FSDP, it is |
| # handled by the caller |
| if module in target_module_to_kwargs and module is not root_module: |
| kwargs = target_module_to_kwargs[module] |
| return fsdp_fn(module, **kwargs) |
| return None |
| |
| return fn |
| |
| |
| def _run_mixed_precision_override_policy( |
| root_module: nn.Module, |
| module_classes: Iterable[Type[nn.Module]], |
| ignored_modules: Set[nn.Module], |
| root_kwargs: Dict[str, Any], |
| target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]], |
| ): |
| module_classes_tuple = tuple(set(module_classes)) |
| for module in root_module.modules(): |
| if module in ignored_modules: |
| continue |
| elif isinstance(module, module_classes_tuple): |
| # This policy overrides any existing policy |
| if module not in target_module_to_kwargs: |
| # Only inherit from the root kwargs if not already specified |
| target_module_to_kwargs[module] = root_kwargs |
| target_module_to_kwargs[module]["mixed_precision"] = None |
| return target_module_to_kwargs |
| |
| |
| def always_wrap_policy(*args, **kwargs) -> bool: |
| """ |
| A simple recursive wrap policy that always returns ``True``. This means |
| that every submodule is wrapped by the wrapper class in |
| :func:`_recursive_wrap`. |
| """ |
| return True |
| |
| |
| class _Policy(ABC): |
| """ |
| This defines an abstract base class that represents a policy for applying |
| a module-level API. |
| """ |
| |
| @abstractmethod |
| def _run_policy( |
| self, |
| root_module: nn.Module, |
| ignored_modules: Set[nn.Module], |
| root_kwargs: Dict[str, Any], |
| ) -> Dict[nn.Module, Dict[str, Any]]: |
| """ |
| This should return a dict ``target_module_to_kwargs`` that maps from |
| each target module to wrap to its kwargs. |
| """ |
| ... |
| |
| |
| def _module_wrap_policy( |
| module: nn.Module, |
| recurse: bool, |
| nonwrapped_numel: int, |
| module_classes: Set[Type[nn.Module]], |
| ) -> bool: |
| """ |
| This auto wrap policy wraps every module that is an instance of any type in |
| ``module_classes`` as its own FSDP instance. The root module given by |
| ``module`` is always wrapped as an FSDP instance regardless. Since the |
| wrapping proceeds bottom up, each FSDP instance manages the parameters in |
| its subtree excluding any already managed by a child FSDP instance. |
| |
| Args: |
| module (nn.Module): Current module being considered. |
| recurse (bool): If ``False``, then this function must decide whether |
| ``module`` should be wrapped as an FSDP instance or not. If |
| ``True``, then the function is still recursing down the module |
| tree as a part of the DFS. |
| nonwrapped_numel (int): Parameter numel not yet wrapped. |
| module_classes (Set[Type[nn.Module]]): Set of module classes that are |
| wrapped as FSDP instances. |
| |
| Returns: |
| ``True`` if ``recurse=True``, and whether ``module`` should be wrapped |
| if ``recurse=False``. |
| """ |
| if recurse: |
| return True # always recurse |
| return isinstance(module, tuple(module_classes)) |
| |
| |
| class ModuleWrapPolicy(_Policy): |
| """ |
| This policy applies to every module of the specified module classes, |
| passing in the kwargs given to the root. |
| """ |
| |
| def __init__(self, module_classes: Iterable[Type[nn.Module]]): |
| module_classes_set = set(module_classes) |
| self._module_classes = module_classes_set |
| self._module_classes_str = str(module_classes_set) |
| |
| def _run_policy( |
| self, |
| root_module: nn.Module, |
| ignored_modules: Set[nn.Module], |
| root_kwargs: Dict[str, Any], |
| ) -> Dict[nn.Module, Dict[str, Any]]: |
| module_classes = tuple(self._module_classes) |
| target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} |
| for module in root_module.modules(): |
| if module in ignored_modules: |
| continue |
| elif isinstance(module, module_classes): |
| # Shallow copy to avoid coupling changes across modules |
| target_module_to_kwargs[module] = copy.copy(root_kwargs) |
| return target_module_to_kwargs |
| |
| def __call__(self, module, recurse, *args, **kwargs): |
| # nonwrapped_numel is not used. |
| return _module_wrap_policy( |
| module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes |
| ) |
| |
| def __repr__(self) -> str: |
| return super().__repr__() + f"({self._module_classes_str})" |
| |
| |
| class CustomPolicy(_Policy): |
| """ |
| This policy takes in a lambda function that maps a given ``nn.Module`` to |
| either ``False``, ``True``, or a kwarg dictionary. |
| - If the function returns ``False`` or an empty dictionary, then the module |
| does not have the API applied. |
| - If the function returns ``True``, then the module has the API applied |
| with the root's kwargs. |
| - If the function returns a non-empty dictionary, then the module has the |
| API applied, and the dictionary overrides the root's kwargs. |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("undefined variables") |
| >>> model = init_transformer_model(...) |
| >>> def lambda_fn(module: nn.Module): |
| >>> if module is model.lm_head: |
| >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} |
| >>> elif isinstance(module, TransformerBlock): |
| >>> return True |
| >>> return False |
| >>> policy = CustomPolicy(lambda_fn) |
| >>> fsdp_model = FSDP(model, auto_wrap_policy=policy) |
| """ |
| |
| def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, Dict[str, Any]]]): |
| self._lambda_fn = lambda_fn |
| |
| def _run_policy( |
| self, |
| root_module: nn.Module, |
| ignored_modules: Set[nn.Module], |
| root_kwargs: Dict[str, Any], |
| ) -> Dict[nn.Module, Dict[str, Any]]: |
| target_module_to_kwargs: Dict[nn.Module, Dict[str, Any]] = {} |
| for module in root_module.modules(): |
| if module in ignored_modules: |
| continue |
| res = self._lambda_fn(module) |
| if not isinstance(res, (dict, bool)): |
| raise ValueError( |
| "The lambda_fn passed to CustomPolicy should return " |
| f"False/True or a kwarg dict, but it returned {res}" |
| ) |
| if not res: |
| continue |
| kwargs = copy.copy(root_kwargs) |
| if isinstance(res, dict): |
| # Override the root kwargs with the ones specified by the |
| # lambda function |
| kwargs.update(res) |
| target_module_to_kwargs[module] = kwargs |
| return target_module_to_kwargs |
| |
| |
| def lambda_auto_wrap_policy( |
| module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable |
| ) -> bool: |
| """ |
| A convenient auto wrap policy to wrap submodules based on an arbitrary user |
| function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as |
| a `wrapper_cls` unit. |
| |
| Return if a module should be wrapped during auto wrapping. |
| |
| The first three parameters are required by :func:`_recursive_wrap`. |
| |
| Args: |
| module (nn.Module): Current module being considered. |
| recurse (bool): If ``False``, then this function must decide whether |
| ``module`` should be wrapped as an FSDP instance or not. If |
| ``True``, then the function is still recursing down the module |
| tree as a part of the DFS. |
| nonwrapped_numel (int): Parameter numel not yet wrapped. |
| |
| lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then |
| this module will be wrapped. |
| """ |
| if recurse: |
| return True # always recurse |
| return lambda_fn(module) |
| |
| |
| def transformer_auto_wrap_policy( |
| module: nn.Module, |
| recurse: bool, |
| nonwrapped_numel: int, |
| transformer_layer_cls: Set[Type[nn.Module]], |
| ) -> bool: |
| """ |
| See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the |
| same as ``module_classes``. Note that shared parameters must be wrapped in |
| the same FSDP instance, so this auto wrap policy can help wrap shared |
| embeddings into the same FSDP instance for transformer models. |
| """ |
| return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) |
| |
| |
| def _wrap_module_cls_individually( |
| module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs |
| ): |
| if recurse: |
| # always recurse |
| return True |
| else: |
| # if not recursing, decide whether we should wrap based on whether the type of module |
| # is in `module_classes`. |
| return isinstance(module, tuple(module_classes)) |
| |
| |
| def _or_policy( |
| module: nn.Module, |
| recurse: bool, |
| nonwrapped_numel: int, |
| policies, |
| ) -> bool: |
| """ |
| A policy that wraps ``module`` if any policy in the passed in iterable of |
| ``policies`` returns ``True``. |
| """ |
| return any( |
| policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) |
| for policy in policies |
| ) |
| |
| |
| def size_based_auto_wrap_policy( |
| module: nn.Module, |
| recurse: bool, |
| nonwrapped_numel: int, |
| # Additional custom arguments |
| min_num_params: int = int(1e8), |
| force_leaf_modules: Optional[Set[Type[nn.Module]]] = None, |
| exclude_wrap_modules: Optional[Set[Type[nn.Module]]] = None, |
| ) -> bool: |
| """ |
| A size-based auto wrap policy. |
| |
| Args: |
| module (nn.Module): Current module being considered. |
| recurse (bool): If ``False``, then this function must decide whether |
| ``module`` should be wrapped as an FSDP instance or not. If |
| ``True``, then the function is still recursing down the module |
| tree as a part of the DFS. |
| nonwrapped_numel (int): Parameter numel not yet wrapped. |
| |
| min_num_params (int): Customizable policy input that controls the size |
| threshold over which a module is ready to be wrapped. This is in |
| units of numel. |
| force_leaf_modules (Set[Type[nn.Module]]): Set of module types to keep |
| as leaves, i.e. their children will never be wrapped. |
| exclude_wrap_modules (Set[Type[nn.Module]]): Set of module types to be |
| excluded in wrapping. |
| |
| Returns: |
| Whether ``module`` should be wrapped. |
| """ |
| force_leaf_modules = ( |
| size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] |
| if force_leaf_modules is None |
| else force_leaf_modules |
| ) |
| exclude_wrap_modules = ( |
| size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] |
| if exclude_wrap_modules is None |
| else exclude_wrap_modules |
| ) |
| |
| # Keep the argument `min_num_params` for BC for now, but it represents the |
| # minimum non-wrapped *numel* before triggering a wrapping |
| min_nonwrapped_numel = min_num_params |
| is_large = nonwrapped_numel >= min_nonwrapped_numel |
| if recurse: |
| # We should recurse if the module is big enough but not in force_leaf_modules list. |
| return is_large and not isinstance(module, tuple(force_leaf_modules)) |
| else: |
| # If we are not recursing, determine if we should wrap. |
| return is_large and not isinstance(module, tuple(exclude_wrap_modules)) |
| |
| |
| # Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. |
| size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] |
| size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] |
| |
| |
| @contextlib.contextmanager |
| def enable_wrap( |
| *, wrapper_cls: Any, **wrapper_kwargs: Any |
| ) -> Generator[None, None, None]: |
| """ |
| Context manager to wrap modules using a wrapper. |
| |
| Useful for when you'd like to apply the same configuration arguments to all |
| child modules that you wrap. A particularly important use case is wrapping |
| large layers so that they get sharded (in-place) during initialization, to |
| avoid running out of system memory. Large layers can indicate that they |
| should be sharded via the ``wrap`` annotation and this context manager can |
| provide the exact configuration for these nested instances. |
| |
| Usage:: |
| |
| with enable_wrap(wrapper_cls, **params): |
| # Wraps layer in FSDP by default if within context |
| self.l1 = wrap(torch.nn.Linear(5, 5)) |
| |
| Args: |
| wrapper_cls: |
| Class that `wrap` annotation will `wrap` modules with, such as |
| `FullyShardedDataParallel`. |
| **wrapper_kwargs: |
| Configuration settings that will be passed to all ``wrap`` |
| instances inside the context |
| """ |
| kwargs = { |
| "wrapper_cls": wrapper_cls, |
| **wrapper_kwargs, |
| } |
| with _ConfigAutoWrap(**kwargs): |
| yield |
| |
| |
| def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: |
| """ |
| Annotate that a module should be wrapped. Annotated modules will only be |
| wrapped if inside of an :func:`enable_wrap` context manager. This allows |
| a module to be initialized both with and without a wrapper without code |
| change. |
| |
| The class that this function wraps the passed in ``nn.Module`` with is the |
| passed in ``wrapper_cls`` argument into ``enable_wrap``. Both |
| ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct |
| the ``wrapper_cls`` instance. In the case of duplicate kwargs in |
| ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be |
| respected. |
| |
| Usage:: |
| |
| with enable_wrap(wrapper_cls=FSDP, **fsdp_config): |
| # Wraps layer in FSDP by default if within context |
| self.l1 = wrap(torch.nn.Linear(5, 5)) |
| |
| Args: |
| module (nn.Module): module to wrap (if in :func:`enable_wrap` context) |
| **wrap_overrides: configuration overrides that will take priority over |
| the values provided by the :func:`enable_wrap` context |
| """ |
| if _ConfigAutoWrap.in_autowrap_context: |
| assert _ConfigAutoWrap.wrapper_cls is not None |
| |
| wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} |
| return _wrap( |
| module, |
| _ConfigAutoWrap.wrapper_cls, |
| **wrap_overrides, |
| ) |
| return module |
| |
| |
| def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: |
| assert wrapper_cls is not None |
| if hasattr(module, "_wrap_overrides"): |
| # If module has a _wrap_overrides attribute, we force overriding the |
| # FSDP config with these attributes for this module. Currently this |
| # is only used to disable mixed precision for BatchNorm when |
| # auto_wrapping. |
| overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type] |
| return wrapper_cls(module, **overrides) |
| |
| return wrapper_cls(module, **kwargs) |
| |
| |
| def _recursive_wrap( |
| module: nn.Module, |
| auto_wrap_policy: Callable, |
| wrapper_cls: Callable, |
| ignored_modules: Set[nn.Module], |
| ignored_params: Set[nn.Parameter], |
| only_wrap_children: bool = False, |
| **kwargs: Any, |
| ) -> Tuple[nn.Module, int]: |
| """ |
| Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns |
| ``True`` with ``wrapper_cls``. |
| |
| Args: |
| module (nn.Module): Module to recursively wrap. |
| auto_wrap_policy (Callable): A callable representing a policy that |
| determines which modules to recursively wrap with ``wrapper_cls``. |
| ignored_modules (Set[torch.nn.Module]): Modules to ignore when |
| wrapping. |
| ignored_params (Set[torch.nn.Parameter]): Parameters to ignore when |
| wrapping; these should be the parameters contained in the modules |
| in ``ignored_modules``. |
| Returns: |
| (nn.Module, int): |
| ``module`` after wrapping and the numel recursively wrapped. |
| """ |
| assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." |
| assert wrapper_cls is not None, "Must specify wrapper_cls" |
| # Make sure no child is already wrapped. |
| for _, child in module.named_modules(): |
| if child in ignored_modules: |
| continue |
| try: |
| assert not isinstance(child, cast(type, wrapper_cls)) |
| except TypeError: |
| # wrapper_cls is a function as opposed to a class type, just bypass above check. |
| pass |
| |
| # We count all params, assuming none of them are already wrapped. |
| nonwrapped_numel = sum( |
| p.numel() for p in module.parameters() if p not in ignored_params |
| ) |
| |
| assert auto_wrap_policy is not None |
| if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): |
| total_wrapped_numel = 0 |
| # Iterate through the children, recursively wrap if necessary |
| for name, child in module.named_children(): |
| if child in ignored_modules: |
| continue |
| wrapped_child, num_wrapped_params = _recursive_wrap( |
| module=child, |
| auto_wrap_policy=auto_wrap_policy, |
| wrapper_cls=wrapper_cls, |
| ignored_modules=ignored_modules, |
| ignored_params=ignored_params, |
| **kwargs, |
| ) |
| setattr(module, name, wrapped_child) |
| # Keep track of how many parameters have been wrapped |
| total_wrapped_numel += num_wrapped_params |
| # decide if we need to wrap the current module, |
| # since the left over parameters exceed the number of params to wrap |
| remainder = nonwrapped_numel - total_wrapped_numel |
| if not only_wrap_children and auto_wrap_policy( |
| module=module, recurse=False, nonwrapped_numel=remainder |
| ): |
| # Leaf node or final wrapping of the remainder both happen here. |
| return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel |
| else: |
| return module, total_wrapped_numel |
| return module, 0 |
| |
| |
| class _ConfigAutoWrap: |
| """ |
| Helper class to wrap modules based on default config args via a context manager. |
| See :func:`enable_wrap` for more information. |
| """ |
| |
| in_autowrap_context: bool = False # Context flag |
| wrapper_cls: Optional[Callable] = None # The wrapper class |
| kwargs: Dict[str, Any] = {} # Wrapper's args |
| |
| def __init__(self, **kwargs: Dict[str, Any]): |
| self.kwargs = kwargs |
| |
| @staticmethod |
| def enable_autowrap_context(kwargs: Any) -> None: |
| if _ConfigAutoWrap.in_autowrap_context: |
| raise NotImplementedError( |
| "You are already within an autowrap context and we currently do not supported nested autowrap." |
| ) |
| _ConfigAutoWrap.in_autowrap_context = True |
| # Get and save the wrapper cls for the context. |
| assert ( |
| "wrapper_cls" in kwargs.keys() |
| ), "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." |
| _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) |
| del kwargs["wrapper_cls"] |
| # Save the rest. |
| _ConfigAutoWrap.kwargs = kwargs |
| |
| @staticmethod |
| def disable_autowrap_context() -> None: |
| _ConfigAutoWrap.in_autowrap_context = False |
| _ConfigAutoWrap.wrapper_cls = None |
| _ConfigAutoWrap.kwargs = {} |
| |
| def __enter__(self) -> None: |
| self.enable_autowrap_context(self.kwargs) |
| |
| def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| self.disable_autowrap_context() |