| import torch |
| from ..utils import parametrize |
| from ..modules import Module |
| from .. import functional as F |
| |
| from typing import Optional |
| |
| class _SpectralNorm(Module): |
| def __init__( |
| self, |
| weight: torch.Tensor, |
| n_power_iterations: int = 1, |
| dim: int = 0, |
| eps: float = 1e-12 |
| ) -> None: |
| super().__init__() |
| ndim = weight.ndim |
| if dim >= ndim or dim < -ndim: |
| raise IndexError("Dimension out of range (expected to be in range of " |
| f"[-{ndim}, {ndim - 1}] but got {dim})") |
| |
| if n_power_iterations <= 0: |
| raise ValueError('Expected n_power_iterations to be positive, but ' |
| 'got n_power_iterations={}'.format(n_power_iterations)) |
| self.dim = dim if dim >= 0 else dim + ndim |
| self.eps = eps |
| if ndim > 1: |
| # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward) |
| self.n_power_iterations = n_power_iterations |
| weight_mat = self._reshape_weight_to_matrix(weight) |
| h, w = weight_mat.size() |
| |
| u = weight_mat.new_empty(h).normal_(0, 1) |
| v = weight_mat.new_empty(w).normal_(0, 1) |
| self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps)) |
| self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps)) |
| |
| # Start with u, v initialized to some reasonable values by performing a number |
| # of iterations of the power method |
| self._power_method(weight_mat, 15) |
| |
| def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor: |
| # Precondition |
| assert weight.ndim > 1 |
| |
| if self.dim != 0: |
| # permute dim to front |
| weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim)) |
| |
| return weight.flatten(1) |
| |
| @torch.autograd.no_grad() |
| def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None: |
| # See original note at torch/nn/utils/spectral_norm.py |
| # NB: If `do_power_iteration` is set, the `u` and `v` vectors are |
| # updated in power iteration **in-place**. This is very important |
| # because in `DataParallel` forward, the vectors (being buffers) are |
| # broadcast from the parallelized module to each module replica, |
| # which is a new module object created on the fly. And each replica |
| # runs its own spectral norm power iteration. So simply assigning |
| # the updated vectors to the module this function runs on will cause |
| # the update to be lost forever. And the next time the parallelized |
| # module is replicated, the same randomly initialized vectors are |
| # broadcast and used! |
| # |
| # Therefore, to make the change propagate back, we rely on two |
| # important behaviors (also enforced via tests): |
| # 1. `DataParallel` doesn't clone storage if the broadcast tensor |
| # is already on correct device; and it makes sure that the |
| # parallelized module is already on `device[0]`. |
| # 2. If the out tensor in `out=` kwarg has correct shape, it will |
| # just fill in the values. |
| # Therefore, since the same power iteration is performed on all |
| # devices, simply updating the tensors in-place will make sure that |
| # the module replica on `device[0]` will update the _u vector on the |
| # parallized module (by shared storage). |
| # |
| # However, after we update `u` and `v` in-place, we need to **clone** |
| # them before using them to normalize the weight. This is to support |
| # backproping through two forward passes, e.g., the common pattern in |
| # GAN training: loss = D(real) - D(fake). Otherwise, engine will |
| # complain that variables needed to do backward for the first forward |
| # (i.e., the `u` and `v` vectors) are changed in the second forward. |
| |
| # Precondition |
| assert weight_mat.ndim > 1 |
| for _ in range(n_power_iterations): |
| # Spectral norm of weight equals to `u^T W v`, where `u` and `v` |
| # are the first left and right singular vectors. |
| # This power iteration produces approximations of `u` and `v`. |
| self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type] |
| dim=0, eps=self.eps, out=self._u) # type: ignore[has-type] |
| self._v = F.normalize(torch.mv(weight_mat.t(), self._u), |
| dim=0, eps=self.eps, out=self._v) # type: ignore[has-type] |
| # See above on why we need to clone |
| self._u = self._u.clone(memory_format=torch.contiguous_format) |
| self._v = self._v.clone(memory_format=torch.contiguous_format) |
| |
| def forward(self, weight: torch.Tensor) -> torch.Tensor: |
| if weight.ndim == 1: |
| # Faster and more exact path, no need to approximate anything |
| return F.normalize(weight, dim=0, eps=self.eps) |
| else: |
| weight_mat = self._reshape_weight_to_matrix(weight) |
| if self.training: |
| self._power_method(weight_mat, self.n_power_iterations) |
| # The proper way of computing this should be through F.bilinear, but |
| # it seems to have some efficiency issues: |
| # https://github.com/pytorch/pytorch/issues/58093 |
| sigma = torch.dot(self._u, torch.mv(weight_mat, self._v)) |
| return weight / sigma |
| |
| def right_inverse(self, value: torch.Tensor) -> torch.Tensor: |
| # we may want to assert here that the passed value already |
| # satisfies constraints |
| return value |
| |
| |
| def spectral_norm(module: Module, |
| name: str = 'weight', |
| n_power_iterations: int = 1, |
| eps: float = 1e-12, |
| dim: Optional[int] = None) -> Module: |
| r"""Applies spectral normalization to a parameter in the given module. |
| |
| .. math:: |
| \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, |
| \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} |
| |
| When applied on a vector, it simplifies to |
| |
| .. math:: |
| \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2} |
| |
| Spectral normalization stabilizes the training of discriminators (critics) |
| in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant |
| of the model. :math:`\sigma` is approximated performing one iteration of the |
| `power method`_ every time the weight is accessed. If the dimension of the |
| weight tensor is greater than 2, it is reshaped to 2D in power iteration |
| method to get spectral norm. |
| |
| |
| See `Spectral Normalization for Generative Adversarial Networks`_ . |
| |
| .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration |
| .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 |
| |
| .. note:: |
| This function is implemented using the new parametrization functionality |
| in :func:`torch.nn.utils.parametrize.register_parametrization`. It is a |
| reimplementation of :func:`torch.nn.utils.spectral_norm`. |
| |
| .. note:: |
| When this constraint is registered, the singular vectors associated to the largest |
| singular value are estimated rather than sampled at random. These are then updated |
| performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor |
| is accessed with the module on `training` mode. |
| |
| .. note:: |
| If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`, |
| is in training mode on removal, it will perform another power iteration. |
| If you'd like to avoid this iteration, set the module to eval mode |
| before its removal. |
| |
| Args: |
| module (nn.Module): containing module |
| name (str, optional): name of weight parameter |
| n_power_iterations (int, optional): number of power iterations to |
| calculate spectral norm |
| eps (float, optional): epsilon for numerical stability in |
| calculating norms |
| dim (int, optional): dimension corresponding to number of outputs, |
| the default is ``0``, except for modules that are instances of |
| ConvTranspose{1,2,3}d, when it is ``1`` |
| |
| Returns: |
| The original module with a new parametrization registered to the specified |
| weight |
| |
| Example:: |
| |
| >>> snm = spectral_norm(nn.Linear(20, 40)) |
| >>> snm |
| ParametrizedLinear( |
| in_features=20, out_features=40, bias=True |
| (parametrizations): ModuleDict( |
| (weight): ParametrizationList( |
| (0): _SpectralNorm() |
| ) |
| ) |
| ) |
| >>> torch.linalg.matrix_norm(snm.weight, 2) |
| tensor(1.0000, grad_fn=<CopyBackwards>) |
| """ |
| if not hasattr(module, name): |
| raise ValueError( |
| "Module '{}' has no attribute with name '{}'".format(module, name) |
| ) |
| # getattr should get the correct parametrized weight if there |
| # is already an parametrization registered |
| weight = getattr(module, name) |
| |
| if dim is None: |
| if isinstance(module, (torch.nn.ConvTranspose1d, |
| torch.nn.ConvTranspose2d, |
| torch.nn.ConvTranspose3d)): |
| dim = 1 |
| else: |
| dim = 0 |
| parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps)) |
| return module |