| # Owner(s): ["module: nn"] |
| import unittest |
| import unittest.mock as mock |
| import pickle |
| |
| import torch |
| |
| import torch.nn as nn |
| import torch.nn.utils.prune as prune |
| from torch.testing._internal.common_utils import TEST_NUMPY, TemporaryFileName, \ |
| instantiate_parametrized_tests, run_tests |
| from torch.testing._internal.common_nn import NNTestCase |
| |
| class TestPruningNN(NNTestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| |
| # torch/nn/utils/prune.py |
| @unittest.skipIf(not TEST_NUMPY, "numpy not found") |
| def test_validate_pruning_amount_init(self): |
| r"""Test the first util function that validates the pruning |
| amount requested by the user the moment the pruning method |
| is initialized. This test checks that the expected errors are |
| raised whenever the amount is invalid. |
| The original function runs basic type checking + value range checks. |
| It doesn't check the validity of the pruning amount with |
| respect to the size of the tensor to prune. That's left to |
| `_validate_pruning_amount`, tested below. |
| """ |
| # neither float not int should raise TypeError |
| with self.assertRaises(TypeError): |
| prune._validate_pruning_amount_init(amount="I'm a string") |
| |
| # float not in [0, 1] should raise ValueError |
| with self.assertRaises(ValueError): |
| prune._validate_pruning_amount_init(amount=1.1) |
| with self.assertRaises(ValueError): |
| prune._validate_pruning_amount_init(amount=20.) |
| |
| # negative int should raise ValueError |
| with self.assertRaises(ValueError): |
| prune._validate_pruning_amount_init(amount=-10) |
| |
| # all these should pass without errors because they're valid amounts |
| prune._validate_pruning_amount_init(amount=0.34) |
| prune._validate_pruning_amount_init(amount=1500) |
| prune._validate_pruning_amount_init(amount=0) |
| prune._validate_pruning_amount_init(amount=0.) |
| prune._validate_pruning_amount_init(amount=1) |
| prune._validate_pruning_amount_init(amount=1.) |
| self.assertTrue(True) |
| |
| @unittest.skipIf(not TEST_NUMPY, "numpy not found") |
| def test_validate_pruning_amount(self): |
| r"""Tests the second util function that validates the pruning |
| amount requested by the user, this time with respect to the size |
| of the tensor to prune. The rationale is that if the pruning amount, |
| converted to absolute value of units to prune, is larger than |
| the number of units in the tensor, then we expect the util function |
| to raise a value error. |
| """ |
| # if amount is int and amount > tensor_size, raise ValueError |
| with self.assertRaises(ValueError): |
| prune._validate_pruning_amount(amount=20, tensor_size=19) |
| |
| # amount is a float so this should not raise an error |
| prune._validate_pruning_amount(amount=0.3, tensor_size=0) |
| |
| # this is okay |
| prune._validate_pruning_amount(amount=19, tensor_size=20) |
| prune._validate_pruning_amount(amount=0, tensor_size=0) |
| prune._validate_pruning_amount(amount=1, tensor_size=1) |
| self.assertTrue(True) |
| |
| @unittest.skipIf(not TEST_NUMPY, "numpy not found") |
| def test_compute_nparams_to_prune(self): |
| r"""Test that requested pruning `amount` gets translated into the |
| correct absolute number of units to prune. |
| """ |
| self.assertEqual( |
| prune._compute_nparams_toprune(amount=0, tensor_size=15), |
| 0 |
| ) |
| self.assertEqual( |
| prune._compute_nparams_toprune(amount=10, tensor_size=15), |
| 10 |
| ) |
| # if 1 is int, means 1 unit |
| self.assertEqual( |
| prune._compute_nparams_toprune(amount=1, tensor_size=15), |
| 1 |
| ) |
| # if 1. is float, means 100% of units |
| self.assertEqual( |
| prune._compute_nparams_toprune(amount=1., tensor_size=15), |
| 15 |
| ) |
| self.assertEqual( |
| prune._compute_nparams_toprune(amount=0.4, tensor_size=17), |
| 7 |
| ) |
| |
| def test_random_pruning_sizes(self): |
| r"""Test that the new parameters and buffers created by the pruning |
| method have the same size as the input tensor to prune. These, in |
| fact, correspond to the pruned version of the tensor itself, its |
| mask, and its original copy, so the size must match. |
| """ |
| # fixturize test |
| # TODO: add other modules |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| original_tensor = getattr(m, name) |
| |
| prune.random_unstructured(m, name=name, amount=0.1) |
| # mask has the same size as tensor being pruned |
| self.assertEqual( |
| original_tensor.size(), |
| getattr(m, name + '_mask').size() |
| ) |
| # 'orig' tensor has the same size as the original tensor |
| self.assertEqual( |
| original_tensor.size(), |
| getattr(m, name + '_orig').size() |
| ) |
| # new tensor has the same size as the original tensor |
| self.assertEqual( |
| original_tensor.size(), |
| getattr(m, name).size() |
| ) |
| |
| def test_random_pruning_orig(self): |
| r"""Test that original tensor is correctly stored in 'orig' |
| after pruning is applied. Important to make sure we don't |
| lose info about the original unpruned parameter. |
| """ |
| # fixturize test |
| # TODO: add other modules |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| |
| # tensor prior to pruning |
| original_tensor = getattr(m, name) |
| prune.random_unstructured(m, name=name, amount=0.1) |
| self.assertEqual( |
| original_tensor, |
| getattr(m, name + '_orig') |
| ) |
| |
| def test_random_pruning_new_weight(self): |
| r"""Test that module.name now contains a pruned version of |
| the original tensor obtained from multiplying it by the mask. |
| """ |
| # fixturize test |
| # TODO: add other modules |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| # tensor prior to pruning |
| original_tensor = getattr(m, name) |
| prune.random_unstructured(m, name=name, amount=0.1) |
| # weight = weight_orig * weight_mask |
| self.assertEqual( |
| getattr(m, name), |
| getattr(m, name + '_orig') |
| * getattr(m, name + '_mask').to( |
| dtype=original_tensor.dtype |
| ), |
| ) |
| |
| def test_identity_pruning(self): |
| r"""Test that a mask of 1s does not change forward or backward. |
| """ |
| input_ = torch.ones(1, 5) |
| m = nn.Linear(5, 2) |
| y_prepruning = m(input_) # output prior to pruning |
| |
| # compute grad pre-pruning and check it's equal to all ones |
| y_prepruning.sum().backward() |
| old_grad_weight = m.weight.grad.clone() # don't grab pointer! |
| self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) |
| old_grad_bias = m.bias.grad.clone() |
| self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) |
| |
| # remove grads |
| m.zero_grad() |
| |
| # force the mask to be made of all 1s |
| prune.identity(m, name="weight") |
| |
| # with mask of 1s, output should be identical to no mask |
| y_postpruning = m(input_) |
| self.assertEqual(y_prepruning, y_postpruning) |
| |
| # with mask of 1s, grad should be identical to no mask |
| y_postpruning.sum().backward() |
| self.assertEqual(old_grad_weight, m.weight_orig.grad) |
| self.assertEqual(old_grad_bias, m.bias.grad) |
| |
| # calling forward twice in a row shouldn't change output |
| y1 = m(input_) |
| y2 = m(input_) |
| self.assertEqual(y1, y2) |
| |
| def test_random_pruning_0perc(self): |
| r"""Test that a mask of 1s does not change forward or backward. |
| """ |
| input_ = torch.ones(1, 5) |
| m = nn.Linear(5, 2) |
| y_prepruning = m(input_) # output prior to pruning |
| |
| # compute grad pre-pruning and check it's equal to all ones |
| y_prepruning.sum().backward() |
| old_grad_weight = m.weight.grad.clone() # don't grab pointer! |
| self.assertEqual(old_grad_weight, torch.ones_like(m.weight)) |
| old_grad_bias = m.bias.grad.clone() |
| self.assertEqual(old_grad_bias, torch.ones_like(m.bias)) |
| |
| # remove grads |
| m.zero_grad() |
| |
| # force the mask to be made of all 1s |
| with mock.patch( |
| "torch.nn.utils.prune.RandomUnstructured.compute_mask" |
| ) as compute_mask: |
| compute_mask.return_value = torch.ones_like(m.weight) |
| prune.random_unstructured(m, name='weight', amount=0.9) # amount won't count |
| |
| # with mask of 1s, output should be identical to no mask |
| y_postpruning = m(input_) |
| self.assertEqual(y_prepruning, y_postpruning) |
| |
| # with mask of 1s, grad should be identical to no mask |
| y_postpruning.sum().backward() |
| self.assertEqual(old_grad_weight, m.weight_orig.grad) |
| self.assertEqual(old_grad_bias, m.bias.grad) |
| |
| # calling forward twice in a row shouldn't change output |
| y1 = m(input_) |
| y2 = m(input_) |
| self.assertEqual(y1, y2) |
| |
| def test_random_pruning(self): |
| input_ = torch.ones(1, 5) |
| m = nn.Linear(5, 2) |
| |
| # define custom mask to assign with mock |
| mask = torch.ones_like(m.weight) |
| mask[1, 0] = 0 |
| mask[0, 3] = 0 |
| |
| # check grad is zero for masked weights |
| with mock.patch( |
| "torch.nn.utils.prune.RandomUnstructured.compute_mask" |
| ) as compute_mask: |
| compute_mask.return_value = mask |
| prune.random_unstructured(m, name='weight', amount=0.9) |
| |
| y_postpruning = m(input_) |
| y_postpruning.sum().backward() |
| # weight_orig is the parameter, so it's the tensor that will accumulate the grad |
| self.assertEqual(m.weight_orig.grad, mask) # all 1s, except for masked units |
| self.assertEqual(m.bias.grad, torch.ones_like(m.bias)) |
| |
| # make sure that weight_orig update doesn't modify [1, 0] and [0, 3] |
| old_weight_orig = m.weight_orig.clone() |
| # update weights |
| learning_rate = 1. |
| for p in m.parameters(): |
| p.data.sub_(p.grad.data * learning_rate) |
| # since these are pruned, they should not be updated |
| self.assertEqual(old_weight_orig[1, 0], m.weight_orig[1, 0]) |
| self.assertEqual(old_weight_orig[0, 3], m.weight_orig[0, 3]) |
| |
| def test_random_pruning_forward(self): |
| r"""check forward with mask (by hand). |
| """ |
| input_ = torch.ones(1, 5) |
| m = nn.Linear(5, 2) |
| |
| # define custom mask to assign with mock |
| mask = torch.zeros_like(m.weight) |
| mask[1, 0] = 1 |
| mask[0, 3] = 1 |
| |
| with mock.patch( |
| "torch.nn.utils.prune.RandomUnstructured.compute_mask" |
| ) as compute_mask: |
| compute_mask.return_value = mask |
| prune.random_unstructured(m, name='weight', amount=0.9) |
| |
| yhat = m(input_) |
| self.assertEqual(yhat[0, 0], m.weight_orig[0, 3] + m.bias[0]) |
| self.assertEqual(yhat[0, 1], m.weight_orig[1, 0] + m.bias[1]) |
| |
| def test_remove_pruning_forward(self): |
| r"""Remove pruning and check forward is unchanged from previous |
| pruned state. |
| """ |
| input_ = torch.ones(1, 5) |
| m = nn.Linear(5, 2) |
| |
| # define custom mask to assign with mock |
| mask = torch.ones_like(m.weight) |
| mask[1, 0] = 0 |
| mask[0, 3] = 0 |
| |
| # check grad is zero for masked weights |
| with mock.patch( |
| "torch.nn.utils.prune.RandomUnstructured.compute_mask" |
| ) as compute_mask: |
| compute_mask.return_value = mask |
| prune.random_unstructured(m, name='weight', amount=0.9) |
| |
| y_postpruning = m(input_) |
| |
| prune.remove(m, 'weight') |
| |
| y_postremoval = m(input_) |
| self.assertEqual(y_postpruning, y_postremoval) |
| |
| def test_pruning_id_consistency(self): |
| r"""Test that pruning doesn't change the id of the parameters, which |
| would otherwise introduce issues with pre-existing optimizers that |
| point to old parameters. |
| """ |
| m = nn.Linear(5, 2, bias=False) |
| |
| tensor_id = id(list(m.parameters())[0]) |
| |
| prune.random_unstructured(m, name="weight", amount=0.9) |
| self.assertEqual(tensor_id, id(list(m.parameters())[0])) |
| |
| prune.remove(m, "weight") |
| self.assertEqual(tensor_id, id(list(m.parameters())[0])) |
| |
| def test_random_pruning_pickle(self): |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| prune.random_unstructured(m, name=name, amount=0.1) |
| m_new = pickle.loads(pickle.dumps(m)) |
| self.assertIsInstance(m_new, type(m)) |
| |
| def test_multiple_pruning_calls(self): |
| # if you call pruning twice, the hook becomes a PruningContainer |
| m = nn.Conv3d(2, 2, 2) |
| prune.l1_unstructured(m, name='weight', amount=0.1) |
| weight_mask0 = m.weight_mask # save it for later sanity check |
| |
| # prune again |
| prune.ln_structured(m, name='weight', amount=0.3, n=2, dim=0) |
| hook = next(iter(m._forward_pre_hooks.values())) |
| self.assertIsInstance( |
| hook, |
| torch.nn.utils.prune.PruningContainer |
| ) |
| # check that container._tensor_name is correctly set no matter how |
| # many pruning methods are in the container |
| self.assertEqual(hook._tensor_name, 'weight') |
| |
| # check that the pruning container has the right length |
| # equal to the number of pruning iters |
| self.assertEqual(len(hook), 2) # m.weight has been pruned twice |
| |
| # check that the entries of the pruning container are of the expected |
| # type and in the expected order |
| self.assertIsInstance(hook[0], torch.nn.utils.prune.L1Unstructured) |
| self.assertIsInstance(hook[1], torch.nn.utils.prune.LnStructured) |
| |
| # check that all entries that are 0 in the 1st mask are 0 in the |
| # 2nd mask too |
| self.assertTrue(torch.all(m.weight_mask[weight_mask0 == 0] == 0)) |
| |
| # prune again |
| prune.ln_structured(m, name='weight', amount=0.1, n=float('inf'), dim=1) |
| # check that container._tensor_name is correctly set no matter how |
| # many pruning methods are in the container |
| hook = next(iter(m._forward_pre_hooks.values())) |
| self.assertEqual(hook._tensor_name, 'weight') |
| |
| def test_pruning_container(self): |
| # create an empty container |
| container = prune.PruningContainer() |
| container._tensor_name = 'test' |
| self.assertEqual(len(container), 0) |
| |
| p = prune.L1Unstructured(amount=2) |
| p._tensor_name = 'test' |
| |
| # test adding a pruning method to a container |
| container.add_pruning_method(p) |
| |
| # test error raised if tensor name is different |
| q = prune.L1Unstructured(amount=2) |
| q._tensor_name = 'another_test' |
| with self.assertRaises(ValueError): |
| container.add_pruning_method(q) |
| |
| # test that adding a non-pruning method object to a pruning container |
| # raises a TypeError |
| with self.assertRaises(TypeError): |
| container.add_pruning_method(10) |
| with self.assertRaises(TypeError): |
| container.add_pruning_method('ugh') |
| |
| def test_pruning_container_compute_mask(self): |
| r"""Test `compute_mask` of pruning container with a known `t` and |
| `default_mask`. Indirectly checks that Ln structured pruning is |
| acting on the right axis. |
| """ |
| # create an empty container |
| container = prune.PruningContainer() |
| container._tensor_name = 'test' |
| |
| # 1) test unstructured pruning |
| # create a new pruning method |
| p = prune.L1Unstructured(amount=2) |
| p._tensor_name = 'test' |
| # add the pruning method to the container |
| container.add_pruning_method(p) |
| |
| # create tensor to be pruned |
| t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) |
| # create prior mask by hand |
| default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) |
| # since we are pruning the two lowest magnitude units, the outcome of |
| # the calculation should be this: |
| expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]], dtype=torch.float32) |
| computed_mask = container.compute_mask(t, default_mask) |
| self.assertEqual(expected_mask, computed_mask) |
| |
| # 2) test structured pruning |
| q = prune.LnStructured(amount=1, n=2, dim=0) |
| q._tensor_name = 'test' |
| container.add_pruning_method(q) |
| # since we are pruning the lowest magnitude one of the two rows, the |
| # outcome of the calculation should be this: |
| expected_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 0, 1]], dtype=torch.float32) |
| computed_mask = container.compute_mask(t, default_mask) |
| self.assertEqual(expected_mask, computed_mask) |
| |
| # 2) test structured pruning, along another axis |
| r = prune.LnStructured(amount=1, n=2, dim=1) |
| r._tensor_name = 'test' |
| container.add_pruning_method(r) |
| # since we are pruning the lowest magnitude of the four columns, the |
| # outcome of the calculation should be this: |
| expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]], dtype=torch.float32) |
| computed_mask = container.compute_mask(t, default_mask) |
| self.assertEqual(expected_mask, computed_mask) |
| |
| def test_l1_unstructured_pruning(self): |
| r"""Test that l1 unstructured pruning actually removes the lowest |
| entries by l1 norm (by hand). It also checks that applying l1 |
| unstructured pruning more than once respects the previous mask. |
| """ |
| m = nn.Linear(4, 2) |
| # modify its weight matrix by hand |
| m.weight = torch.nn.Parameter( |
| torch.tensor( |
| [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 |
| ) |
| ) |
| |
| prune.l1_unstructured(m, 'weight', amount=2) |
| expected_weight = torch.tensor([[0, 2, 3, 4], [-4, -3, -2, 0]], |
| dtype=m.weight.dtype) |
| self.assertEqual(expected_weight, m.weight) |
| |
| # check that pruning again removes the next two smallest entries |
| prune.l1_unstructured(m, 'weight', amount=2) |
| expected_weight = torch.tensor([[0, 0, 3, 4], [-4, -3, 0, 0]], |
| dtype=m.weight.dtype) |
| self.assertEqual(expected_weight, m.weight) |
| |
| def test_l1_unstructured_pruning_with_importance_scores(self): |
| r"""Test that l1 unstructured pruning actually removes the lowest |
| entries of importance scores and not the parameter by l1 norm (by hand). |
| It also checks that applying l1 unstructured pruning more than once |
| respects the previous mask. |
| """ |
| m = nn.Linear(4, 2) |
| # modify its weight matrix by hand |
| m.weight = torch.nn.Parameter( |
| torch.tensor( |
| [[1, 2, 3, 4], [-4, -3, -2, -1]], dtype=torch.float32 |
| ) |
| ) |
| importance_scores = torch.tensor( |
| [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 |
| ) |
| |
| prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) |
| expected_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]], |
| dtype=m.weight.dtype) |
| self.assertEqual(expected_weight, m.weight) |
| |
| # check that pruning again removes two entries of m.weight that are colocated with |
| # the next two smallest absolute values of importance scores. |
| prune.l1_unstructured(m, 'weight', amount=2, importance_scores=importance_scores) |
| expected_weight = torch.tensor([[1, 0, 0, 4], [-4, 0, 0, -1]], |
| dtype=m.weight.dtype) |
| self.assertEqual(expected_weight, m.weight) |
| |
| def test_unstructured_pruning_same_magnitude(self): |
| r"""Since it may happen that the tensor to prune has entries with the |
| same exact magnitude, it is important to check that pruning happens |
| consistenly based on the bottom % of weights, and not by threshold, |
| which would instead kill off *all* units with magnitude = threshold. |
| """ |
| AMOUNT = 0.2 |
| p = prune.L1Unstructured(amount=AMOUNT) |
| # create a random tensors with entries in {-2, 0, 2} |
| t = 2 * torch.randint(low=-1, high=2, size=(10, 7)) |
| nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.nelement()) |
| |
| computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) |
| nparams_pruned = torch.sum(computed_mask == 0) |
| self.assertEqual(nparams_toprune, nparams_pruned) |
| |
| def test_random_structured_pruning_amount(self): |
| AMOUNT = 0.6 |
| AXIS = 2 |
| p = prune.RandomStructured(amount=AMOUNT, dim=AXIS) |
| t = 2 * torch.randint(low=-1, high=2, size=(5, 4, 2)).to( |
| dtype=torch.float32 |
| ) |
| nparams_toprune = prune._compute_nparams_toprune(AMOUNT, t.shape[AXIS]) |
| |
| computed_mask = p.compute_mask(t, default_mask=torch.ones_like(t)) |
| # check that 1 column is fully prune, the others are left untouched |
| remaining_axes = [_ for _ in range(len(t.shape)) if _ != AXIS] |
| per_column_sums = sorted( |
| torch.sum(computed_mask == 0, axis=remaining_axes) |
| ) |
| assert per_column_sums == [0, 20] |
| |
| def test_ln_structured_pruning(self): |
| r"""Check Ln structured pruning by hand. |
| """ |
| m = nn.Conv2d(3, 1, 2) |
| m.weight.data = torch.tensor( |
| [[[[1., 2.], [1., 2.5]], |
| [[0.5, 1.], [0.1, 0.1]], |
| [[-3., -5.], [0.1, -1.]]]] |
| ) |
| # expected effect of pruning 1 of the 3 channels by L2-norm |
| expected_mask_axis1 = torch.ones_like(m.weight) |
| expected_mask_axis1[:, 1] = 0. |
| |
| prune.ln_structured(m, 'weight', amount=1, n=2, dim=1) |
| self.assertEqual(expected_mask_axis1, m.weight_mask) |
| |
| # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm |
| expected_mask_axis3 = expected_mask_axis1 |
| expected_mask_axis3[:, :, :, 0] = 0. |
| |
| prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1) |
| self.assertEqual(expected_mask_axis3, m.weight_mask) |
| |
| def test_ln_structured_pruning_importance_scores(self): |
| r"""Check Ln structured pruning by hand. |
| """ |
| m = nn.Conv2d(3, 1, 2) |
| m.weight.data = torch.tensor( |
| [[[[1., 2.], [1., 2.5]], |
| [[0.5, 1.], [0.1, 0.1]], |
| [[-3., -5.], [0.1, -1.]]]] |
| ) |
| importance_scores = torch.tensor( |
| [[[[10., 1.], [10., 1.]], |
| [[30., 3.], [30., 3.]], |
| [[-20., -2.], [-20., -2.]]]] |
| ) |
| # expected effect of pruning 1 of the 3 channels by L2-norm |
| expected_mask_axis1 = torch.ones_like(m.weight) |
| expected_mask_axis1[:, 0] = 0. |
| |
| prune.ln_structured(m, 'weight', amount=1, n=2, dim=1, importance_scores=importance_scores) |
| self.assertEqual(expected_mask_axis1, m.weight_mask) |
| |
| # expected effect of pruning 1 of the 2 columns along axis -1 by L1-norm |
| expected_mask_axis3 = expected_mask_axis1 |
| expected_mask_axis3[:, :, :, 1] = 0. |
| |
| prune.ln_structured(m, 'weight', amount=1, n=1, dim=-1, importance_scores=importance_scores) |
| self.assertEqual(expected_mask_axis3, m.weight_mask) |
| |
| def test_remove_pruning(self): |
| r"""`prune.remove` removes the hook and the reparametrization |
| and makes the pruning final in the original parameter. |
| """ |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| # first prune |
| prune.random_unstructured(m, name, amount=0.5) |
| self.assertIn(name + "_orig", dict(m.named_parameters())) |
| self.assertIn(name + "_mask", dict(m.named_buffers())) |
| self.assertNotIn(name, dict(m.named_parameters())) |
| self.assertTrue(hasattr(m, name)) |
| pruned_t = getattr(m, name) |
| |
| # then remove pruning |
| prune.remove(m, name) |
| self.assertIn(name, dict(m.named_parameters())) |
| self.assertNotIn(name + "_orig", dict(m.named_parameters())) |
| self.assertNotIn(name + "_mask", dict(m.named_buffers())) |
| final_t = getattr(m, name) |
| |
| self.assertEqual(pruned_t, final_t) |
| |
| def test_remove_pruning_exception(self): |
| r"""Removing from an unpruned tensor throws an assertion error |
| """ |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| # check that the module isn't pruned |
| self.assertFalse(prune.is_pruned(m)) |
| # since it isn't pruned, pruning can't be removed from it |
| with self.assertRaises(ValueError): |
| prune.remove(m, name) |
| |
| |
| def test_global_pruning(self): |
| r"""Test that global l1 unstructured pruning over 2 parameters removes |
| the `amount=4` smallest global weights across the 2 parameters. |
| """ |
| m = nn.Linear(4, 2) |
| n = nn.Linear(3, 1) |
| # modify the weight matrices by hand |
| m.weight = torch.nn.Parameter( |
| torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( |
| dtype=torch.float32) |
| ) |
| n.weight = torch.nn.Parameter( |
| torch.tensor([[0, 0.1, -2]]).to( |
| dtype=torch.float32) |
| ) |
| |
| params_to_prune = ( |
| (m, 'weight'), |
| (n, 'weight'), |
| ) |
| |
| # prune the 4 smallest weights globally by L1 magnitude |
| prune.global_unstructured( |
| params_to_prune, |
| pruning_method=prune.L1Unstructured, |
| amount=4 |
| ) |
| |
| expected_mweight = torch.tensor([[0, 2, 3, 4], [-4, -3, -2, 0]], |
| dtype=m.weight.dtype) |
| self.assertEqual(expected_mweight, m.weight) |
| |
| expected_nweight = torch.tensor([[0, 0, -2]]).to(dtype=n.weight.dtype) |
| self.assertEqual(expected_nweight, n.weight) |
| |
| def test_global_pruning_importance_scores(self): |
| r"""Test that global l1 unstructured pruning over 2 parameters removes |
| the `amount=4` smallest global weights across the 2 parameters. |
| """ |
| m = nn.Linear(4, 2) |
| n = nn.Linear(3, 1) |
| # modify the weight matrices by hand |
| m.weight = torch.nn.Parameter( |
| torch.tensor([[1, 2, 3, 4], [-4, -3, -2, -1]]).to( |
| dtype=torch.float32) |
| ) |
| m_importance_scores = torch.tensor( |
| [[4, 2, 1, 3], [-3, -1, -2, -4]], dtype=torch.float32 |
| ) |
| n.weight = torch.nn.Parameter( |
| torch.tensor([[0, 0.1, -2]]).to( |
| dtype=torch.float32) |
| ) |
| n_importance_scores = torch.tensor([[0, 10., -0.2]]).to(dtype=torch.float32) |
| |
| params_to_prune = ( |
| (m, 'weight'), |
| (n, 'weight'), |
| ) |
| importance_scores = { |
| (m, 'weight'): m_importance_scores, |
| (n, 'weight'): n_importance_scores, |
| } |
| |
| # prune the 4 smallest weights globally by L1 magnitude |
| prune.global_unstructured( |
| params_to_prune, |
| pruning_method=prune.L1Unstructured, |
| amount=4, |
| importance_scores=importance_scores, |
| ) |
| |
| expected_m_weight = torch.tensor([[1, 2, 0, 4], [-4, 0, -2, -1]], |
| dtype=m.weight.dtype) |
| self.assertEqual(expected_m_weight, m.weight) |
| |
| expected_n_weight = torch.tensor([[0, 0.1, 0]]).to(dtype=n.weight.dtype) |
| self.assertEqual(expected_n_weight, n.weight) |
| |
| def test_custom_from_mask_pruning(self): |
| r"""Test that the CustomFromMask is capable of receiving |
| as input at instantiation time a custom mask, and combining it with |
| the previous default mask to generate the correct final mask. |
| """ |
| # new mask |
| mask = torch.tensor([[0, 1, 1, 0], [0, 0, 1, 1]]) |
| # old mask |
| default_mask = torch.tensor([[0, 0, 0, 0], [1, 1, 1, 1]]) |
| |
| # some tensor (not actually used) |
| t = torch.rand_like(mask.to(dtype=torch.float32)) |
| |
| p = prune.CustomFromMask(mask=mask) |
| |
| computed_mask = p.compute_mask(t, default_mask) |
| expected_mask = torch.tensor([[0, 0, 0, 0], [0, 0, 1, 1]], dtype=computed_mask.dtype) |
| |
| self.assertEqual(computed_mask, expected_mask) |
| |
| def test_pruning_rollback(self): |
| r"""Test that if something fails when the we try to compute the mask, |
| then the model isn't left in some intermediate half-pruned state. |
| The try/except statement in `apply` should handle rolling back |
| to the previous state before pruning began. |
| """ |
| modules = [nn.Linear(5, 7), nn.Conv3d(2, 2, 2)] |
| names = ['weight', 'bias'] |
| |
| for m in modules: |
| for name in names: |
| with self.subTest(m=m, name=name): |
| |
| with mock.patch( |
| "torch.nn.utils.prune.L1Unstructured.compute_mask" |
| ) as compute_mask: |
| compute_mask.side_effect = Exception('HA!') |
| with self.assertRaises(Exception): |
| prune.l1_unstructured(m, name=name, amount=0.9) |
| |
| self.assertTrue( |
| name in dict(m.named_parameters()) |
| ) |
| self.assertFalse( |
| name + '_mask' in dict(m.named_buffers()) |
| ) |
| self.assertFalse( |
| name + '_orig' in dict(m.named_parameters()) |
| ) |
| |
| def test_pruning_serialization_model(self): |
| # create a model |
| model = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 1), |
| ) |
| # check that everything looks normal before pruning |
| self.assertNotIn('0.weight_orig', model.state_dict()) |
| self.assertNotIn('0.weight_mask', model.state_dict()) |
| self.assertIn('0.weight', model.state_dict()) |
| |
| # prune one of its parameters |
| prune.l1_unstructured(module=model[0], name='weight', amount=0.9) |
| |
| # check that the original weight and the new mask are present |
| self.assertIn('0.weight_orig', model.state_dict()) |
| self.assertIn('0.weight_mask', model.state_dict()) |
| self.assertNotIn('0.weight', model.state_dict()) |
| self.assertTrue(hasattr(model[0], 'weight')) |
| |
| pruned_weight = model[0].weight |
| |
| with TemporaryFileName() as fname: |
| torch.save(model, fname) |
| new_model = torch.load(fname) |
| |
| # check that the original weight and the new mask are present |
| self.assertIn('0.weight_orig', new_model.state_dict()) |
| self.assertIn('0.weight_mask', new_model.state_dict()) |
| self.assertNotIn('0.weight', new_model.state_dict()) |
| self.assertTrue(hasattr(new_model[0], 'weight')) |
| |
| self.assertEqual(pruned_weight, new_model[0].weight) |
| |
| def test_pruning_serialization_state_dict(self): |
| # create a model |
| model = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 1), |
| ) |
| # check that everything looks normal before pruning |
| self.assertNotIn('0.weight_orig', model.state_dict()) |
| self.assertNotIn('0.weight_mask', model.state_dict()) |
| self.assertIn('0.weight', model.state_dict()) |
| |
| # prune one of its parameters |
| prune.l1_unstructured(module=model[0], name='weight', amount=0.9) |
| |
| # check that the original weight and the new mask are present |
| self.assertIn('0.weight_orig', model.state_dict()) |
| self.assertIn('0.weight_mask', model.state_dict()) |
| self.assertNotIn('0.weight', model.state_dict()) |
| self.assertTrue(hasattr(model[0], 'weight')) |
| |
| pruned_weight = model[0].weight |
| |
| # make pruning permanent and restore parameter names as in base |
| # architecture |
| prune.remove(module=model[0], name='weight') |
| |
| # check that the original weight and the new mask are no longer present |
| self.assertNotIn('0.weight_orig', model.state_dict()) |
| self.assertNotIn('0.weight_mask', model.state_dict()) |
| self.assertIn('0.weight', model.state_dict()) |
| |
| # save the state dict of model and reload it into new_model |
| new_model = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 1), |
| ) |
| with TemporaryFileName() as fname: |
| torch.save(model.state_dict(), fname) |
| new_model.load_state_dict(torch.load(fname)) |
| |
| # check that the original weight and the new mask are not present in |
| # new_model either. |
| self.assertNotIn('0.weight_orig', new_model.state_dict()) |
| self.assertNotIn('0.weight_mask', new_model.state_dict()) |
| self.assertIn('0.weight', new_model.state_dict()) |
| |
| self.assertEqual(pruned_weight, new_model[0].weight) |
| |
| def test_prune(self): |
| # create a new pruning method |
| p = prune.L1Unstructured(amount=2) |
| # create tensor to be pruned |
| t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) |
| # create prior mask by hand |
| default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) |
| # since we are pruning the two lowest magnitude units, the outcome of |
| # the calculation should be this: |
| expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) |
| pruned_tensor = p.prune(t, default_mask) |
| self.assertEqual(t * expected_mask, pruned_tensor) |
| |
| def test_prune_importance_scores(self): |
| # create a new pruning method |
| p = prune.L1Unstructured(amount=2) |
| # create tensor to be pruned |
| t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) |
| importance_scores = torch.tensor( |
| [[1, 2, 3, 4], [1.5, 1.6, 1.7, 1.8]] |
| ).to(dtype=torch.float32) |
| # create prior mask by hand |
| default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) |
| # since we are pruning the two lowest magnitude units, the outcome of |
| # the calculation should be this: |
| expected_mask = torch.tensor([[0, 1, 1, 0], [0, 1, 0, 1]]) |
| pruned_tensor = p.prune(t, default_mask, importance_scores=importance_scores) |
| self.assertEqual(t * expected_mask, pruned_tensor) |
| |
| def test_prune_importance_scores_mimic_default(self): |
| # create a new pruning method |
| p = prune.L1Unstructured(amount=2) |
| # create tensor to be pruned |
| t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]).to(dtype=torch.float32) |
| # create prior mask by hand |
| default_mask = torch.tensor([[1, 1, 1, 0], [1, 1, 0, 1]]) |
| # since we are pruning the two lowest magnitude units, the outcome of |
| # the calculation should be this: |
| expected_mask = torch.tensor([[0, 0, 1, 0], [1, 1, 0, 1]]) |
| pruned_tensor_without_importance_scores = p.prune(t, default_mask) |
| pruned_tensor_with_importance_scores = p.prune(t, default_mask, importance_scores=t) |
| self.assertEqual(pruned_tensor_without_importance_scores, pruned_tensor_with_importance_scores) |
| self.assertEqual(t * expected_mask, pruned_tensor_without_importance_scores) |
| |
| def test_rnn_pruning(self): |
| l = torch.nn.LSTM(32, 32) |
| # This Module has 4 parameters called: |
| # 'weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0' |
| |
| # Pruning one of them causes one of the weights to become a tensor |
| prune.l1_unstructured(l, 'weight_ih_l0', 0.5) |
| assert ( |
| sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) |
| == 3 |
| ) |
| |
| # Removing the pruning reparametrization restores the Parameter |
| prune.remove(l, 'weight_ih_l0') |
| assert ( |
| sum([isinstance(p, torch.nn.Parameter) for p in l._flat_weights]) |
| == 4 |
| ) |
| |
| # Make sure that, upon removal of the reparametrization, the |
| # `._parameters` and `.named_parameters` contain the right params. |
| # Specifically, the original weight ('weight_ih_l0') should be placed |
| # back in the parameters, while the reparametrization component |
| # ('weight_ih_l0_orig') should be removed. |
| assert 'weight_ih_l0' in l._parameters |
| assert l._parameters['weight_ih_l0'] is not None |
| assert 'weight_ih_l0_orig' not in l._parameters |
| assert 'weight_ih_l0' in dict(l.named_parameters()) |
| assert dict(l.named_parameters())['weight_ih_l0'] is not None |
| assert 'weight_ih_l0_orig' not in dict(l.named_parameters()) |
| |
| instantiate_parametrized_tests(TestPruningNN) |
| |
| if __name__ == '__main__': |
| run_tests() |