| # mypy: allow-untyped-defs | 
 | import uuid | 
 | from collections import OrderedDict | 
 | from functools import wraps | 
 | from typing import Callable, Dict, List, Optional, Type | 
 |  | 
 | import torch.nn as nn | 
 | from torch.distributed._composable_state import _State | 
 |  | 
 |  | 
 | def generate_state_key(string="__composable_api_state_key"): | 
 |     return f"{string}_{str(uuid.uuid4())}" | 
 |  | 
 |  | 
 | STATE_KEY = generate_state_key() | 
 | REGISTRY_KEY = generate_state_key() | 
 |  | 
 |  | 
 | # TODO: we can add additional info to RegistryItem to share across APIs. E.g., | 
 | # we can add args and kwargs here, and then we can detect whether fully_shard | 
 | # is combined with reentrant activation checkpointing and error out with a clear | 
 | # message. | 
 | class RegistryItem: | 
 |     pass | 
 |  | 
 |  | 
 | def contract(state_cls: Type[_State] = _State): | 
 |     r""" | 
 |     Decorate a function as a composable distributed API, where the first | 
 |     argument of the function must be an :class:`nn.Module` instance. The | 
 |     decorator verifies that the wrapped function does not modify parameter, | 
 |     buffer or sub-module fully-qualified names (FQN). | 
 |  | 
 |     When a function ``func`` is decorated by ``@contract()``, a | 
 |     ``.state(module: nn.Module)`` method will be installed to the decorated | 
 |     function. Then you can retrieve and modify the state on a module by calling | 
 |     ``func.state(module)``. | 
 |  | 
 |     Example:: | 
 |         >>> # xdoctest: +SKIP | 
 |         >>> import torch.nn as nn | 
 |         >>> | 
 |         >>> class MyModel(nn.Module): | 
 |         >>>     def __init__(self): | 
 |         >>>         super().__init__() | 
 |         >>>         self.l1 = nn.Linear(10, 10) | 
 |         >>>         self.l2 = nn.Linear(10, 10) | 
 |         >>> | 
 |         >>>     def forward(self, x): | 
 |         >>>         return self.l2(self.l1(x)) | 
 |         >>> | 
 |         >>> @contract() | 
 |         >>> def my_feature(module: nn.Module) -> nn.Module: | 
 |         >>>     my_feature.state(module).some_state = "any value" | 
 |         >>>     return module | 
 |         >>> | 
 |         >>> model = MyModel() | 
 |         >>> my_feature(model.l1) | 
 |         >>> assert my_feature.state(model.l1).some_state == "any value" | 
 |         >>> my_feature(model.l2) | 
 |         >>> model(torch.randn(2, 10)).sum().backward() | 
 |     """ | 
 |  | 
 |     # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package | 
 |     @wraps(state_cls) | 
 |     def inner(func): | 
 |         @wraps(func) | 
 |         def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]: | 
 |             # get existing global states | 
 |             default_all_state: Dict[Callable, _State] = OrderedDict() | 
 |             all_state: Dict[Callable, _State] = module.__dict__.setdefault(  # type: ignore[call-overload] | 
 |                 STATE_KEY, default_all_state | 
 |             ) | 
 |             assert isinstance( | 
 |                 all_state, dict | 
 |             ), "Distributed composable API states corrupted" | 
 |  | 
 |             # get global registry | 
 |             default_registry: Dict[str, RegistryItem] = OrderedDict() | 
 |             registry: Dict[str, RegistryItem] = module.__dict__.setdefault(  # type: ignore[call-overload] | 
 |                 REGISTRY_KEY, default_registry | 
 |             ) | 
 |  | 
 |             assert isinstance( | 
 |                 registry, dict | 
 |             ), "Distributed composable API registry corrupted" | 
 |  | 
 |             # make sure the API func has not been applied to the input module yet. | 
 |             assert func not in all_state and func.__name__ not in registry, ( | 
 |                 "Each distinct composable distributed API can only be applied to a " | 
 |                 f"module once. {func.__name__} has already been applied to the " | 
 |                 f"following module.\n{module}" | 
 |             ) | 
 |  | 
 |             # install states specific to the wrapped ``func`` | 
 |             all_state.setdefault(func, state_cls()) | 
 |             # register ``func`` in the global registry by name | 
 |             registry.setdefault(func.__name__, RegistryItem()) | 
 |  | 
 |             orig_named_params = OrderedDict(module.named_parameters()) | 
 |             orig_named_buffers = OrderedDict( | 
 |                 module.named_buffers(remove_duplicate=False) | 
 |             ) | 
 |             orig_named_modules = OrderedDict( | 
 |                 module.named_modules(remove_duplicate=False) | 
 |             ) | 
 |  | 
 |             updated = func(module, *args, **kwargs) | 
 |  | 
 |             if updated is None: | 
 |                 updated = module | 
 |  | 
 |             new_named_params = OrderedDict(updated.named_parameters()) | 
 |             new_named_buffers = OrderedDict( | 
 |                 updated.named_buffers(remove_duplicate=False) | 
 |             ) | 
 |             new_named_modules = OrderedDict( | 
 |                 updated.named_modules(remove_duplicate=False) | 
 |             ) | 
 |  | 
 |             assert isinstance(updated, nn.Module), ( | 
 |                 "Output of composable distributed APIs must be either None or " | 
 |                 f"nn.Module, but got {type(updated)}" | 
 |             ) | 
 |  | 
 |             def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str): | 
 |                 if orig_fqns == new_fqns: | 
 |                     return | 
 |  | 
 |                 orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) | 
 |                 orig_only = orig_fqn_set - new_fqn_set | 
 |                 new_only = new_fqn_set - orig_fqn_set | 
 |                 if len(orig_only) or len(new_only): | 
 |                     raise RuntimeError( | 
 |                         f"{check_key}" | 
 |                         "Composable distributed API implementations cannot modify " | 
 |                         "FQNs.\n" | 
 |                         f"Only in original FQNs: {orig_only},\n" | 
 |                         f"Only in new FQNs: {new_only}" | 
 |                     ) | 
 |                 else: | 
 |                     raise RuntimeError( | 
 |                         f"{check_key}" | 
 |                         "Composable distributed API implementations cannot modify " | 
 |                         "the order of FQNs.\n" | 
 |                         f"Original FQNs: {orig_only}\n" | 
 |                         f"New FQNs: {new_only}" | 
 |                     ) | 
 |  | 
 |             check_fqn( | 
 |                 list(orig_named_params.keys()), | 
 |                 list(new_named_params.keys()), | 
 |                 "Check parameters, ", | 
 |             ) | 
 |             check_fqn( | 
 |                 list(orig_named_buffers.keys()), | 
 |                 list(new_named_buffers.keys()), | 
 |                 "Check buffer, ", | 
 |             ) | 
 |             check_fqn( | 
 |                 list(orig_named_modules.keys()), | 
 |                 list(new_named_modules.keys()), | 
 |                 "Check modules, ", | 
 |             ) | 
 |  | 
 |             # TODO: a stricter verification should also reject changing module | 
 |             # types and monkey-patching forward() method implementations. | 
 |  | 
 |             # TODO: verify that installed distributed paradigms are compatible with | 
 |             # each other. | 
 |  | 
 |             return updated | 
 |  | 
 |         def get_state(module: nn.Module) -> Optional[_State]: | 
 |             return module.__dict__.setdefault(  # type: ignore[call-overload] | 
 |                 STATE_KEY, | 
 |                 {},  # TODO(@yhcharles): this is a temporary fix, need a better way | 
 |             ).get( | 
 |                 func | 
 |             )  # type: ignore[call-overload] | 
 |  | 
 |         wrapper.state = get_state  # type: ignore[attr-defined] | 
 |  | 
 |         return wrapper | 
 |  | 
 |     return inner | 
 |  | 
 |  | 
 | def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]: | 
 |     r""" | 
 |     Get an ``OrderedDict`` of composable APIs that have been applied to the | 
 |     ``module``, indexed by the API name. If no API has been applied, then this | 
 |     returns ``None``. | 
 |     """ | 
 |     return getattr(module, REGISTRY_KEY, None) |