| import torch |
| from torch.nn.modules.container import ModuleList, ModuleDict, Module |
| from torch.nn.parameter import Parameter |
| from torch import Tensor |
| from typing import Union, Optional, Iterable, Dict, Tuple |
| from contextlib import contextmanager |
| |
| |
| _cache_enabled = 0 |
| _cache: Dict[Tuple[int, str], Optional[Tensor]] = {} |
| |
| |
| @contextmanager |
| def cached(): |
| r"""Context manager that enables the caching system within parametrizations |
| registered with :func:`register_parametrization`. |
| The value of the parametrized objects is computed and cached the first time |
| they are required when this context manager is active. The cached values are |
| discarded when leaving the context manager. |
| This is useful when using a parametrized parameter more than once in the forward pass. |
| An example of this is when parametrizing the recurrent kernel of an RNN or when |
| sharing weights. |
| The simplest way to activate the cache is by wrapping the forward pass of the neural network |
| |
| .. code-block:: python |
| |
| import torch.nn.utils.parametrize as P |
| ... |
| with P.cached(): |
| output = model(inputs) |
| |
| in training and evaluation. One may also wrap the parts of the modules that use |
| several times the parametrized tensors. For example, the loop of an RNN with a |
| parametrized recurrent kernel: |
| |
| .. code-block:: python |
| |
| with P.cached(): |
| for x in xs: |
| out_rnn = self.rnn_cell(x, out_rnn) |
| """ |
| global _cache |
| global _cache_enabled |
| _cache_enabled += 1 |
| try: |
| yield |
| finally: |
| _cache_enabled -= 1 |
| if not _cache_enabled: |
| _cache = {} |
| |
| |
| class ParametrizationList(ModuleList): |
| r"""A sequential container that holds and manages the ``original`` parameter of |
| a parametrized :class:`~nn.Parameter` or buffer. It is the type of |
| ``module.parametrizations[tensor_name]`` when ``tensor_name`` has been parametrized |
| with :func:`register_parametrization` |
| |
| .. note :: |
| This class is used internally by :func:`register_parametrization`. It is documented |
| here for completeness. It should not be instantiated by the user. |
| |
| Args: |
| modules (iterable): an iterable of modules representing the parametrizations |
| original (Parameter or Tensor): parameter or buffer that is parametrized |
| """ |
| original: Tensor |
| |
| def __init__( |
| self, modules: Iterable[Module], original: Union[Tensor, Parameter] |
| ) -> None: |
| super().__init__(modules) |
| if isinstance(original, Parameter): |
| self.register_parameter("original", original) |
| else: |
| self.register_buffer("original", original) |
| |
| def set_original_(self, value: Tensor) -> None: |
| r"""This method is called when assigning to a parametrized tensor. |
| It calls the methods ``right_inverse`` (see :func:`register_parametrization`) |
| of the parametrizations in the inverse order that they have been registered. |
| Then, it assigns the result to ``self.original``. |
| |
| Args: |
| value (Tensor): Value to which initialize the module |
| |
| Raises: |
| RuntimeError: if any of the parametrizations do not implement a ```right_inverse`` method |
| """ |
| # See https://github.com/pytorch/pytorch/issues/53103 |
| for module in reversed(self): # type: ignore |
| if not hasattr(module, "right_inverse"): |
| raise RuntimeError( |
| "The parametrization '{}' does not implement a 'right_inverse' method. " |
| "Assigning to a parametrized tensor is only possible when all the parametrizations " |
| "implement a 'right_inverse' method.".format( |
| module.__class__.__name__ |
| ) |
| ) |
| |
| with torch.no_grad(): |
| # See https://github.com/pytorch/pytorch/issues/53103 |
| for module in reversed(self): # type: ignore |
| value = module.right_inverse(value) |
| self.original.copy_(value) |
| |
| def forward(self) -> Tensor: |
| x = self.original |
| for module in self: |
| x = module(x) |
| if x.size() != self.original.size(): |
| raise RuntimeError( |
| "The parametrization may not change the size of the parametrized tensor. " |
| "Size of original tensor: {} " |
| "Size of parametrized tensor: {}".format(self.original.size(), x.size()) |
| ) |
| return x |
| |
| |
| def _inject_new_class(module: Module) -> None: |
| r"""Sets up the parametrization mechanism used by parametrizations. |
| This works by substituting the class of the module by a class |
| that extends it to be able to inject a property |
| |
| Args: |
| module (nn.Module): module into which to inject the property |
| """ |
| cls = module.__class__ |
| |
| def getstate(self): |
| raise RuntimeError( |
| "Serialization of parametrized modules is only " |
| "supported through state_dict(). See:\n" |
| "https://pytorch.org/tutorials/beginner/saving_loading_models.html" |
| "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" |
| ) |
| |
| param_cls = type( |
| "Parametrized{}".format(cls.__name__), |
| (cls,), |
| { |
| "__getstate__": getstate, |
| }, |
| ) |
| |
| module.__class__ = param_cls |
| |
| |
| def _inject_property(module: Module, tensor_name: str) -> None: |
| r"""Injects a property into module[tensor_name]. |
| It assumes that the class in the module has already been modified from its |
| original one using _inject_new_class and that the tensor under `tensor_name` |
| has already been moved out |
| |
| Args: |
| module (nn.Module): module into which to inject the property |
| tensor_name (str): name of the name of the property to create |
| """ |
| # We check the precondition. |
| # This should never fire if register_parametrization is correctly implemented |
| assert not hasattr(module, tensor_name) |
| |
| def get_parametrized(self) -> Tensor: |
| global _cache |
| |
| parametrization = self.parametrizations[tensor_name] |
| if _cache_enabled: |
| key = (id(module), tensor_name) |
| tensor = _cache.get(key) |
| if tensor is None: |
| tensor = parametrization() |
| _cache[key] = tensor |
| return tensor |
| else: |
| # If caching is not active, this function just evaluates the parametrization |
| return parametrization() |
| |
| def set_original(self, value: Tensor) -> None: |
| self.parametrizations[tensor_name].set_original_(value) |
| |
| setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) |
| |
| |
| def register_parametrization( |
| module: Module, tensor_name: str, parametrization: Module |
| ) -> Module: |
| r"""Adds a parametrization to a tensor in a module. |
| When accessing ``module[tensor_name]``, the module will return the |
| parametrized version ``parametrization(module[tensor_name])``. The backward |
| pass will differentiate through the ``parametrization`` and if the original |
| tensor is a :class:``~Parameter``, it will be updated accordingly by the optimizer. |
| The first time that a module registers a parametrization, this function will add an attribute |
| ``parametrizations`` to the module of type :class:`~ParametrizationList`. |
| The list of parametrizations on a tensor will be accessible under |
| ``module.parametrizations[tensor_name]``. |
| The original tensor will be accessible under |
| ``module.parametrizations[tensor_name].original``. |
| Parametrizations may be composed by registering several parametrizations |
| on the same attribute. |
| Parametrized parameters and buffers have a built-in caching system that can be activated |
| using :func:`cached`. |
| A ``parametrization`` may optionally implement a method with signature |
| |
| .. code-block:: python |
| |
| def right_inverse(self, X: Tensor) -> Tensor |
| |
| If this method is implemented, it will be possible to assign to the parametrized tensor. |
| This may be used to initialize the tensor: |
| |
| >>> import torch |
| >>> import torch.nn.utils.parametrize as P |
| >>> |
| >>> class Symmetric(torch.nn.Module): |
| >>> def forward(self, X): |
| >>> return X.triu() + X.triu(1).T # Return a symmetric matrix |
| >>> |
| >>> def right_inverse(self, A): |
| >>> return A.triu() |
| >>> |
| >>> m = torch.nn.Linear(5, 5) |
| >>> P.register_parametrization(m, "weight", Symmetric()) |
| >>> print(torch.allclose(m.weight, m.weight.T)) # m.weight is now symmetric |
| True |
| >>> A = torch.rand(5, 5) |
| >>> A = A + A.T # A is now symmetric |
| >>> m.weight = A # Initialize the weight to be the symmetric matrix A |
| >>> print(torch.allclose(m.weight, A)) |
| True |
| |
| In most situations, ``right_inverse`` will be a function such that |
| ``forward(right_inverse(X)) == X`` (see |
| `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_). |
| Sometimes, when the parametrization is not surjective, it may be reasonable |
| to relax this, as we did with ``Symmetric`` in the example above. |
| |
| Args: |
| module (nn.Module): module on which to register the parametrization |
| tensor_name (str): name of the parameter, buffer on which to register |
| the parametrization |
| parametrization (nn.Module): the parametrization to register |
| |
| Returns: |
| Module: module |
| |
| Raises: |
| ValueError: if the module does not have a parameter or a buffer named ``tensor_name`` |
| """ |
| if is_parametrized(module, tensor_name): |
| # Just add the new parametrization to the parametrization list |
| module.parametrizations[tensor_name].append(parametrization) # type: ignore |
| elif tensor_name in module._buffers or tensor_name in module._parameters: |
| # Set the parametrization mechanism |
| # Fetch the original buffer or parameter |
| original = getattr(module, tensor_name) |
| # Delete the previous parameter or buffer |
| delattr(module, tensor_name) |
| # If this is the first parametrization of a buffer or parameter of the module, |
| # we prepare the module to inject the property |
| if not is_parametrized(module): |
| # Change the class |
| _inject_new_class(module) |
| # Inject the a ``ModuleDict`` into the instance under module.parametrizations |
| module.parametrizations = ModuleDict() |
| # Add a property into the class |
| _inject_property(module, tensor_name) |
| # Add a ParametrizationList |
| module.parametrizations[tensor_name] = ParametrizationList( # type: ignore |
| [parametrization], original |
| ) |
| else: |
| raise ValueError( |
| "Module '{}' does not have a parameter, a buffer, nor a " |
| "parametrized element with name '{}'".format(module, tensor_name) |
| ) |
| return module |
| |
| |
| def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: |
| r"""Returns ``True`` if module has an active parametrization. |
| If the argument ``name`` is specified, it returns ``True`` if |
| ``module[name]`` is parametrized. |
| |
| Args: |
| module (nn.Module): module to query |
| name (str, optional): attribute in the module to query |
| Default: ``None`` |
| """ |
| parametrizations = getattr(module, "parametrizations", None) |
| if parametrizations is None or not isinstance(parametrizations, ModuleDict): |
| return False |
| if tensor_name is None: |
| # Check that there is at least one parametrized buffer or Parameter |
| return len(parametrizations) > 0 |
| else: |
| return tensor_name in parametrizations |
| |
| |
| def remove_parametrizations( |
| module: Module, tensor_name: str, leave_parametrized: bool = True |
| ) -> Module: |
| r"""Removes the parametrizations on a tensor in a module. |
| |
| - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to |
| its current output. In this case, the parametrization shall not change the ``dtype`` |
| of the tensor. |
| - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to |
| the unparametrised tensor in ``module.parametrizations[tensor_name].original``. |
| |
| Args: |
| module (nn.Module): module from which remove the parametrization |
| tensor_name (str): name of the parametrization to be removed |
| leave_parametrized (bool, optional): leave the attribute ``tensor_name`` parametrized. |
| Default: ``True`` |
| |
| Returns: |
| Module: module |
| |
| Raises: |
| ValueError: if ``module[tensor_name]`` is not parametrized |
| ValueError: if ``leave_parametrized=True`` and the parametrization changes the ``dtype`` of the tensor |
| """ |
| |
| if not is_parametrized(module, tensor_name): |
| raise ValueError( |
| "Module {} does not have a parametrization on {}".format( |
| module, tensor_name |
| ) |
| ) |
| |
| # Fetch the original tensor |
| original = module.parametrizations[tensor_name].original # type: ignore |
| if leave_parametrized: |
| t = getattr(module, tensor_name) |
| # If they have the same dtype, we reuse the original tensor. |
| # We do this so that the parameter does not to change the id() |
| # This way the user does not need to update the optimizer |
| if t.dtype == original.dtype: |
| original.set_(t) |
| else: |
| raise ValueError( |
| "The parametrization changes the dtype of the tensor from {} to {}. " |
| "It is not supported to leave the tensor parametrized (`leave_parametrized=True`) " |
| "in this case.".format(original.dtype, t.dtype) |
| ) |
| # Delete the property that manages the parametrization |
| delattr(module.__class__, tensor_name) |
| # Delete the ParametrizationList |
| del module.parametrizations[tensor_name] # type: ignore |
| |
| # Restore the parameter / buffer into the main class |
| if isinstance(original, Parameter): |
| module.register_parameter(tensor_name, original) |
| else: |
| module.register_buffer(tensor_name, original) |
| |
| # Roll back the parametrized class if no other buffer or parameter |
| # is currently parametrized in this class |
| if not is_parametrized(module): |
| delattr(module, "parametrizations") |
| # Restore class |
| orig_cls = module.__class__.__bases__[0] |
| module.__class__ = orig_cls |
| return module |