| # Owner(s): ["module: nn"] |
| from copy import deepcopy |
| from itertools import product |
| import pickle |
| |
| import torch |
| |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.init as init |
| import torch.nn.utils.parametrize as parametrize |
| from torch.nn import Parameter |
| from torch.testing._internal.common_utils import run_tests, skipIfNoLapack, \ |
| TemporaryFileName, instantiate_parametrized_tests, set_default_dtype |
| from torch.testing._internal.common_cuda import TEST_MULTIGPU |
| from torch.testing._internal.common_nn import NNTestCase |
| from torch.testing._internal.common_utils import gradcheck |
| |
| |
| class TestNNParametrization(NNTestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| # torch/nn/utils/parametrize |
| @skipIfNoLapack |
| def test_register_and_remove_parametrization(self): |
| r"""Test that it is possible to add a few parametrizations |
| on a parameter or a buffer and that removing them restores the initial state |
| It also tests that backpropagating through them works as expected |
| """ |
| # Define a couple matrix parametrizations |
| class Skew(nn.Module): |
| def forward(self, X): |
| X = X.tril(-1) |
| return X - X.T |
| |
| class Orthogonal(nn.Module): |
| def forward(self, X): |
| # Cayley map |
| # If X is skew-symmetric it returns an orthogonal matrix |
| Id = torch.eye(X.size(0), device=X.device) |
| # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous |
| # and autograd raises a performance warning. |
| # This happens when we remove the parametrization with leave_parametrized=True, |
| # which does a set_ with a non-contiguous tensor while the gradient is contiguous |
| return torch.linalg.solve(Id + X, Id - X).contiguous() |
| |
| class Resize(nn.Module): |
| def forward(self, X): |
| return X[[0]] |
| |
| class NoResize(nn.Module): |
| def forward(self, X): |
| return X |
| |
| # Define a couple vector parametrizations |
| class FirstZero(nn.Module): |
| def forward(self, x): |
| return torch.cat([x.new_zeros(1), x[1:]]) |
| |
| class LastZero(nn.Module): |
| def forward(self, x): |
| return torch.cat([x[:-1], x.new_zeros(1)]) |
| |
| model = nn.Linear(8, 8) |
| initial_weight_id = id(model.weight) |
| initial_bias_id = id(model.bias) |
| initial_model = deepcopy(model) |
| |
| # Test unsafe flag |
| with self.assertRaisesRegex(ValueError, "Registering a parametrization may not change the shape of the tensor"): |
| parametrize.register_parametrization(model, "weight", Resize()) # default unsafe = False |
| model(torch.ones(8, 8)) |
| |
| # One parametrization with unsafe=True |
| parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertNotIn("weight", model._parameters) |
| A = model.weight |
| self.assertTrue(A.shape[0] == 1) |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.weight, initial_model.weight) |
| self.assertEqual(id(model.weight), initial_weight_id) |
| self.assertEqual(model.__class__, nn.Linear) |
| |
| # Two parametrizations with unsafe=True |
| parametrize.register_parametrization(model, "weight", Resize(), unsafe=True) |
| parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False) |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertNotIn("weight", model._parameters) |
| A = model.weight |
| self.assertTrue(A.shape[0] == 1) |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.weight, initial_model.weight) |
| self.assertEqual(id(model.weight), initial_weight_id) |
| self.assertEqual(model.__class__, nn.Linear) |
| |
| # Test unsafe flag doesn't change expected behavior |
| parametrize.register_parametrization(model, "weight", Skew(), unsafe=True) |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertNotIn("weight", model._parameters) |
| # Result should be skew-symmetric |
| A = model.weight |
| self.assertEqual(A, -A.T) |
| # Remove and check consistency |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.weight, initial_model.weight) |
| self.assertEqual(id(model.weight), initial_weight_id) |
| self.assertEqual(model.__class__, nn.Linear) |
| |
| # Test one parametrization |
| parametrize.register_parametrization(model, "weight", Skew()) |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertNotIn("weight", model._parameters) |
| # Result should be skew-symmetric |
| A = model.weight |
| self.assertEqual(A, -A.T) |
| # Remove and check consistency |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.weight, initial_model.weight) |
| self.assertEqual(id(model.weight), initial_weight_id) |
| self.assertEqual(model.__class__, nn.Linear) |
| |
| # Test two parametrizations at the same time and removing them |
| parametrize.register_parametrization(model, "weight", Skew()) |
| parametrize.register_parametrization(model, "weight", Orthogonal()) |
| # Result should be orthogonal |
| X = model.weight |
| Id = torch.eye(X.size(0), device=X.device) |
| self.assertEqual(X.T @ X, Id) |
| # Structure tests |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertIn("weight", model.parametrizations) |
| self.assertNotIn("weight", model._parameters) |
| # Remove |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertEqual(model.weight, initial_model.weight) |
| self.assertEqual(id(model.weight), initial_weight_id) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.__class__, nn.Linear) |
| |
| # Add everything |
| parametrize.register_parametrization(model, "weight", Skew()) |
| parametrize.register_parametrization(model, "weight", Orthogonal()) |
| parametrize.register_parametrization(model, "bias", FirstZero()) |
| parametrize.register_parametrization(model, "bias", LastZero()) |
| |
| # Basic tests |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertTrue(parametrize.is_parametrized(model, "bias")) |
| self.assertEqual(model.bias[0].item(), 0.) |
| self.assertEqual(model.bias[-1].item(), 0.) |
| self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happpened |
| # Should not throw |
| |
| sgd = torch.optim.SGD(model.parameters(), lr=0.01) |
| |
| weight_copy = model.weight.clone() |
| bias_copy = model.bias.clone() |
| sgd.zero_grad() |
| (model.weight.T @ model.bias).sum().backward() |
| sgd.step() |
| self.assertNotEqual(model.weight, weight_copy) |
| self.assertNotEqual(model.bias, bias_copy) |
| |
| # Remove first parametrization. |
| # Check that the model is still parametrized and so is the second parameter |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertTrue(parametrize.is_parametrized(model)) # Still parametrized |
| self.assertFalse(parametrize.is_parametrized(model, "weight")) # Parametrization removed |
| self.assertTrue(parametrize.is_parametrized(model, "bias")) # Still parametrized |
| self.assertEqual(model.bias[0].item(), 0.) # Still parametrized |
| self.assertEqual(model.bias[-1].item(), 0.) # Still parametrized |
| self.assertNotEqual(model.weight, initial_model.weight) # Has been updated |
| self.assertEqual(id(model.weight), initial_weight_id) # Keeps the same id |
| self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happened |
| # Should not throw |
| weight_copy = model.weight.clone() |
| bias_copy = model.bias.clone() |
| sgd.zero_grad() |
| (model.weight.T @ model.bias).sum().backward() |
| sgd.step() |
| self.assertNotEqual(model.weight, weight_copy) |
| self.assertNotEqual(model.bias, bias_copy) |
| |
| # Remove the second parametrization. |
| # Check that the module is not parametrized |
| parametrize.remove_parametrizations(model, "bias", leave_parametrized=False) |
| self.assertFalse(parametrize.is_parametrized(model)) # Not parametrized |
| self.assertNotEqual(model.bias, initial_model.bias) # Has been updated |
| self.assertNotEqual(model.bias[0].item(), 0.) # Not parametrized |
| self.assertNotEqual(model.bias[-1].item(), 0.) # Not parametrized |
| self.assertEqual(id(model.bias), initial_bias_id) # Keeps the same id |
| self.assertFalse(hasattr(model, "parametrizations")) # Not parametrized the module |
| self.assertEqual(model.__class__, nn.Linear) # Resores the previous class |
| self.assertEqual(len(list(model.parameters())), 2) # Nothing weird has happeed |
| |
| # Should not throw things are updated |
| weight_copy = model.weight.clone() |
| bias_copy = model.bias.clone() |
| sgd.zero_grad() |
| (model.weight.T @ model.bias).sum().backward() |
| sgd.step() |
| self.assertNotEqual(model.weight, weight_copy) |
| self.assertNotEqual(model.bias, bias_copy) |
| |
| # Test leave_parametrized=True |
| for _ in range(2): |
| parametrize.register_parametrization(model, "weight", Skew()) |
| parametrize.register_parametrization(model, "weight", Orthogonal()) |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) |
| # We didn't change the dtype nor had multiple inputs, so the id should be the same |
| self.assertEqual(id(model.weight), initial_weight_id) |
| self.assertEqual(id(model.bias), initial_bias_id) |
| |
| # Should not throw. Things are updated |
| weight_copy = model.weight.clone() |
| bias_copy = model.bias.clone() |
| sgd.zero_grad() |
| (model.weight.T @ model.bias).sum().backward() |
| sgd.step() |
| self.assertNotEqual(model.weight, weight_copy) |
| self.assertNotEqual(model.bias, bias_copy) |
| |
| def test_register_and_remove_nested_parametrization(self): |
| r"""Test that it is possible to nest the parametrizations |
| meaning that the original param is parametrized again |
| """ |
| class Skew(nn.Module): |
| def forward(self, X): |
| X = X.tril(-1) |
| return X - X.T |
| |
| model = nn.Linear(8, 8) |
| # Add top level parametrization |
| parametrize.register_parametrization(model, "weight", Skew()) |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertNotIn("weight", model._parameters) |
| # Result should be skew-symmetric |
| A = model.weight |
| self.assertEqual(A, -A.T) |
| |
| # Add nested parametrization |
| param_mod = model.parametrizations.weight |
| self.assertFalse(hasattr(param_mod, "parametrizations")) |
| self.assertFalse(parametrize.is_parametrized(param_mod)) |
| self.assertFalse(parametrize.is_parametrized(param_mod, "original")) |
| |
| parametrize.register_parametrization(param_mod, "original", Skew()) |
| self.assertTrue(hasattr(param_mod, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(param_mod)) |
| self.assertTrue(parametrize.is_parametrized(param_mod, "original")) |
| self.assertNotIn("original", param_mod._parameters) |
| # Result should be skew-symmetric |
| A = param_mod.original |
| self.assertEqual(A, -A.T) |
| |
| # Remove nested param and check consistency |
| parametrize.remove_parametrizations(param_mod, "original", leave_parametrized=False) |
| self.assertFalse(hasattr(param_mod, "parametrizations")) |
| self.assertEqual(param_mod.__class__, parametrize.ParametrizationList) |
| |
| # Remove top level and check consistency |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.__class__, nn.Linear) |
| |
| def test_register_and_remove_buffer_parametrization(self): |
| r"""Test that it is possible to add and remove parametrizations on buffers""" |
| # Define a couple vector parametrizations |
| class FirstZero(nn.Module): |
| def forward(self, x): |
| return torch.cat([x.new_zeros(1), x[1:]]) |
| |
| class LastZero(nn.Module): |
| def forward(self, x): |
| return torch.cat([x[:-1], x.new_zeros(1)]) |
| |
| model = nn.Linear(8, 8) |
| |
| # Instantiate parametrizations on buffers. It should work as expected |
| delattr(model, "bias") |
| model.register_buffer("bias", torch.ones(8)) |
| parametrize.register_parametrization(model, "bias", FirstZero()) |
| parametrize.register_parametrization(model, "bias", LastZero()) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "bias")) |
| self.assertEqual(model.bias[0].item(), 0.) |
| self.assertEqual(model.bias[-1].item(), 0.) |
| self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) |
| self.assertEqual(len(list(model.parameters())), 1) |
| |
| # Remove parametrizations on buffers. It should work as expected |
| parametrize.remove_parametrizations(model, "bias", leave_parametrized=True) |
| self.assertFalse(parametrize.is_parametrized(model)) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertEqual(model.bias[0].item(), 0.) |
| self.assertEqual(model.bias[-1].item(), 0.) |
| self.assertTrue((model.bias[1:-1] == torch.ones(6)).all()) |
| self.assertEqual(len(list(model.parameters())), 1) |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| @skipIfNoLapack |
| def test_serialization_parametrization(self): |
| r"""Test that it is possible to serialize a parametrized model via state_dict""" |
| # A stateful parametrization |
| class Orthogonal(nn.Module): |
| def __init__(self, n): |
| super().__init__() |
| self.register_buffer("id", torch.eye(n)) |
| self.register_buffer("B", torch.empty(n, n)) |
| init.orthogonal_(self.B) |
| |
| def forward(self, X): |
| A = X.triu(1) |
| A = A - A.T |
| return self.B @ torch.linalg.solve(self.id + A, self.id - A) |
| |
| def get_model(): |
| model = torch.nn.Sequential( |
| torch.nn.Linear(5, 5), |
| torch.nn.ReLU(), |
| torch.nn.Linear(5, 1), |
| ) |
| |
| parametrize.register_parametrization(model[0], "weight", Orthogonal(5)) |
| return model |
| |
| model = get_model() |
| |
| prev_weight = model[0].weight |
| prev_B = model[0].parametrizations.weight[0].B |
| |
| new_model = get_model() |
| with TemporaryFileName() as fname: |
| torch.save(model.state_dict(), fname) |
| new_model.load_state_dict(torch.load(fname)) |
| |
| # Integrity tests |
| self.assertTrue(parametrize.is_parametrized(new_model[0], "weight")) |
| self.assertEqual(prev_weight, new_model[0].weight) |
| self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B) |
| |
| # Trying to save the whole parametrized model raises |
| with self.assertRaisesRegex(RuntimeError, "state_dict"): |
| with TemporaryFileName() as fname: |
| torch.save(model, fname) |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| @skipIfNoLapack |
| def test_initialization_parametrization(self): |
| r"""Test that it is possible to initialize a parametrization when it |
| implements a `right_inverse` method |
| """ |
| class Skew(nn.Module): |
| def forward(self, X): |
| A = X.triu(1) |
| return A - A.T |
| |
| def is_skew(self, A): |
| return torch.allclose(A, -A.T, atol=1e-6) |
| |
| def right_inverse(self, X): |
| if not self.is_skew(X): |
| raise ValueError("The matrix is not skew-symmetric.") |
| return X.triu(1) |
| |
| # Implements a Cayley map where right_inverse is not quite the inverse of forward |
| class Orthogonal(nn.Module): |
| def __init__(self, n): |
| super().__init__() |
| self.register_buffer("B", torch.eye(n)) |
| |
| def forward(self, X): |
| Id = torch.eye(X.size(0)) |
| return self.B @ torch.linalg.solve(Id + X, Id - X) |
| |
| def is_orthogonal(self, X): |
| Id = torch.eye(X.size(0)) |
| return torch.allclose(X.T @ X, Id, atol=1e-4) |
| |
| def right_inverse(self, X): |
| if not self.is_orthogonal(X): |
| raise ValueError("The input is not orthogonal.") |
| # cayley(0) == Id, so B @ cayley(0) == B |
| self.B = X |
| return torch.zeros_like(X) |
| |
| N = 5 |
| model = nn.Linear(N, N) |
| # Register the skew-symmetric constraint. The result is now skew-symmetric |
| skew = Skew() |
| # Make the weight skew-symmetric before registering the parametrization |
| with torch.no_grad(): |
| model.weight.set_(skew(model.weight)) |
| parametrize.register_parametrization(model, "weight", skew) |
| X = torch.rand(N, N) |
| # X is not skew-symmetric, so it throws an error |
| with self.assertRaises(ValueError): |
| model.weight = X |
| # Make X skew-symmetric |
| X = X - X.T |
| model.weight = X |
| self.assertEqual(model.parametrizations.weight.original, X.triu(1)) |
| self.assertEqual(model.weight, X) |
| |
| # Having several parametrizations registered should work in the same way |
| parametrize.register_parametrization(model, "weight", Orthogonal(N)) |
| # Register now the Cayley map. The result is now orthogonal |
| X = torch.rand(N, N) |
| # X is not orthogonal, so it throws an error |
| with self.assertRaises(ValueError): |
| model.weight = X |
| init.orthogonal_(X) |
| model.weight = X |
| self.assertEqual(model.weight, X) |
| self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X)) |
| |
| def test_errors_unparametrized_tensor_parametrization(self): |
| # Test errors when registering a parametrization on an unparametrized tensor |
| module = nn.Linear(3, 4) |
| weight_init = module.weight.clone() |
| |
| class Identity(nn.Module): |
| def forward(self, x): |
| return x |
| |
| # Register a parametrization on a non-existing parameter throws |
| with self.assertRaisesRegex(ValueError, "does not have a parameter"): |
| parametrize.register_parametrization(module, "foo", Identity()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # Removing parametrizations from an unparametrized tensor throws |
| with self.assertRaisesRegex(ValueError, "does not have a parametrization"): |
| parametrize.remove_parametrizations(module, "bias") |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # A correct parametrization with several outputs |
| class Sum(nn.Module): |
| def forward(self, x, y): |
| return x + y |
| |
| def right_inverse(self, z): |
| return z, torch.zeros_like(z) |
| |
| parametrize.register_parametrization(module, "weight", Sum()) |
| # Cannot remove a parametrization with several outputs with `leave_parametrized=False` |
| with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): |
| parametrize.remove_parametrizations(module, "weight", leave_parametrized=False) |
| parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) |
| |
| # A parametrization with an incorrect number of outputs |
| class WrongNumberParams(nn.Module): |
| def forward(self, x, y, z): |
| return x + y + z |
| |
| def right_inverse(self, w): |
| return w, torch.zeros_like(w) |
| |
| # Makes param(*param.right_inverse(X)) fail |
| with self.assertRaisesRegex(TypeError, "positional argument"): |
| parametrize.register_parametrization(module, "weight", WrongNumberParams()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor] |
| class WrongRightInverse(Identity): |
| def right_inverse(self, z): |
| return None |
| |
| # right_inverse should return a Tensor or a Sequence[Tensor] |
| with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"): |
| parametrize.register_parametrization(module, "weight", WrongRightInverse()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # If it's a sequence, it must to be a sequence of tensors |
| class WrongRightInverseSequence(nn.Module): |
| def forward(self, x, y): |
| return x |
| |
| def right_inverse(self, z): |
| return None, z |
| |
| with self.assertRaisesRegex(ValueError, "of the sequence with type"): |
| parametrize.register_parametrization(module, "weight", WrongRightInverseSequence()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # A parametrization from one tensor to one tensor that changes the dtype |
| class ChangeDtypeInverse(nn.Module): |
| def forward(self, x): |
| return x.float() |
| |
| def right_inverse(self, w): |
| return w.bool() |
| |
| # For parametrizations that return one tensor, right_inverse may not change the dtype |
| with self.assertRaisesRegex(ValueError, "outputs one tensor, it may not change the dtype"): |
| parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # Doesn't return a tensor |
| class NotTensor(nn.Module): |
| def forward(self, x): |
| return 2 |
| |
| # Forward must return a tensor |
| with self.assertRaisesRegex(ValueError, "must return a tensor"): |
| parametrize.register_parametrization(module, "weight", NotTensor()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # A parametrization from one tensor to one tensor that changes the dtype |
| class ChangeDtype(nn.Module): |
| def forward(self, x): |
| return x.bool() |
| |
| # forward should not change the initial dtype |
| with self.assertRaisesRegex(ValueError, "may not change the dtype"): |
| parametrize.register_parametrization(module, "weight", ChangeDtype()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # Change shape |
| class ChangeShape(nn.Module): |
| def forward(self, x): |
| return x[:-1] |
| |
| # forward should not change the original shape |
| with self.assertRaisesRegex(ValueError, "may not change the shape"): |
| parametrize.register_parametrization(module, "weight", ChangeShape()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # Many to one that changes dtype |
| class ChangeDtypeMulti(nn.Module): |
| def forward(self, x, y): |
| return (x + y).bool() |
| |
| def right_inverse(self, w): |
| return w, w + 1 |
| |
| # forward should not change the original shape even for parametrizations with many inputs |
| with self.assertRaisesRegex(ValueError, "may not change the dtype"): |
| parametrize.register_parametrization(module, "weight", ChangeDtypeMulti()) |
| self.assertFalse(parametrize.is_parametrized(module)) |
| |
| # Returning a sequence of size one, although weird, it's correct |
| class SequenceLen1(nn.Module): |
| def forward(self, x): |
| return x |
| |
| def right_inverse(self, w): |
| return (w,) |
| |
| parametrize.register_parametrization(module, "weight", SequenceLen1()) |
| self.assertTrue(hasattr(module.parametrizations.weight, "original0")) |
| self.assertFalse(hasattr(module.parametrizations.weight, "original1")) |
| _ = module.weight # Does not throw |
| self.assertTrue(parametrize.is_parametrized(module)) |
| parametrize.remove_parametrizations(module, "weight", leave_parametrized=True) |
| |
| # None of the operations above should have altered the weight |
| self.assertFalse(parametrize.is_parametrized(module)) |
| self.assertEqual(module.weight, weight_init) |
| |
| def test_errors_parametrized_tensor_parametrization(self): |
| # Test errors when registering a parametrization on a parametrized tensor |
| |
| class Identity(nn.Module): |
| def forward(self, x): |
| return x |
| |
| module = nn.Linear(3, 4) |
| parametrize.register_parametrization(module, "weight", Identity()) |
| |
| # Has to return a tensor |
| class WrongReturn(nn.Module): |
| def forward(self, x): |
| return x, x |
| |
| with self.assertRaisesRegex(ValueError, "must return a tensor"): |
| parametrize.register_parametrization(module, "weight", WrongReturn()) |
| self.assertTrue(parametrize.is_parametrized(module)) |
| self.assertEqual(len(module.parametrizations.weight), 1) |
| self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) |
| |
| # Cannot change dtype |
| class ChangeDtype(nn.Module): |
| def forward(self, x): |
| return x.bool() |
| |
| with self.assertRaisesRegex(ValueError, "may not change the dtype"): |
| parametrize.register_parametrization(module, "weight", ChangeDtype()) |
| self.assertTrue(parametrize.is_parametrized(module)) |
| self.assertEqual(len(module.parametrizations.weight), 1) |
| self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) |
| |
| # Cannot change shape |
| class ChangeShape(nn.Module): |
| def forward(self, x): |
| return x[:-1] |
| |
| with self.assertRaisesRegex(ValueError, "may not change the shape"): |
| parametrize.register_parametrization(module, "weight", ChangeShape()) |
| self.assertTrue(parametrize.is_parametrized(module)) |
| self.assertEqual(len(module.parametrizations.weight), 1) |
| self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) |
| |
| # The following checks are mostly due to bugs in the code of the parametrization |
| |
| # right_inverse has to return a tensor |
| class WrongReturnInverse(Identity): |
| def right_inverse(self, x): |
| return x, x |
| |
| with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"): |
| parametrize.register_parametrization(module, "weight", WrongReturnInverse()) |
| self.assertTrue(parametrize.is_parametrized(module)) |
| self.assertEqual(len(module.parametrizations.weight), 1) |
| self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) |
| |
| # Cannot change dtype |
| class ChangeDtypeInverse(Identity): |
| def right_inverse(self, x): |
| return x.bool() |
| |
| with self.assertRaisesRegex(ValueError, "must have the same dtype"): |
| parametrize.register_parametrization(module, "weight", ChangeDtypeInverse()) |
| self.assertTrue(parametrize.is_parametrized(module)) |
| self.assertEqual(len(module.parametrizations.weight), 1) |
| self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) |
| |
| # Cannot change shape |
| class ChangeShapeInverse(Identity): |
| def right_inverse(self, x): |
| return x[:-1] |
| |
| with self.assertRaisesRegex(ValueError, "must have the same shape"): |
| parametrize.register_parametrization(module, "weight", ChangeShapeInverse()) |
| self.assertTrue(parametrize.is_parametrized(module)) |
| self.assertEqual(len(module.parametrizations.weight), 1) |
| self.assertTrue(isinstance(module.parametrizations.weight[0], Identity)) |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| @skipIfNoLapack |
| def test_multiple_inputs_parametrization(self): |
| # A parametrization with several outputs |
| class RankOne(nn.Module): |
| def forward(self, x, y): |
| # Form a rank-1 matrix from a pair of vectors |
| return x.unsqueeze(-1) @ y.unsqueeze(-2) |
| |
| def right_inverse(self, Y): |
| # We project the given matrix onto the rank 1 matrices |
| U, S, Vh = torch.linalg.svd(Y, full_matrices=False) |
| # S is ordered in a decreasing way. |
| s0_sqrt = S[0].sqrt().unsqueeze(-1) |
| return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt |
| |
| # Simple parametrisation |
| class Double(nn.Module): |
| def forward(self, x): |
| return 2.0 * x |
| |
| def right_inverse(self, w): |
| return 0.5 * w |
| |
| model = nn.Linear(3, 3) |
| # Test one parametrization |
| parametrize.register_parametrization(model, "weight", RankOne()) |
| self.assertTrue(hasattr(model, "parametrizations")) |
| self.assertTrue(parametrize.is_parametrized(model)) |
| self.assertTrue(parametrize.is_parametrized(model, "weight")) |
| self.assertTrue(hasattr(model.parametrizations.weight, "original0")) |
| self.assertIn("original0", model.parametrizations.weight._parameters) |
| self.assertTrue(hasattr(model.parametrizations.weight, "original1")) |
| self.assertIn("original1", model.parametrizations.weight._parameters) |
| self.assertFalse(parametrize.is_parametrized(model, "bias")) |
| self.assertNotIn("weight", model._parameters) |
| # Result should be rank 1 |
| self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) |
| |
| with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): |
| # Cannot remove a parametrization with multiple inputs and not leave it parametrized |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| # Remove parametrization and check consistency |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.__class__, nn.Linear) |
| self.assertFalse(parametrize.is_parametrized(model)) |
| self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) |
| self.assertIn("weight", model._parameters) |
| |
| # Registering parametrizations with one input on top of one with multiple inputs should work |
| init_weight = model.weight.clone() |
| parametrize.register_parametrization(model, "weight", RankOne()) |
| # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix |
| self.assertEqual(init_weight, model.weight) |
| parametrize.register_parametrization(model, "weight", Double()) |
| # The matrix now is twice the initial matrix |
| self.assertEqual(2.0 * init_weight, model.weight) |
| # Multiplying by a scalar does not change the rank |
| self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) |
| |
| # The model has now three parameters |
| self.assertEqual(len(list(model.parameters())), 3) |
| |
| sgd = torch.optim.SGD(model.parameters(), lr=0.1) |
| |
| # Test backward. Should not throw |
| for _ in range(2): |
| sgd.zero_grad() |
| loss = (model.weight.T @ model.bias).sum() |
| loss.backward() |
| sgd.step() |
| |
| # Same drill as before, removing should work as expected |
| with self.assertRaisesRegex(ValueError, "leave_parametrized=False"): |
| # Cannot remove a parametrization with multiple inputs and not leave it parametrized |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=False) |
| # Remove parametrization and check consistency |
| parametrize.remove_parametrizations(model, "weight", leave_parametrized=True) |
| self.assertFalse(hasattr(model, "parametrizations")) |
| self.assertEqual(model.__class__, nn.Linear) |
| self.assertFalse(parametrize.is_parametrized(model)) |
| self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1) |
| self.assertIn("weight", model._parameters) |
| |
| # The model has now two parameters |
| self.assertEqual(len(list(model.parameters())), 2) |
| |
| # Test backward. Should not throw |
| sgd = torch.optim.SGD(model.parameters(), lr=0.1) |
| for _ in range(2): |
| sgd.zero_grad() |
| loss = (model.weight.T @ model.bias).sum() |
| loss.backward() |
| sgd.step() |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| @skipIfNoLapack |
| def test_caching_parametrization(self): |
| r"""Test the caching system of a parametrization""" |
| # Define a couple matrix parametrizations |
| class Skew(nn.Module): |
| def forward(self, X): |
| X = X.tril(-1) |
| return X - X.T |
| |
| class Orthogonal(nn.Module): |
| def forward(self, X): |
| Id = torch.eye(X.size(0), device=X.device) |
| return torch.linalg.solve(Id + X, Id - X) |
| |
| model = nn.Linear(5, 5) |
| parametrize.register_parametrization(model, "weight", Skew()) |
| parametrize.register_parametrization(model, "weight", Orthogonal()) |
| |
| # Test that the caching system works |
| with parametrize.cached(): |
| X = model.weight |
| Y = model.weight |
| self.assertEqual(id(X), id(Y)) |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| @skipIfNoLapack |
| def test_caching_parametrization_with_transfer_parametrizations_and_params(self): |
| r"""Test that transferring parametrizations doesn't cause issues with caching""" |
| class Skew(nn.Module): |
| def forward(self, X): |
| X = X.tril(-1) |
| return X - X.T |
| |
| class Orthogonal(nn.Module): |
| def forward(self, X): |
| Id = torch.eye(X.size(0), device=X.device) |
| return torch.linalg.solve(Id + X, Id - X) |
| |
| model = nn.Linear(5, 5) |
| parametrize.register_parametrization(model, "weight", Skew()) |
| parametrize.register_parametrization(model, "weight", Orthogonal()) |
| |
| to_model = nn.Linear(5, 5) |
| parametrize.transfer_parametrizations_and_params(model, to_model) |
| |
| with parametrize.cached(): |
| X = model.weight |
| Y = model.weight |
| self.assertEqual(id(X), id(Y)) |
| |
| A = to_model.weight |
| B = to_model.weight |
| self.assertEqual(id(A), id(B)) |
| |
| # test that the results are distinct objects for each module |
| self.assertNotEqual(id(A), id(X)) |
| |
| def test_parametrization_same_training_mode(self): |
| r"""Test training mode updated on parametrization registration""" |
| class Identity(nn.Module): |
| def forward(self, X): |
| return X |
| |
| module = nn.Linear(4, 4) |
| module.eval() |
| parametrize.register_parametrization(module, "weight", Identity()) |
| self.assertFalse(module.parametrizations.weight[0].training) |
| module.train() |
| parametrize.register_parametrization(module, "weight", Identity().eval()) |
| self.assertTrue(module.parametrizations.weight[0].training) |
| self.assertTrue(module.parametrizations.weight[1].training) |
| |
| def test_type_before_parametrizations(self): |
| r"""Test that type_before_parametrizations always retrieves original type""" |
| |
| class Identity(nn.Module): |
| def forward(self, X): |
| return X |
| |
| model = nn.Linear(5, 5) |
| original_type = type(model) |
| self.assertTrue( |
| parametrize.type_before_parametrizations(model) == original_type |
| ) |
| parametrize.register_parametrization(model, "weight", Identity()) |
| self.assertTrue( |
| parametrize.type_before_parametrizations(model) == original_type |
| ) |
| |
| def test_deepcopy_after_parametrization(self): |
| r"""Test that we are able to create a deepcopy of the module when it's parametrized.""" |
| |
| class AddOne(nn.Module): |
| def forward(self, x): |
| return x + 1.0 |
| |
| class ModelWithoutDeepcopy(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = nn.Parameter(torch.tensor([1., 1., 1., 1.]), requires_grad=True) |
| self.bias = nn.Parameter(torch.tensor([0., 0., 0., 0.]), requires_grad=True) |
| self.attr = [1.0, 2.0, 3.0, 4.0] |
| |
| class ActualModel(ModelWithoutDeepcopy): |
| # Emulate custom implementation of the deepcopying. |
| def __deepcopy__(self, memo): |
| result = self.__new__(self.__class__) |
| memo[id(self)] = result |
| result.__dict__ = deepcopy(self.__dict__, memo) |
| return result |
| |
| def check_deepcopy(m1: nn.Module, m2: nn.Module): |
| w1 = m1.parametrizations.weight.original |
| w2 = m2.parametrizations.weight.original |
| b1 = m1.parametrizations.bias.original if parametrize.is_parametrized(m1, "bias") else m1.bias |
| b2 = m2.parametrizations.bias.original if parametrize.is_parametrized(m2, "bias") else m2.bias |
| # Weights, biases and attributes should be equal but they must be different objects. |
| self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys()) |
| self.assertIsNot(m1, m2) |
| self.assertEqual(w1, w2) |
| self.assertIsNot(w1, w2) |
| self.assertEqual(b1, b2) |
| self.assertIsNot(b1, b2) |
| self.assertEqual(m1.attr, m2.attr) |
| self.assertIsNot(m1.attr, m2.attr) |
| |
| for model in (ModelWithoutDeepcopy(), ActualModel()): |
| # General check that we are able to create deepcopy. |
| parametrize.register_parametrization(model, "weight", AddOne()) |
| check_deepcopy(model, deepcopy(model)) |
| # Check that this works on models with several parametrized tensors. |
| parametrize.register_parametrization(model, "bias", AddOne()) |
| check_deepcopy(model, deepcopy(model)) |
| # Check that this works on models where tensors have more than one parametrization. |
| parametrize.register_parametrization(model, "weight", AddOne()) |
| check_deepcopy(model, deepcopy(model)) |
| |
| def test_transfer_parametrizations_and_params(self): |
| r"""Test that all parametrizations and their associated parameters are transferred.""" |
| |
| class AddOne(nn.Module): |
| def forward(self, x): |
| return x + 1.0 |
| |
| class Double(nn.Module): |
| def forward(self, x): |
| return 2.0 * x |
| |
| def right_inverse(self, x): |
| return 0.5 * x |
| |
| class MinusOne(nn.Module): |
| def forward(self, x): |
| return x - 1.0 |
| |
| model = nn.Linear(5, 5) |
| parametrize.register_parametrization(model, "weight", AddOne()) |
| parametrize.register_parametrization(model, "weight", Double()) |
| parametrize.register_parametrization(model, "weight", MinusOne()) |
| hold_weight = model.weight |
| |
| to_model = torch.ao.nn.qat.Linear( |
| 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() |
| ) |
| parametrize.transfer_parametrizations_and_params(model, to_model) |
| |
| # checks that final and original value are correct and the to_model is parametrized |
| self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) |
| self.assertEqual(model.weight, to_model.weight) |
| self.assertEqual( |
| model.parametrizations.weight.original, |
| to_model.parametrizations.weight.original, |
| ) |
| |
| # check that the transfer didn't affect the original value |
| self.assertEqual(hold_weight, model.weight) |
| |
| # testing that changes to one set of parametrizations do not affect the other |
| parametrize.remove_parametrizations(to_model, "weight") |
| self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) |
| self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight")) |
| |
| # also test that parameters that don't exist in to_model get transferred |
| model.test_param = Parameter(torch.randn(5, 5)) |
| |
| self.assertTrue(not hasattr(to_model, "test_param")) |
| parametrize.register_parametrization(model, "test_param", Double()) |
| hold_test_param = model.test_param |
| parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") |
| |
| # check that previously missing params got transferred correctly |
| self.assertEqual(model.test_param, to_model.test_param) |
| self.assertEqual( |
| model.parametrizations.test_param.original, |
| to_model.parametrizations.test_param.original, |
| ) |
| |
| # check that the new transfer didn't change the value for the from_module |
| self.assertEqual(hold_test_param, model.test_param) |
| |
| def test_transfer_parametrizations_and_params_right_inverse(self): |
| r"""Test that all parametrizations and their associated parameters are transferred.""" |
| |
| class Double(nn.Module): |
| def forward(self, x): |
| return 2.0 * x |
| |
| def right_inverse(self, x): |
| return 0.5 * x |
| |
| model = nn.Linear(5, 5) |
| parametrize.register_parametrization(model, "weight", Double()) |
| hold_weight = model.weight |
| |
| to_model = torch.ao.nn.qat.Linear( |
| 5, 5, qconfig=torch.ao.quantization.get_default_qconfig() |
| ) |
| parametrize.transfer_parametrizations_and_params(model, to_model) |
| |
| # check that transfer occurs successfully |
| self.assertEqual(model.weight, to_model.weight) |
| self.assertEqual( |
| model.parametrizations.weight.original, |
| to_model.parametrizations.weight.original, |
| ) |
| |
| # check that transfer doesn't affect the from_model weight |
| self.assertEqual(hold_weight, model.weight) |
| |
| def test_transfer_parametrizations_and_params_single_param(self): |
| r"""Test that all parametrizations and their associated parameters are transferred.""" |
| |
| class AddOne(nn.Module): |
| def forward(self, x): |
| return x + 1.0 |
| |
| class Double(nn.Module): |
| def forward(self, x): |
| return 2.0 * x |
| |
| class MinusOne(nn.Module): |
| def forward(self, x): |
| return x - 1.0 |
| |
| model = nn.Linear(5, 5, bias=True) |
| parametrize.register_parametrization(model, "weight", AddOne()) |
| parametrize.register_parametrization(model, "weight", Double()) |
| parametrize.register_parametrization(model, "weight", MinusOne()) |
| parametrize.register_parametrization(model, "bias", AddOne()) |
| parametrize.register_parametrization(model, "bias", Double()) |
| parametrize.register_parametrization(model, "bias", MinusOne()) |
| |
| to_model = torch.ao.nn.qat.Linear( |
| 5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig() |
| ) |
| parametrize.transfer_parametrizations_and_params(model, to_model, "weight") |
| |
| # check that weight and only weight was transferred |
| self.assertEqual(model.weight, to_model.weight) |
| self.assertEqual( |
| model.parametrizations.weight.original, |
| to_model.parametrizations.weight.original, |
| ) |
| self.assertTrue("bias" not in to_model.parametrizations) |
| |
| # FIXME: Rewrite this test using functions not depending on LAPACK |
| # and remove the `@skipIfNoLapack` (see #70995) |
| @skipIfNoLapack |
| def test_transfer_parametrizations_and_params_many_to_one(self): |
| # A parametrization with several outputs |
| class RankOne(nn.Module): |
| def forward(self, x, y): |
| # Form a rank-1 matrix from a pair of vectors |
| return x.unsqueeze(-1) @ y.unsqueeze(-2) |
| |
| def right_inverse(self, Y): |
| # We project the given matrix onto the rank 1 matrices |
| U, S, Vh = torch.linalg.svd(Y, full_matrices=False) |
| # S is ordered in a decreasing way. |
| s0_sqrt = S[0].sqrt().unsqueeze(-1) |
| return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt |
| |
| class Double(nn.Module): |
| def forward(self, x): |
| return 2.0 * x |
| |
| model = nn.Linear(3, 3) |
| parametrize.register_parametrization(model, "weight", RankOne()) |
| parametrize.register_parametrization(model, "weight", Double()) |
| hold_weight = model.weight |
| |
| to_model = torch.ao.nn.qat.Linear( |
| 3, 3, qconfig=torch.ao.quantization.get_default_qconfig() |
| ) |
| |
| parametrize.transfer_parametrizations_and_params(model, to_model) |
| |
| # checks that final and original value are correct and the to_model is parametrized |
| self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight")) |
| self.assertEqual(model.weight, to_model.weight) |
| self.assertEqual( |
| model.parametrizations.weight.original0, |
| to_model.parametrizations.weight.original0, |
| ) |
| self.assertEqual( |
| model.parametrizations.weight.original1, |
| to_model.parametrizations.weight.original1, |
| ) |
| |
| # check that the transfer didn't affect the original value |
| self.assertEqual(hold_weight, model.weight) |
| |
| # testing that changes to one set of parametrizations do not affect the other |
| model.test_param = Parameter(torch.randn(3, 3)) |
| |
| self.assertTrue(not hasattr(to_model, "test_param")) |
| parametrize.register_parametrization(model, "test_param", RankOne()) |
| hold_test_param = model.test_param |
| parametrize.transfer_parametrizations_and_params(model, to_model, "test_param") |
| |
| # also check that previously missing params got transferred correctly |
| self.assertEqual(model.test_param, to_model.test_param) |
| self.assertEqual( |
| model.parametrizations.test_param.original0, |
| to_model.parametrizations.test_param.original0, |
| ) |
| self.assertEqual( |
| model.parametrizations.test_param.original1, |
| to_model.parametrizations.test_param.original1, |
| ) |
| |
| # check that the new transfer didn't change the value for the from_module |
| self.assertEqual(hold_test_param, model.test_param) |
| |
| def test_new_spectral_norm(self): |
| with set_default_dtype(torch.double): |
| input = torch.randn(3, 5) |
| m = nn.Linear(5, 7) |
| m = torch.nn.utils.parametrizations.spectral_norm(m) |
| spectral_norm_m = m.parametrizations.weight[0] |
| |
| self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)])) |
| |
| # .parametrizations.weight.original should be trainable |
| self.assertTrue(hasattr(m.parametrizations.weight, 'original')) |
| self.assertTrue('original' in m.parametrizations.weight._parameters) |
| |
| # u should be just a reused buffer |
| self.assertTrue(hasattr(spectral_norm_m, '_u')) |
| self.assertTrue('_u' in spectral_norm_m._buffers) |
| self.assertTrue('_v' in spectral_norm_m._buffers) |
| |
| # weight should be a plain attribute, not counted as a buffer or a param |
| self.assertIsNotNone(m.weight) |
| self.assertFalse('weight' in m._buffers) |
| self.assertFalse('weight' in m._parameters) |
| |
| # it should also be sharing storage as `weight_orig` |
| # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage()) |
| self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size()) |
| self.assertEqual(m.parametrizations.weight.original.stride(), m.weight.stride()) |
| |
| m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') |
| |
| # spectral_norm is the only parametrization |
| self.assertFalse(hasattr(m, 'parametrizations')) |
| self.assertTrue('weight' in m._parameters) |
| |
| # We can register spectral_norm multiple times on the same parameter |
| # and on multiple parameters in the same module |
| m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight') |
| m = torch.nn.utils.parametrizations.spectral_norm(m, 'weight') |
| m = torch.nn.utils.parametrizations.spectral_norm(m, 'bias') |
| |
| # If we remove the parametrization on bias, weight is still parametrized |
| # Removing a parametrization runs forward in eval mode if leave_parametrized=True |
| m = torch.nn.utils.parametrize.remove_parametrizations(m, 'bias') |
| self.assertTrue('bias' in m._parameters) |
| self.assertTrue(hasattr(m, 'parametrizations')) |
| self.assertFalse('weight' in m._parameters) |
| |
| m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') |
| # Neither weight and bias are parametrized |
| self.assertFalse(hasattr(m, 'parametrizations')) |
| self.assertTrue('weight' in m._parameters) |
| self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m)) |
| |
| # test correctness in training/eval modes and cpu/multi-gpu settings |
| for apply_dp in (True, False): |
| if apply_dp: |
| if not TEST_MULTIGPU: |
| continue |
| device = torch.device('cuda:0') |
| |
| def maybe_wrap(m): |
| return torch.nn.DataParallel(m, [0, 1]) |
| else: |
| device = torch.device('cpu') |
| |
| def maybe_wrap(m): |
| return m |
| |
| for requires_grad in (True, False): |
| def get_modules(): |
| m = nn.Linear(3, 4).to(device) |
| m.weight.requires_grad_(requires_grad) |
| m = torch.nn.utils.parametrizations.spectral_norm(m) |
| wrapped_m = maybe_wrap(m) |
| spectral_norm_m = m.parametrizations.weight[0] |
| return m, wrapped_m, spectral_norm_m |
| |
| input = torch.randn(2, 3, device=device) |
| |
| m, wrapped_m, spectral_norm_m = get_modules() |
| |
| self.assertTrue(hasattr(spectral_norm_m, '_u')) |
| u0 = spectral_norm_m._u.clone() |
| v0 = spectral_norm_m._v.clone() |
| |
| # TEST TRAINING BEHAVIOR |
| |
| # We perform GD first to modify the initial matrix |
| opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1) |
| |
| opt.zero_grad() |
| wrapped_m(input).sum().backward() |
| opt.step() |
| |
| out = wrapped_m(input) |
| if requires_grad: |
| # run forward again and assert that u and v are updated |
| self.assertNotEqual(u0, spectral_norm_m._u) |
| self.assertNotEqual(v0, spectral_norm_m._v) |
| |
| # assert that backprop reaches original weight |
| # can't use gradcheck because the function changes as we |
| # activate through it in training mode |
| if requires_grad: |
| torch.autograd.grad(out.sum(), m.parametrizations.weight.original) |
| |
| # test backward works with multiple forwards |
| # it uses training mode so we need to reset `u` and `v` vectors |
| # to same value at beginning for finite difference test to pass |
| saved_u = spectral_norm_m._u.clone() |
| saved_v = spectral_norm_m._v.clone() |
| |
| def fn(input): |
| spectral_norm_m._u.data.copy_(saved_u) |
| spectral_norm_m._v.data.copy_(saved_v) |
| out0 = wrapped_m(input) |
| out1 = wrapped_m(input) |
| return out0 + out1 |
| |
| # Make sure we can compute gradients wrt to all the parameters in the case |
| # of double forward |
| fn(input.clone().requires_grad_()).sum().backward() |
| gradcheck(fn, (input.clone().requires_grad_(),), check_batched_grad=False) |
| |
| # test removing |
| # spectral norm module needs to be in eval mode if we'd like to |
| # avoid doing another power iteration |
| m, wrapped_m, _ = get_modules() |
| pre_remove_out = wrapped_m(input) |
| m.eval() |
| m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') |
| self.assertEqual(wrapped_m(input), pre_remove_out) |
| |
| torch.nn.utils.parametrizations.spectral_norm(m) |
| for _ in range(3): |
| pre_remove_out = wrapped_m(input) |
| m.eval() |
| m = torch.nn.utils.parametrize.remove_parametrizations(m, 'weight') |
| self.assertEqual(wrapped_m(input), pre_remove_out) |
| |
| # TEST EVAL BEHAVIOR |
| m, wrapped_m, spectral_norm_m = get_modules() |
| wrapped_m(input) |
| last_train_out = wrapped_m(input) |
| last_train_u = spectral_norm_m._u.clone() |
| last_train_v = spectral_norm_m._v.clone() |
| wrapped_m.zero_grad() |
| wrapped_m.eval() |
| |
| eval_out0 = wrapped_m(input) |
| # assert eval gives same result as last training iteration |
| self.assertEqual(eval_out0, last_train_out) |
| # assert doing more iteartion in eval don't change things |
| self.assertEqual(eval_out0, wrapped_m(input)) |
| self.assertEqual(last_train_u, spectral_norm_m._u) |
| self.assertEqual(last_train_v, spectral_norm_m._v) |
| |
| # FIXME: the code below is flaky when executed with DataParallel |
| # see https://github.com/pytorch/pytorch/issues/13818 |
| if apply_dp: |
| continue |
| |
| # test backward works with multiple forwards in mixed training |
| # and eval modes |
| # it uses training mode so we need to reset `u` and `v` vectors |
| # to same value at beginning for finite difference test to pass |
| saved_u = spectral_norm_m._u.clone() |
| saved_v = spectral_norm_m._v.clone() |
| |
| def fn(input): |
| spectral_norm_m._u.data.copy_(saved_u) |
| spectral_norm_m._v.data.copy_(saved_v) |
| wrapped_m.train() |
| out0 = wrapped_m(input) |
| wrapped_m.eval() |
| out1 = wrapped_m(input) |
| wrapped_m.train() |
| out2 = wrapped_m(input) |
| wrapped_m.eval() |
| out3 = wrapped_m(input) |
| return out0 + out1 + out2 + out3 |
| |
| gradcheck(fn, (input.clone().requires_grad_(),)) |
| |
| # assert that backprop reaches weight_orig in eval |
| if requires_grad: |
| def fn(weight): |
| return wrapped_m(input) |
| |
| gradcheck(fn, (m.parametrizations.weight.original,)) |
| |
| def test_new_spectral_norm_load_state_dict(self): |
| for activate_times in (0, 3): |
| inp = torch.randn(2, 3) |
| m = nn.Linear(3, 5) |
| snm = torch.nn.utils.parametrizations.spectral_norm(m) |
| snm.train() |
| |
| for _ in range(activate_times): |
| snm(inp) |
| |
| state_dict = deepcopy(snm.state_dict()) |
| self.assertEqual({ |
| 'parametrizations.weight.original', |
| 'bias', |
| 'parametrizations.weight.0._v', |
| 'parametrizations.weight.0._u' |
| }, set(state_dict.keys())) |
| |
| # test that non-strict loading works |
| non_strict_state_dict = deepcopy(state_dict) |
| non_strict_state_dict['nonsense'] = 'nonsense' |
| with self.assertRaisesRegex(RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'): |
| snm.load_state_dict(non_strict_state_dict, strict=True) |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| del non_strict_state_dict['parametrizations.weight.original'] |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| del non_strict_state_dict['parametrizations.weight.0._u'] |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| del non_strict_state_dict['parametrizations.weight.0._v'] |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| non_strict_state_dict['weight'] = snm.weight.detach().clone() # set W as a buffer |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| del non_strict_state_dict._metadata['parametrizations.weight.0'] # remove metadata info |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| del non_strict_state_dict['weight'] # remove W buffer |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| del non_strict_state_dict['bias'] |
| snm.load_state_dict(non_strict_state_dict, strict=False) |
| |
| # normal state_dict |
| |
| # test that re-wrapping does not matter |
| m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight') |
| snm = torch.nn.utils.parametrizations.spectral_norm(m) |
| |
| snm.load_state_dict(state_dict) |
| with torch.no_grad(): |
| snm.eval() |
| out0_eval = snm(inp) |
| snm.train() |
| out1_train = snm(inp) |
| out2_train = snm(inp) |
| snm.eval() |
| out3_eval = snm(inp) |
| |
| # test that re-wrapping does not matter |
| m = torch.nn.utils.parametrize.remove_parametrizations(snm, 'weight') |
| snm = torch.nn.utils.parametrizations.spectral_norm(m) |
| |
| # Test normal loading |
| snm.load_state_dict(state_dict) |
| with torch.no_grad(): |
| snm.eval() |
| self.assertEqual(out0_eval, snm(inp)) |
| snm.train() |
| self.assertEqual(out1_train, snm(inp)) |
| self.assertEqual(out2_train, snm(inp)) |
| snm.eval() |
| self.assertEqual(out3_eval, snm(inp)) |
| |
| def test_new_spectral_norm_dim(self): |
| inp = torch.randn(2, 3, 10, 12) |
| m = nn.ConvTranspose2d(3, 4, (5, 6)) |
| m = torch.nn.utils.parametrizations.spectral_norm(m) |
| snm = m.parametrizations.weight[0] |
| # this should not run into incompatible shapes |
| x = m(inp) |
| # check that u refers to the same dimension |
| self.assertEqual(snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape) |
| |
| def test_new_spectral_norm_forward(self): |
| input = torch.randn(3, 5) |
| m = nn.Linear(5, 7) |
| m = torch.nn.utils.parametrizations.spectral_norm(m) |
| snm = m.parametrizations.weight[0] |
| # naive forward |
| _weight = m.parametrizations.weight.original |
| _bias, _v = m.bias, snm._v |
| _weight_mat = _weight.view(_weight.size(0), -1) |
| _u = torch.mv(_weight_mat, _v) |
| _u = F.normalize(_u, dim=0, eps=1e-12) |
| _v = torch.mv(_weight_mat.t(), _u) |
| _v = F.normalize(_v, dim=0, eps=1e-12) |
| _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v)) |
| out_hat = torch.nn.functional.linear(input, _weight, _bias) |
| expect_out = m(input) |
| self.assertEqual(expect_out, out_hat) |
| |
| @skipIfNoLapack |
| def test_orthogonal_parametrization(self): |
| # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization) |
| |
| def assert_is_orthogonal(X): |
| n, k = X.size(-2), X.size(-1) |
| if n < k: |
| X = X.mT |
| n, k = k, n |
| Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(*(X.size()[:-2]), k, k) |
| eps = 10 * n * torch.finfo(X.dtype).eps |
| torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.) |
| |
| def assert_weight_allclose_Q(weight, W): |
| # Test that weight is equal to the Q part of the QR decomposition of W |
| # (or of its transpose if the matrix is wide) |
| wide_matrix = W.size(-2) < W.size(-1) |
| if wide_matrix: |
| W = W.mT |
| Q, R = torch.linalg.qr(W) |
| Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2) |
| if wide_matrix: |
| Q = Q.mT |
| torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.) |
| |
| for shape, dtype, use_linear in product(((4, 4), (5, 3), (3, 5)), # square/ tall / wide |
| (torch.float32, torch.complex64), |
| (True, False)): |
| # Conv2d does not support complex yet |
| if not use_linear: |
| continue |
| |
| if use_linear: |
| input = torch.randn(3, shape[0], dtype=dtype) |
| else: |
| input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype) |
| |
| for parametrization, use_trivialization in product(("matrix_exp", "cayley", "householder"), |
| (False, True)): |
| # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False |
| # See Note [right_inverse expm cayley] |
| can_initialize = use_trivialization or parametrization == "householder" |
| |
| # We generate them every time to always start with fresh weights |
| if use_linear: |
| m = nn.Linear(*shape, dtype=dtype) |
| else: |
| m = nn.Conv2d(2, 3, shape, dtype=dtype) |
| |
| # We do not support householder for complex inputs |
| # See Note [Householder complex] |
| w_init = m.weight.clone() |
| if parametrization == "householder" and m.weight.is_complex(): |
| msg = "householder parametrization does not support complex tensors" |
| with self.assertRaisesRegex(ValueError, msg): |
| torch.nn.utils.parametrizations.orthogonal(m, |
| "weight", |
| parametrization, |
| use_trivialization=use_trivialization) |
| continue |
| |
| wide_matrix = w_init.size(-2) < w_init.size(-1) |
| torch.nn.utils.parametrizations.orthogonal(m, |
| "weight", |
| parametrization, |
| use_trivialization=use_trivialization) |
| # Forwards works as expected |
| self.assertEqual(w_init.shape, m.weight.shape) |
| assert_is_orthogonal(m.weight) |
| if can_initialize: |
| assert_weight_allclose_Q(m.weight, w_init) |
| |
| # Intializing with a given orthogonal matrix works |
| X = torch.randn_like(m.weight) |
| if wide_matrix: |
| X = X.mT |
| w_new = torch.linalg.qr(X).Q |
| if wide_matrix: |
| w_new = w_new.mT |
| if can_initialize: |
| m.weight = w_new |
| torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.) |
| else: |
| msg = "assign to the matrix exponential or the Cayley parametrization" |
| with self.assertRaisesRegex(NotImplementedError, msg): |
| m.weight = w_new |
| |
| # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix |
| w_new = torch.randn_like(m.weight) |
| if can_initialize: |
| m.weight = w_new |
| assert_weight_allclose_Q(m.weight, w_new) |
| else: |
| msg = "assign to the matrix exponential or the Cayley parametrization" |
| with self.assertRaisesRegex(NotImplementedError, msg): |
| m.weight = w_new |
| |
| opt = torch.optim.SGD(m.parameters(), lr=0.1) |
| for _ in range(2): |
| opt.zero_grad() |
| m(input).norm().backward() |
| grad = m.parametrizations.weight.original.grad |
| self.assertIsNotNone(grad) |
| # We do not update the upper triangular part of the matrix if tall tril if wide |
| if grad.size(-2) >= grad.size(-1): |
| zeros_grad = grad.triu(1) |
| else: |
| zeros_grad = grad.tril(-1) |
| self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad)) |
| # The gradient in the diagonal can only be imaginary because a skew-Hermitian |
| # matrix has imaginary diagonal |
| diag_grad = grad.diagonal(dim1=-2, dim2=-1) |
| if grad.is_complex(): |
| diag_grad = diag_grad.real |
| self.assertEqual(diag_grad, torch.zeros_like(diag_grad)) |
| opt.step() |
| assert_is_orthogonal(m.weight) |
| |
| @skipIfNoLapack |
| def test_orthogonal_errors(self): |
| m = nn.Linear(3, 4) |
| with self.assertRaisesRegex(ValueError, "has to be one of"): |
| torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo") |
| |
| with self.assertRaisesRegex(ValueError, "Expected a matrix"): |
| torch.nn.utils.parametrizations.orthogonal(m, "bias") |
| |
| torch.nn.utils.parametrizations.orthogonal(m, "weight") |
| with self.assertRaisesRegex(ValueError, "matrices of shape"): |
| m.weight = torch.randn(5, 5) |
| torch.nn.utils.parametrize.remove_parametrizations(m, "weight") |
| |
| |
| def test_weight_norm_parametrization(self): |
| for dtype in [torch.float, torch.bfloat16]: |
| input = torch.randn(3, 4, dtype=dtype) |
| m = nn.Linear(4, 5).to(dtype=dtype) |
| expected_output = m(input) |
| |
| # add weight normalization |
| m = torch.nn.utils.parametrizations.weight_norm(m) |
| self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size()) |
| self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1)) |
| self.assertEqual(m(input), expected_output) |
| |
| # remove weight norm |
| torch.nn.utils.parametrize.remove_parametrizations(m, "weight") |
| self.assertFalse(hasattr(m, "parametrizations")) |
| self.assertEqual(m(input), expected_output) |
| |
| # test with dim=1 |
| m = torch.nn.utils.parametrizations.weight_norm(m, dim=1) |
| self.assertEqual(m.parametrizations.weight.original1.size(), m.weight.size()) |
| self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4)) |
| self.assertEqual(m(input), expected_output) |
| |
| # test with dim=None |
| m = nn.Linear(4, 5).to(dtype=dtype) |
| expected_output = m(input) |
| m = torch.nn.utils.parametrizations.weight_norm(m, dim=None) |
| self.assertEqual(m(input), expected_output) |
| |
| def test_weight_norm_state_dict_compat(self): |
| m = nn.Linear(4, 5) |
| m = torch.nn.utils.weight_norm(m) |
| old_dict = m.state_dict() |
| |
| m2 = nn.Linear(4, 5) |
| m2 = torch.nn.utils.parametrizations.weight_norm(m2) |
| m2.load_state_dict(old_dict) |
| |
| input = torch.randn(3, 4) |
| self.assertEqual(m(input), m2(input)) |
| |
| def test_weight_norm_pickle(self): |
| m = nn.Linear(4, 5) |
| m = torch.nn.utils.parametrizations.weight_norm(m) |
| with self.assertRaisesRegex(RuntimeError, 'state_dict'): |
| pickle.dumps(m) |
| |
| def test_weight_norm_deepcopy(self): |
| m = nn.Linear(4, 5) |
| m = torch.nn.utils.parametrizations.weight_norm(m) |
| m2 = deepcopy(m) |
| input = torch.randn(3, 4) |
| self.assertEqual(m(input), m2(input)) |
| |
| |
| instantiate_parametrized_tests(TestNNParametrization) |
| |
| if __name__ == '__main__': |
| run_tests() |