| import itertools |
| import warnings |
| from typing import Protocol, Optional, Type, Any |
| |
| import torch |
| from ..parameter import is_lazy |
| |
| __all__ = ['LazyModuleMixin'] |
| |
| class _LazyProtocol(Protocol): |
| """This class is used to avoid errors with mypy checks for the attributes in a mixin. |
| |
| https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes |
| """ |
| |
| def _register_load_state_dict_pre_hook(self, hook): |
| ... |
| |
| def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): |
| ... |
| |
| def _lazy_load_hook( |
| self, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| ... |
| |
| def _get_name(self): |
| ... |
| |
| def _infer_parameters(self, module, input): |
| ... |
| |
| @property |
| def _parameters(self): |
| ... |
| |
| @property |
| def _buffers(self): |
| ... |
| |
| @property |
| def _non_persistent_buffers_set(self): |
| ... |
| |
| @property |
| def _load_hook(self): |
| ... |
| |
| @property |
| def _initialize_hook(self): |
| ... |
| |
| |
| class LazyModuleMixin: |
| r"""A mixin for modules that lazily initialize parameters, also known as "lazy modules". |
| |
| .. warning: |
| Lazy modules are an experimental new feature under active development, |
| and their API is likely to change. |
| |
| Modules that lazily initialize parameters, or "lazy modules", |
| derive the shapes of their parameters from the first input(s) |
| to their forward method. Until that first forward they contain |
| :class:`torch.nn.UninitializedParameter` s that should not be accessed |
| or used, and afterward they contain regular :class:`torch.nn.Parameter` s. |
| Lazy modules are convenient since they don't require computing some |
| module arguments, like the :attr:`in_features` argument of a |
| typical :class:`torch.nn.Linear`. |
| |
| After construction, networks with lazy modules should first |
| be converted to the desired dtype and placed on the expected device. |
| This is because lazy modules only perform shape inference so the usual dtype |
| and device placement behavior applies. |
| The lazy modules should then perform "dry runs" to initialize all the components in the module. |
| These "dry runs" send inputs of the correct size, dtype, and device through |
| the network and to each one of its lazy modules. After this the network can be used as usual. |
| |
| >>> # xdoctest: +SKIP |
| >>> class LazyMLP(torch.nn.Module): |
| ... def __init__(self): |
| ... super().__init__() |
| ... self.fc1 = torch.nn.LazyLinear(10) |
| ... self.relu1 = torch.nn.ReLU() |
| ... self.fc2 = torch.nn.LazyLinear(1) |
| ... self.relu2 = torch.nn.ReLU() |
| ... |
| ... def forward(self, input): |
| ... x = self.relu1(self.fc1(input)) |
| ... y = self.relu2(self.fc2(x)) |
| ... return y |
| >>> # constructs a network with lazy modules |
| >>> lazy_mlp = LazyMLP() |
| >>> # transforms the network's device and dtype |
| >>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs' |
| >>> lazy_mlp = lazy_mlp.cuda().double() |
| >>> lazy_mlp |
| LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True) |
| (relu1): ReLU() |
| (fc2): LazyLinear(in_features=0, out_features=1, bias=True) |
| (relu2): ReLU() |
| ) |
| >>> # performs a dry run to initialize the network's lazy modules |
| >>> lazy_mlp(torch.ones(10,10).cuda()) |
| >>> # after initialization, LazyLinear modules become regular Linear modules |
| >>> lazy_mlp |
| LazyMLP( |
| (fc1): Linear(in_features=10, out_features=10, bias=True) |
| (relu1): ReLU() |
| (fc2): Linear(in_features=10, out_features=1, bias=True) |
| (relu2): ReLU() |
| ) |
| >>> # attaches an optimizer, since parameters can now be used as usual |
| >>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01) |
| |
| A final caveat when using lazy modules is that the order of initialization of a network's |
| parameters may change, since the lazy modules are always initialized after other modules. |
| For example, if the LazyMLP class defined above had a :class:`torch.nn.LazyLinear` module |
| first and then a regular :class:`torch.nn.Linear` second, the second module would be |
| initialized on construction and the first module would be initialized during the first dry run. |
| This can cause the parameters of a network using lazy modules to be initialized differently |
| than the parameters of a network without lazy modules as the order of parameter initializations, |
| which often depends on a stateful random number generator, is different. |
| Check :doc:`/notes/randomness` for more details. |
| |
| Lazy modules can be serialized with a state dict like other modules. For example: |
| |
| >>> lazy_mlp = LazyMLP() |
| >>> # The state dict shows the uninitialized parameters |
| >>> lazy_mlp.state_dict() |
| OrderedDict([('fc1.weight', Uninitialized parameter), |
| ('fc1.bias', |
| tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, |
| 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), |
| ('fc2.weight', Uninitialized parameter), |
| ('fc2.bias', tensor([0.0019]))]) |
| |
| |
| Lazy modules can load regular :class:`torch.nn.Parameter` s (i.e. you can serialize/deserialize |
| initialized LazyModules and they will remain initialized) |
| |
| |
| >>> full_mlp = LazyMLP() |
| >>> # Dry run to initialize another module |
| >>> full_mlp.forward(torch.ones(10, 1)) |
| >>> # Load an initialized state into a lazy module |
| >>> lazy_mlp.load_state_dict(full_mlp.state_dict()) |
| >>> # The state dict now holds valid values |
| >>> lazy_mlp.state_dict() |
| OrderedDict([('fc1.weight', |
| tensor([[-0.3837], |
| [ 0.0907], |
| [ 0.6708], |
| [-0.5223], |
| [-0.9028], |
| [ 0.2851], |
| [-0.4537], |
| [ 0.6813], |
| [ 0.5766], |
| [-0.8678]])), |
| ('fc1.bias', |
| tensor([-1.8832e+25, 4.5636e-41, -1.8832e+25, 4.5636e-41, -6.1598e-30, |
| 4.5637e-41, -1.8788e+22, 4.5636e-41, -2.0042e-31, 4.5637e-41])), |
| ('fc2.weight', |
| tensor([[ 0.1320, 0.2938, 0.0679, 0.2793, 0.1088, -0.1795, -0.2301, 0.2807, |
| 0.2479, 0.1091]])), |
| ('fc2.bias', tensor([0.0019]))]) |
| |
| Note, however, that the loaded parameters will not be replaced when doing a "dry run" if they are initialized |
| when the state is loaded. This prevents using initialized modules in different contexts. |
| """ |
| |
| # modules inheriting from this will change their __class__ to the specified |
| # one after they are fully initialized |
| cls_to_become: Optional[Type[Any]] = None |
| |
| def __init__(self: _LazyProtocol, *args, **kwargs): |
| # Mypy doesnt like this super call in a mixin |
| super().__init__(*args, **kwargs) # type: ignore[misc] |
| self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook) |
| self._initialize_hook = self.register_forward_pre_hook(self._infer_parameters, with_kwargs=True) |
| warnings.warn('Lazy modules are a new feature under heavy development ' |
| 'so changes to the API or functionality can happen at any moment.') |
| |
| def _save_to_state_dict(self: _LazyProtocol, destination, prefix, keep_vars): |
| # This should be ideally implemented as a hook, |
| # but we should override `detach` in the UninitializedParameter to return itself |
| # which is not clean |
| for name, param in self._parameters.items(): |
| if param is not None: |
| if not (is_lazy(param) or keep_vars): |
| param = param.detach() |
| destination[prefix + name] = param |
| for name, buf in self._buffers.items(): |
| if buf is not None and name not in self._non_persistent_buffers_set: |
| if not (is_lazy(buf) or keep_vars): |
| buf = buf.detach() |
| destination[prefix + name] = buf |
| |
| def _lazy_load_hook( |
| self: _LazyProtocol, state_dict, prefix, local_metadata, strict, |
| missing_keys, unexpected_keys, error_msgs): |
| """load_state_dict pre-hook function for lazy buffers and parameters. |
| |
| The purpose of this hook is to adjust the current state and/or |
| ``state_dict`` being loaded so that a module instance serialized in |
| both un/initialized state can be deserialized onto both un/initialized |
| module instance. |
| See comment in ``torch.nn.Module._register_load_state_dict_pre_hook`` |
| for the details of the hook specification. |
| """ |
| for name, param in itertools.chain(self._parameters.items(), self._buffers.items()): |
| key = prefix + name |
| if key in state_dict and param is not None: |
| input_param = state_dict[key] |
| if is_lazy(param): |
| # The current parameter is not initialized but the one being loaded one is |
| # create a new parameter based on the uninitialized one |
| if not is_lazy(input_param): |
| with torch.no_grad(): |
| param.materialize(input_param.shape) |
| |
| def initialize_parameters(self: _LazyProtocol, *args, **kwargs): |
| r"""Initialize parameters according to the input batch properties. |
| |
| This adds an interface to isolate parameter initialization from the |
| forward pass when doing parameter shape inference. |
| """ |
| raise NotImplementedError(f'initialize_parameters is not implemented for {self.__class__.__name__}') |
| |
| def has_uninitialized_params(self: _LazyProtocol): |
| r"""Check if a module has parameters that are not initialized.""" |
| # This is to avoid the JIT to track this parameter and force |
| # custom modules __setstate__ to add it |
| params = self._parameters.values() |
| buffers = self._buffers.values() |
| for param in itertools.chain(params, buffers): |
| if is_lazy(param): |
| return True |
| return False |
| |
| def _infer_parameters(self: _LazyProtocol, module, args, kwargs=None): |
| r"""Infers the size and initializes the parameters according to the provided input batch. |
| |
| Given a module that contains parameters that were declared inferrable |
| using :class:`torch.nn.parameter.ParameterMode.Infer`, runs a forward pass |
| in the complete module using the provided input to initialize all the parameters |
| as needed. |
| The module is set into evaluation mode before running the forward pass in order |
| to avoid saving statistics or calculating gradients |
| """ |
| kwargs = kwargs if kwargs else {} |
| module.initialize_parameters(*args, **kwargs) |
| if module.has_uninitialized_params(): |
| raise RuntimeError(f'module {self._get_name()} has not been fully initialized') |
| module._initialize_hook.remove() |
| module._load_hook.remove() |
| delattr(module, '_initialize_hook') |
| delattr(module, '_load_hook') |
| if module.cls_to_become is not None: |
| module.__class__ = module.cls_to_become |
| |
| |
| def _replicate_for_data_parallel(self: _LazyProtocol): |
| raise RuntimeError('Modules with uninitialized parameters can\'t be used with `DataParallel`. ' |
| 'Run a dummy forward pass to correctly initialize the modules') |