| import torch | 
 | from torch.nn.modules.container import ModuleList, ModuleDict, Module | 
 | from torch.nn.parameter import Parameter | 
 | from torch import Tensor | 
 |  | 
 | import collections | 
 | import copyreg | 
 | from copy import deepcopy | 
 | from contextlib import contextmanager | 
 | from typing import Union, Optional, Dict, Tuple, Sequence | 
 |  | 
 | __all__ = ['cached', 'ParametrizationList', 'register_parametrization', 'is_parametrized', 'remove_parametrizations', | 
 |            'type_before_parametrizations', 'transfer_parametrizations_and_params'] | 
 |  | 
 | _cache_enabled = 0 | 
 | _cache: Dict[Tuple[int, str], Optional[Tensor]] = {} | 
 |  | 
 |  | 
 | @contextmanager | 
 | def cached(): | 
 |     r"""Context manager that enables the caching system within parametrizations | 
 |     registered with :func:`register_parametrization`. | 
 |  | 
 |     The value of the parametrized objects is computed and cached the first time | 
 |     they are required when this context manager is active. The cached values are | 
 |     discarded when leaving the context manager. | 
 |  | 
 |     This is useful when using a parametrized parameter more than once in the forward pass. | 
 |     An example of this is when parametrizing the recurrent kernel of an RNN or when | 
 |     sharing weights. | 
 |  | 
 |     The simplest way to activate the cache is by wrapping the forward pass of the neural network | 
 |  | 
 |     .. code-block:: python | 
 |  | 
 |         import torch.nn.utils.parametrize as P | 
 |         ... | 
 |         with P.cached(): | 
 |             output = model(inputs) | 
 |  | 
 |     in training and evaluation. One may also wrap the parts of the modules that use | 
 |     several times the parametrized tensors. For example, the loop of an RNN with a | 
 |     parametrized recurrent kernel: | 
 |  | 
 |     .. code-block:: python | 
 |  | 
 |         with P.cached(): | 
 |             for x in xs: | 
 |                 out_rnn = self.rnn_cell(x, out_rnn) | 
 |     """ | 
 |     global _cache | 
 |     global _cache_enabled | 
 |     _cache_enabled += 1 | 
 |     try: | 
 |         yield | 
 |     finally: | 
 |         _cache_enabled -= 1 | 
 |         if not _cache_enabled: | 
 |             _cache = {} | 
 |  | 
 |  | 
 | def _register_parameter_or_buffer(module, name, X): | 
 |     if isinstance(X, Parameter): | 
 |         module.register_parameter(name, X) | 
 |     else: | 
 |         module.register_buffer(name, X) | 
 |  | 
 |  | 
 | class ParametrizationList(ModuleList): | 
 |     r"""A sequential container that holds and manages the ``original`` or ``original0``, ``original1``, ... | 
 |     parameters or buffers of a parametrized :class:`torch.nn.Module`. | 
 |  | 
 |     It is the type of ``module.parametrizations[tensor_name]`` when ``module[tensor_name]`` | 
 |     has been parametrized with :func:`register_parametrization`. | 
 |  | 
 |     If the first registered parametrization has a ``right_inverse`` that returns one tensor or | 
 |     does not have a ``right_inverse`` (in which case we assume that ``right_inverse`` is the identity), | 
 |     it will hold the tensor under the name ``original``. | 
 |     If it has a ``right_inverse`` that returns more than one tensor, these will be registered as | 
 |     ``original0``, ``original1``, ... | 
 |  | 
 |     .. warning:: | 
 |         This class is used internally by :func:`register_parametrization`. It is documented | 
 |         here for completeness. It shall not be instantiated by the user. | 
 |  | 
 |     Args: | 
 |         modules (sequence): sequence of modules representing the parametrizations | 
 |         original (Parameter or Tensor): parameter or buffer that is parametrized | 
 |         unsafe (bool): a boolean flag that denotes whether the parametrization | 
 |             may change the dtype and shape of the tensor. Default: `False` | 
 |             Warning: the parametrization is not checked for consistency upon registration. | 
 |             Enable this flag at your own risk. | 
 |     """ | 
 |     original: Tensor | 
 |     unsafe: bool | 
 |  | 
 |     def __init__( | 
 |         self, modules: Sequence[Module], original: Union[Tensor, Parameter], unsafe: bool = False | 
 |     ) -> None: | 
 |         # We require this because we need to treat differently the first parametrization | 
 |         # This should never throw, unless this class is used from the outside | 
 |         if len(modules) == 0: | 
 |             raise ValueError("ParametrizationList requires one or more modules.") | 
 |  | 
 |         super().__init__(modules) | 
 |         self.unsafe = unsafe | 
 |  | 
 |         # In plain words: | 
 |         # module.weight must keep its dtype and shape. | 
 |         # Furthermore, if there is no right_inverse or the right_inverse returns a tensor, | 
 |         # this should be of the same dtype as the original tensor | 
 |         # | 
 |         # We check that the following invariants hold: | 
 |         #    X = module.weight | 
 |         #    Y = param.right_inverse(X) | 
 |         #    assert isinstance(Y, Tensor) or | 
 |         #           (isinstance(Y, collections.abc.Sequence) and all(isinstance(t, Tensor) for t in Y)) | 
 |         #    Z = param(Y) if isisntance(Y, Tensor) else param(*Y) | 
 |         #    # Consistency checks | 
 |         #    assert X.dtype == Z.dtype and X.shape == Z.shape | 
 |         #    # If it has one input, this allows to be able to use set_ to be able to | 
 |         #    # move data to/from the original tensor without changing its id (which is what the | 
 |         #    # optimiser uses to track parameters) | 
 |         #    if isinstance(Y, Tensor) | 
 |         #      assert X.dtype == Y.dtype | 
 |         # Below we use original = X, new = Y | 
 |  | 
 |         original_shape = original.shape | 
 |         original_dtype = original.dtype | 
 |  | 
 |         # Compute new | 
 |         with torch.no_grad(): | 
 |             new = original | 
 |             for module in reversed(self):  # type: ignore[call-overload] | 
 |                 if hasattr(module, "right_inverse"): | 
 |                     try: | 
 |                         new = module.right_inverse(new) | 
 |                     except NotImplementedError: | 
 |                         pass | 
 |                 # else, or if it throws, we assume that right_inverse is the identity | 
 |  | 
 |         if not isinstance(new, Tensor) and not isinstance(new, collections.abc.Sequence): | 
 |             raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors (list, tuple...). " | 
 |                              f"Got {type(new).__name__}") | 
 |  | 
 |         # Set the number of original tensors | 
 |         self.is_tensor = isinstance(new, Tensor) | 
 |         self.ntensors = 1 if self.is_tensor else len(new) | 
 |  | 
 |         # Register the tensor(s) | 
 |         if self.is_tensor: | 
 |             if original.dtype != new.dtype: | 
 |                 raise ValueError( | 
 |                     "When `right_inverse` outputs one tensor, it may not change the dtype.\n" | 
 |                     f"original.dtype: {original.dtype}\n" | 
 |                     f"right_inverse(original).dtype: {new.dtype}" | 
 |                 ) | 
 |             # Set the original to original so that the user does not need to re-register the parameter | 
 |             # manually in the optimiser | 
 |             with torch.no_grad(): | 
 |                 original.set_(new)  # type: ignore[call-overload] | 
 |             _register_parameter_or_buffer(self, "original", original) | 
 |         else: | 
 |             for i, originali in enumerate(new): | 
 |                 if not isinstance(originali, Tensor): | 
 |                     raise ValueError("'right_inverse' must return a Tensor or a Sequence of tensors " | 
 |                                      "(list, tuple...). " | 
 |                                      f"Got element {i} of the sequence with type {type(originali).__name__}.") | 
 |  | 
 |                 # If the original tensor was a Parameter that required grad, we expect the user to | 
 |                 # add the new parameters to the optimizer after registering the parametrization | 
 |                 # (this is documented) | 
 |                 if isinstance(original, Parameter): | 
 |                     originali = Parameter(originali) | 
 |                 originali.requires_grad_(original.requires_grad) | 
 |                 _register_parameter_or_buffer(self, f"original{i}", originali) | 
 |  | 
 |         if not self.unsafe: | 
 |             # Consistency checks: | 
 |             # Since f : A -> B, right_inverse : B -> A, Z and original should live in B | 
 |             # Z = forward(right_inverse(original)) | 
 |             Z = self() | 
 |             if not isinstance(Z, Tensor): | 
 |                 raise ValueError( | 
 |                     f"A parametrization must return a tensor. Got {type(Z).__name__}." | 
 |                 ) | 
 |             if Z.dtype != original_dtype: | 
 |                 raise ValueError( | 
 |                     "Registering a parametrization may not change the dtype of the tensor, unless `unsafe` flag is enabled.\n" | 
 |                     f"unparametrized dtype: {original_dtype}\n" | 
 |                     f"parametrized dtype: {Z.dtype}" | 
 |                 ) | 
 |             if Z.shape != original_shape: | 
 |                 raise ValueError( | 
 |                     "Registering a parametrization may not change the shape of the tensor, unless `unsafe` flag is enabled.\n" | 
 |                     f"unparametrized shape: {original_shape}\n" | 
 |                     f"parametrized shape: {Z.shape}" | 
 |                 ) | 
 |  | 
 |     def right_inverse(self, value: Tensor) -> None: | 
 |         r"""Calls the methods ``right_inverse`` (see :func:`register_parametrization`) | 
 |         of the parametrizations in the inverse order they were registered in. | 
 |         Then, it stores the result in ``self.original`` if ``right_inverse`` outputs one tensor | 
 |         or in ``self.original0``, ``self.original1``, ... if it outputs several. | 
 |  | 
 |         Args: | 
 |             value (Tensor): Value to which initialize the module | 
 |         """ | 
 |         # All the exceptions in this function should almost never throw. | 
 |         # They could throw if, for example, right_inverse function returns a different | 
 |         # dtype when given a different input, which should most likely be caused by a | 
 |         # bug in the user's code | 
 |  | 
 |         with torch.no_grad(): | 
 |             # See https://github.com/pytorch/pytorch/issues/53103 | 
 |             for module in reversed(self):  # type: ignore[call-overload] | 
 |                 if hasattr(module, "right_inverse"): | 
 |                     value = module.right_inverse(value) | 
 |                 else: | 
 |                     raise RuntimeError(f"parametrization {type(module).__name__} does not implement " | 
 |                                        "right_inverse.") | 
 |             if self.is_tensor: | 
 |                 # These exceptions should only throw when a right_inverse function does not | 
 |                 # return the same dtype for every input, which should most likely be caused by a bug | 
 |                 if not isinstance(value, Tensor): | 
 |                     raise ValueError( | 
 |                         f"`right_inverse` should return a tensor. Got {type(value).__name__}" | 
 |                     ) | 
 |                 if value.dtype != self.original.dtype: | 
 |                     raise ValueError( | 
 |                         f"The tensor returned by `right_inverse` has dtype {value.dtype} " | 
 |                         f"while `original` has dtype {self.original.dtype}" | 
 |                     ) | 
 |                 # We know that the result is going to have the same dtype | 
 |                 self.original.set_(value)  # type: ignore[call-overload] | 
 |             else: | 
 |                 if not isinstance(value, collections.abc.Sequence): | 
 |                     raise ValueError( | 
 |                         "'right_inverse' must return a sequence of tensors. " | 
 |                         f"Got {type(value).__name__}." | 
 |                     ) | 
 |                 if len(value) != self.ntensors: | 
 |                     raise ValueError( | 
 |                         "'right_inverse' must return a sequence of tensors of length " | 
 |                         f"{self.ntensors}. Got a sequence of length {len(value)}." | 
 |                     ) | 
 |                 for i, tensor in enumerate(value): | 
 |                     original_i = getattr(self, f"original{i}") | 
 |                     if not isinstance(tensor, Tensor): | 
 |                         raise ValueError( | 
 |                             f"`right_inverse` must return a sequence of tensors. " | 
 |                             f"Got element {i} of type {type(tensor).__name__}" | 
 |                         ) | 
 |                     if original_i.dtype != tensor.dtype: | 
 |                         raise ValueError( | 
 |                             f"Tensor {i} returned by `right_inverse` has dtype {tensor.dtype} " | 
 |                             f"while `original{i}` has dtype {original_i.dtype}" | 
 |                         ) | 
 |                     original_i.set_(tensor) | 
 |  | 
 |     def forward(self) -> Tensor: | 
 |         if torch.jit.is_scripting(): | 
 |             raise RuntimeError('Parametrization is not working with scripting.') | 
 |         # Unpack the originals for the first parametrization | 
 |         if self.is_tensor: | 
 |             x = self[0](self.original) | 
 |         else: | 
 |             originals = (getattr(self, f"original{i}") for i in range(self.ntensors)) | 
 |             x = self[0](*originals) | 
 |         # It's not possible to call self[1:] here, so we have to be a bit more cryptic | 
 |         # Also we want to skip all non-integer keys | 
 |         curr_idx = 1 | 
 |         while hasattr(self, str(curr_idx)): | 
 |             x = self[curr_idx](x) | 
 |             curr_idx += 1 | 
 |         return x | 
 |  | 
 |  | 
 | def _inject_new_class(module: Module) -> None: | 
 |     r"""Sets up a module to be parametrized. | 
 |  | 
 |     This works by substituting the class of the module by a class | 
 |     that extends it to be able to inject a property | 
 |  | 
 |     Args: | 
 |         module (nn.Module): module into which to inject the property | 
 |     """ | 
 |     cls = module.__class__ | 
 |  | 
 |     def default_deepcopy(self, memo): | 
 |         # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class. | 
 |         obj = memo.get(id(self), None) | 
 |         if obj is not None: | 
 |             return obj | 
 |         replica = self.__new__(self.__class__) | 
 |         memo[id(self)] = replica | 
 |         replica.__dict__ = deepcopy(self.__dict__, memo) | 
 |         # Also save all slots if they exist. | 
 |         slots_to_save = copyreg._slotnames(self.__class__)  # type: ignore[attr-defined] | 
 |         for slot in slots_to_save: | 
 |             if hasattr(self, slot): | 
 |                 setattr(replica, slot, deepcopy(getattr(self, slot), memo)) | 
 |         return replica | 
 |  | 
 |     def getstate(self): | 
 |         raise RuntimeError( | 
 |             "Serialization of parametrized modules is only " | 
 |             "supported through state_dict(). See:\n" | 
 |             "https://pytorch.org/tutorials/beginner/saving_loading_models.html" | 
 |             "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training" | 
 |         ) | 
 |  | 
 |     dct = {"__getstate__": getstate} | 
 |     # We don't allow serialization of parametrized modules but should still allow deepcopying. | 
 |     # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists. | 
 |     if not hasattr(cls, "__deepcopy__"): | 
 |         dct["__deepcopy__"] = default_deepcopy  # type: ignore[assignment] | 
 |  | 
 |     param_cls = type( | 
 |         f"Parametrized{cls.__name__}", | 
 |         (cls,), | 
 |         dct, | 
 |     ) | 
 |  | 
 |     module.__class__ = param_cls | 
 |  | 
 |  | 
 | def _inject_property(module: Module, tensor_name: str) -> None: | 
 |     r"""Injects a property into module[tensor_name]. | 
 |  | 
 |     It assumes that the class in the module has already been modified from its | 
 |     original one using _inject_new_class and that the tensor under :attr:`tensor_name` | 
 |     has already been moved out | 
 |  | 
 |     Args: | 
 |         module (nn.Module): module into which to inject the property | 
 |         tensor_name (str): name of the name of the property to create | 
 |     """ | 
 |     # We check the precondition. | 
 |     # This should never fire if register_parametrization is correctly implemented | 
 |     assert not hasattr(module, tensor_name) | 
 |  | 
 |     @torch.jit.unused | 
 |     def get_cached_parametrization(parametrization) -> Tensor: | 
 |         global _cache | 
 |         key = (id(module), tensor_name) | 
 |         tensor = _cache.get(key) | 
 |         if tensor is None: | 
 |             tensor = parametrization() | 
 |             _cache[key] = tensor | 
 |         return tensor | 
 |  | 
 |     def get_parametrized(self) -> Tensor: | 
 |         if torch.jit.is_scripting(): | 
 |             raise RuntimeError('Parametrization is not working with scripting.') | 
 |         parametrization = self.parametrizations[tensor_name] | 
 |         if _cache_enabled: | 
 |             if torch.jit.is_scripting(): | 
 |                 # Scripting | 
 |                 raise RuntimeError('Caching is not implemented for scripting. ' | 
 |                                    'Either disable caching or avoid scripting.') | 
 |             elif torch._C._get_tracing_state() is not None: | 
 |                 # Tracing | 
 |                 raise RuntimeError('Cannot trace a model while caching parametrizations.') | 
 |             else: | 
 |                 return get_cached_parametrization(parametrization) | 
 |         else: | 
 |             # If caching is not active, this function just evaluates the parametrization | 
 |             return parametrization() | 
 |  | 
 |     def set_original(self, value: Tensor) -> None: | 
 |         if torch.jit.is_scripting(): | 
 |             raise RuntimeError('Parametrization is not working with scripting.') | 
 |         self.parametrizations[tensor_name].right_inverse(value) | 
 |  | 
 |     setattr(module.__class__, tensor_name, property(get_parametrized, set_original)) | 
 |  | 
 | def register_parametrization( | 
 |     module: Module, tensor_name: str, parametrization: Module, *, unsafe: bool = False, | 
 | ) -> Module: | 
 |     r"""Adds a parametrization to a tensor in a module. | 
 |  | 
 |     Assume that ``tensor_name="weight"`` for simplicity. When accessing ``module.weight``, | 
 |     the module will return the parametrized version ``parametrization(module.weight)``. | 
 |     If the original tensor requires a gradient, the backward pass will differentiate | 
 |     through :attr:`parametrization`, and the optimizer will update the tensor accordingly. | 
 |  | 
 |     The first time that a module registers a parametrization, this function will add an attribute | 
 |     ``parametrizations`` to the module of type :class:`~ParametrizationList`. | 
 |  | 
 |     The list of parametrizations on the tensor ``weight`` will be accessible under | 
 |     ``module.parametrizations.weight``. | 
 |  | 
 |     The original tensor will be accessible under | 
 |     ``module.parametrizations.weight.original``. | 
 |  | 
 |     Parametrizations may be concatenated by registering several parametrizations | 
 |     on the same attribute. | 
 |  | 
 |     The training mode of a registered parametrization is updated on registration | 
 |     to match the training mode of the host module | 
 |  | 
 |     Parametrized parameters and buffers have an inbuilt caching system that can be activated | 
 |     using the context manager :func:`cached`. | 
 |  | 
 |     A :attr:`parametrization` may optionally implement a method with signature | 
 |  | 
 |     .. code-block:: python | 
 |  | 
 |         def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]] | 
 |  | 
 |     This method is called on the unparametrized tensor when the first parametrization | 
 |     is registered to compute the initial value of the original tensor. | 
 |     If this method is not implemented, the original tensor will be just the unparametrized tensor. | 
 |  | 
 |     If all the parametrizations registered on a tensor implement `right_inverse` it is possible | 
 |     to initialize a parametrized tensor by assigning to it, as shown in the example below. | 
 |  | 
 |     It is possible for the first parametrization to depend on several inputs. | 
 |     This may be implemented returning a tuple of tensors from ``right_inverse`` | 
 |     (see the example implementation of a ``RankOne`` parametrization below). | 
 |  | 
 |     In this case, the unconstrained tensors are also located under ``module.parametrizations.weight`` | 
 |     with names ``original0``, ``original1``,... | 
 |  | 
 |     .. note:: | 
 |  | 
 |         If unsafe=False (default) both the forward and right_inverse methods will be called | 
 |         once to perform a number of consistency checks. | 
 |         If unsafe=True, then right_inverse will be called if the tensor is not parametrized, | 
 |         and nothing will be called otherwise. | 
 |  | 
 |     .. note:: | 
 |  | 
 |         In most situations, ``right_inverse`` will be a function such that | 
 |         ``forward(right_inverse(X)) == X`` (see | 
 |         `right inverse <https://en.wikipedia.org/wiki/Inverse_function#Right_inverses>`_). | 
 |         Sometimes, when the parametrization is not surjective, it may be reasonable | 
 |         to relax this. | 
 |  | 
 |     .. warning:: | 
 |  | 
 |         If a parametrization depends on several inputs, :func:`~register_parametrization` | 
 |         will register a number of new parameters. If such parametrization is registered | 
 |         after the optimizer is created, these new parameters will need to be added manually | 
 |         to the optimizer. See :meth:`torch.Optimizer.add_param_group`. | 
 |  | 
 |     Args: | 
 |         module (nn.Module): module on which to register the parametrization | 
 |         tensor_name (str): name of the parameter or buffer on which to register | 
 |             the parametrization | 
 |         parametrization (nn.Module): the parametrization to register | 
 |     Keyword args: | 
 |         unsafe (bool): a boolean flag that denotes whether the parametrization | 
 |             may change the dtype and shape of the tensor. Default: `False` | 
 |             Warning: the parametrization is not checked for consistency upon registration. | 
 |             Enable this flag at your own risk. | 
 |  | 
 |     Raises: | 
 |         ValueError: if the module does not have a parameter or a buffer named :attr:`tensor_name` | 
 |  | 
 |     Examples: | 
 |         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) | 
 |         >>> import torch | 
 |         >>> import torch.nn as nn | 
 |         >>> import torch.nn.utils.parametrize as P | 
 |         >>> | 
 |         >>> class Symmetric(nn.Module): | 
 |         >>>     def forward(self, X): | 
 |         >>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix | 
 |         >>> | 
 |         >>>     def right_inverse(self, A): | 
 |         >>>         return A.triu() | 
 |         >>> | 
 |         >>> m = nn.Linear(5, 5) | 
 |         >>> P.register_parametrization(m, "weight", Symmetric()) | 
 |         >>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric | 
 |         True | 
 |         >>> A = torch.rand(5, 5) | 
 |         >>> A = A + A.T   # A is now symmetric | 
 |         >>> m.weight = A  # Initialize the weight to be the symmetric matrix A | 
 |         >>> print(torch.allclose(m.weight, A)) | 
 |         True | 
 |  | 
 |         >>> class RankOne(nn.Module): | 
 |         >>>     def forward(self, x, y): | 
 |         >>>         # Form a rank 1 matrix multiplying two vectors | 
 |         >>>         return x.unsqueeze(-1) @ y.unsqueeze(-2) | 
 |         >>> | 
 |         >>>     def right_inverse(self, Z): | 
 |         >>>         # Project Z onto the rank 1 matrices | 
 |         >>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False) | 
 |         >>>         # Return rescaled singular vectors | 
 |         >>>         s0_sqrt = S[0].sqrt().unsqueeze(-1) | 
 |         >>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt | 
 |         >>> | 
 |         >>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne()) | 
 |         >>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item()) | 
 |         1 | 
 |  | 
 |     """ | 
 |     parametrization.train(module.training) | 
 |     if is_parametrized(module, tensor_name): | 
 |         # Correctness checks. | 
 |         # If A is the space of tensors with shape and dtype equal to module.weight | 
 |         # we check that parametrization.forward and parametrization.right_inverse are | 
 |         # functions from A to A | 
 |         if not unsafe: | 
 |             Y = getattr(module, tensor_name) | 
 |             X = parametrization(Y) | 
 |             if not isinstance(X, Tensor): | 
 |                 raise ValueError( | 
 |                     f"A parametrization must return a tensor. Got {type(X).__name__}." | 
 |                 ) | 
 |             if X.dtype != Y.dtype: | 
 |                 raise ValueError( | 
 |                     "Registering a parametrization may not change the dtype of the tensor, unless the `unsafe` flag is enabled.\n" | 
 |                     f"module.{tensor_name}.dtype: {Y.dtype}\n" | 
 |                     f"parametrization(module.{tensor_name}).dtype: {X.dtype}" | 
 |                 ) | 
 |             if X.shape != Y.shape: | 
 |                 raise ValueError( | 
 |                     "Registering a parametrization may not change the shape of the tensor, unless the `unsafe` flag is enabled.\n" | 
 |                     f"module.{tensor_name}.shape: {Y.shape}\n" | 
 |                     f"parametrization(module.{tensor_name}).shape: {X.shape}" | 
 |                 ) | 
 |             if hasattr(parametrization, "right_inverse"): | 
 |                 try: | 
 |                     Z = parametrization.right_inverse(X)  # type: ignore[operator] | 
 |                 except NotImplementedError: | 
 |                     pass | 
 |                 else: | 
 |                     if not isinstance(Z, Tensor): | 
 |                         raise ValueError( | 
 |                             f"parametrization.right_inverse must return a tensor. Got: {type(Z).__name__}" | 
 |                         ) | 
 |                     if Z.dtype != Y.dtype: | 
 |                         raise ValueError( | 
 |                             "The tensor returned by parametrization.right_inverse must have the same dtype " | 
 |                             f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" | 
 |                             f"module.{tensor_name}.dtype: {Y.dtype}\n" | 
 |                             f"returned dtype: {Z.dtype}" | 
 |                         ) | 
 |                     if Z.shape != Y.shape: | 
 |                         raise ValueError( | 
 |                             "The tensor returned by parametrization.right_inverse must have the same shape " | 
 |                             f"as module.{tensor_name}, unless the `unsafe` flag is enabled.\n" | 
 |                             f"module.{tensor_name}.shape: {Y.shape}\n" | 
 |                             f"returned shape: {Z.shape}" | 
 |                         ) | 
 |             # else right_inverse is assumed to be the identity | 
 |  | 
 |         # add the new parametrization to the parametrization list | 
 |         assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy | 
 |         module.parametrizations[tensor_name].append(parametrization) | 
 |         # If unsafe was True in previous parametrization, keep it enabled | 
 |         module.parametrizations[tensor_name].unsafe |= unsafe  # type: ignore[index, union-attr] | 
 |     elif tensor_name in module._buffers or tensor_name in module._parameters: | 
 |         # Set the parametrization mechanism | 
 |         # Fetch the original buffer or parameter | 
 |         original = getattr(module, tensor_name) | 
 |         # We create this early to check for possible errors | 
 |         parametrizations = ParametrizationList([parametrization], original, unsafe=unsafe) | 
 |         # Delete the previous parameter or buffer | 
 |         delattr(module, tensor_name) | 
 |         # If this is the first parametrization registered on the module, | 
 |         # we prepare the module to inject the property | 
 |         if not is_parametrized(module): | 
 |             # Change the class | 
 |             _inject_new_class(module) | 
 |             # Inject a ``ModuleDict`` into the instance under module.parametrizations | 
 |             module.parametrizations = ModuleDict() | 
 |         # Add a property into the class | 
 |         _inject_property(module, tensor_name) | 
 |         # Add a ParametrizationList | 
 |         assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy | 
 |         module.parametrizations[tensor_name] = parametrizations | 
 |     else: | 
 |         raise ValueError( | 
 |             f"Module '{module}' does not have a parameter, a buffer, or a " | 
 |             f"parametrized element with name '{tensor_name}'" | 
 |         ) | 
 |     return module | 
 |  | 
 |  | 
 | def is_parametrized(module: Module, tensor_name: Optional[str] = None) -> bool: | 
 |     r"""Returns ``True`` if module has an active parametrization. | 
 |  | 
 |     If the argument :attr:`tensor_name` is specified, returns ``True`` if | 
 |     ``module[tensor_name]`` is parametrized. | 
 |  | 
 |     Args: | 
 |         module (nn.Module): module to query | 
 |         name (str, optional): attribute in the module to query | 
 |             Default: ``None`` | 
 |     """ | 
 |     parametrizations = getattr(module, "parametrizations", None) | 
 |     if parametrizations is None or not isinstance(parametrizations, ModuleDict): | 
 |         return False | 
 |     if tensor_name is None: | 
 |         # Check that there is at least one parametrized buffer or Parameter | 
 |         return len(parametrizations) > 0 | 
 |     else: | 
 |         return tensor_name in parametrizations | 
 |  | 
 | def remove_parametrizations( | 
 |     module: Module, tensor_name: str, leave_parametrized: bool = True | 
 | ) -> Module: | 
 |     r"""Removes the parametrizations on a tensor in a module. | 
 |  | 
 |     - If ``leave_parametrized=True``, ``module[tensor_name]`` will be set to | 
 |       its current output. In this case, the parametrization shall not change the ``dtype`` | 
 |       of the tensor. | 
 |     - If ``leave_parametrized=False``, ``module[tensor_name]`` will be set to | 
 |       the unparametrised tensor in ``module.parametrizations[tensor_name].original``. | 
 |       This is only possible when the parametrization depends on just one tensor. | 
 |  | 
 |     Args: | 
 |         module (nn.Module): module from which remove the parametrization | 
 |         tensor_name (str): name of the parametrization to be removed | 
 |         leave_parametrized (bool, optional): leave the attribute :attr:`tensor_name` parametrized. | 
 |             Default: ``True`` | 
 |  | 
 |     Returns: | 
 |         Module: module | 
 |  | 
 |     Raises: | 
 |         ValueError: if ``module[tensor_name]`` is not parametrized | 
 |         ValueError: if ``leave_parametrized=False`` and the parametrization depends on several tensors | 
 |     """ | 
 |  | 
 |     if not is_parametrized(module, tensor_name): | 
 |         raise ValueError(f"Module {module} does not have a parametrization on {tensor_name}") | 
 |  | 
 |     # Fetch the original tensor | 
 |     assert isinstance(module.parametrizations, ModuleDict)  # Make mypy happy | 
 |     parametrizations = module.parametrizations[tensor_name] | 
 |     if parametrizations.is_tensor: | 
 |         original = parametrizations.original | 
 |         if leave_parametrized: | 
 |             with torch.no_grad(): | 
 |                 t = getattr(module, tensor_name) | 
 |             # We know they have the same dtype because we have checked this when registering the | 
 |             # parametrizations. As such, we can use set_ | 
 |             # We do this so that the parameter does not to change the id() | 
 |             # This way the user does not need to update the optimizer | 
 |             with torch.no_grad(): | 
 |                 if type(original) is torch.Tensor: | 
 |                     original.set_(t) | 
 |                 else: | 
 |                     try: | 
 |                         original.set_(t) | 
 |                     except RuntimeError as e: | 
 |                         # TODO: Fix this for tensor subclasses that are parameters: | 
 |                         # RuntimeError: set_storage is not allowed on a Tensor created from .data or .detach(). | 
 |                         raise RuntimeError("Calling remove_parametrizations() with leave_parametrized=True " | 
 |                                            "for a parameter that is an instance of a tensor subclass requires " | 
 |                                            "set_() to be implemented correctly for the tensor subclass. Either " | 
 |                                            "set leave_parametrized=False or provide a working implementation for " | 
 |                                            "set_() in the tensor subclass.") | 
 |     else: | 
 |         if leave_parametrized: | 
 |             # We cannot use no_grad because we need to know whether one or more | 
 |             # original tensors required grad | 
 |             t = getattr(module, tensor_name) | 
 |             # We'll have to trust the user to add it to the optimizer | 
 |             original = Parameter(t) if t.requires_grad else t | 
 |         else: | 
 |             raise ValueError("Cannot leave unparametrized (`leave_parametrized=False`) a tensor " | 
 |                              "that is parametrized in terms of a sequence of tensors.") | 
 |  | 
 |     # Delete the property that manages the parametrization | 
 |     delattr(module.__class__, tensor_name) | 
 |     # Delete the ParametrizationList | 
 |     del module.parametrizations[tensor_name] | 
 |  | 
 |     # Restore the parameter / buffer into the main class | 
 |     _register_parameter_or_buffer(module, tensor_name, original) | 
 |  | 
 |     # Roll back the parametrized class if no other buffer or parameter | 
 |     # is currently parametrized in this class | 
 |     if not is_parametrized(module): | 
 |         delattr(module, "parametrizations") | 
 |         # Restore class | 
 |         orig_cls = module.__class__.__bases__[0] | 
 |         module.__class__ = orig_cls | 
 |     return module | 
 |  | 
 | def type_before_parametrizations(module: Module) -> type: | 
 |     r"""Returns the module type before parametrizations were applied and if not, | 
 |     then it returns the module type. | 
 |  | 
 |     Args: | 
 |         module (nn.Module): module to get type of | 
 |     """ | 
 |     if is_parametrized(module): | 
 |         return module.__class__.__bases__[0] | 
 |     else: | 
 |         return type(module) | 
 |  | 
 | def transfer_parametrizations_and_params( | 
 |     from_module: Module, to_module: Module, tensor_name: Optional[str] = None | 
 | ) -> Module: | 
 |     r"""Transfers parametrizations and the parameters they parametrize from from_module | 
 |     to to_module. If tensor_name is specified, only transfers the specified parameter, otherwise | 
 |     transfers all parametrized parameters. If those parameters do not exist in to_module, it will create them. | 
 |     Does nothing if from_module is not parametrized. | 
 |  | 
 |     Args: | 
 |         from_module (nn.Module): module to transfer from | 
 |         to_module (nn.Module): module to transfer to | 
 |         tensor_name (str, optional): parameter to transfer | 
 |  | 
 |     Returns: | 
 |         Module: to_module | 
 |     """ | 
 |     if is_parametrized(from_module): | 
 |         assert isinstance(from_module.parametrizations, ModuleDict)  # for mypy | 
 |  | 
 |         # get list of all params or the single param to transfer | 
 |         parameters_to_transfer: Union[list, ModuleDict] = ( | 
 |             from_module.parametrizations if tensor_name is None else [tensor_name] | 
 |         ) | 
 |  | 
 |         assert hasattr(parameters_to_transfer, "__iter__")  # for mypy | 
 |         for parameter_name in parameters_to_transfer: | 
 |  | 
 |             # initialize the to-be-transfered param in to_module if it doesn't exist already | 
 |             if not hasattr(to_module, parameter_name): | 
 |                 setattr( | 
 |                     to_module, | 
 |                     parameter_name, | 
 |                     Parameter(getattr(from_module, parameter_name)), | 
 |                 ) | 
 |  | 
 |             # apply the params's parametrizations to to_module | 
 |             for param_func in from_module.parametrizations[parameter_name]: | 
 |                 register_parametrization(to_module, parameter_name, param_func) | 
 |             assert isinstance(to_module.parametrizations, ModuleDict)  # for mypy | 
 |  | 
 |             # make values match, original values can be stored in either original or | 
 |             # original0, original1..., need to check both cases | 
 |             if hasattr(from_module.parametrizations[parameter_name], "original"): | 
 |                 to_module.parametrizations[parameter_name].original = \ | 
 |                     from_module.parametrizations[parameter_name].original | 
 |             else: | 
 |                 num = 0 | 
 |                 orig_num = "original" + str(num) | 
 |                 # loop through each original# until all values have been set | 
 |                 while hasattr(from_module.parametrizations[parameter_name], orig_num): | 
 |                     setattr( | 
 |                         to_module.parametrizations[parameter_name], | 
 |                         orig_num, | 
 |                         getattr(from_module.parametrizations[parameter_name], orig_num), | 
 |                     ) | 
 |                     num = num + 1 | 
 |                     orig_num = "original" + str(num) | 
 |  | 
 |     return to_module |