blob: 415679337ff2ee6eeac05b717265f87a504758e7 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import itertools
import logging
import re
import torch
from torch import nn
from torch.ao.pruning import BaseSparsifier, WeightNormSparsifier, FakeSparsity, NearlyDiagonalSparsifier
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16)
)
self.linear = nn.Linear(16, 16)
self.head = nn.Linear(16, 4)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
x = self.head(x)
return x
class ImplementedSparsifier(BaseSparsifier):
def __init__(self, **kwargs):
super().__init__(defaults=kwargs)
def update_mask(self, module, **kwargs):
module.parametrizations.weight[0].mask[0] = 0
linear_state = self.state['linear.weight']
linear_state['step_count'] = linear_state.get('step_count', 0) + 1
class TestBaseSparsifier(TestCase):
def test_constructor(self):
# Cannot instantiate the abstract base
self.assertRaises(TypeError, BaseSparsifier)
# Can instantiate the model with no configs
model = Model()
sparsifier = ImplementedSparsifier(test=3)
sparsifier.prepare(model, config=None)
assert len(sparsifier.groups) == 3
sparsifier.step()
# Can instantiate the model with configs
sparsifier = ImplementedSparsifier(test=3)
sparsifier.prepare(model, [{'tensor_fqn': 'linear.weight'}])
assert len(sparsifier.groups) == 1
assert sparsifier.groups[0]['tensor_fqn'] == 'linear.weight'
assert 'test' in sparsifier.groups[0]
assert sparsifier.groups[0]['test'] == 3
def test_prepare_config(self):
model = Model()
sparsifier = ImplementedSparsifier(test=3)
# Make sure there are no parametrizations before `prepare`
assert not hasattr(model.seq[0], 'parametrizations')
assert not hasattr(model.linear, 'parametrizations')
assert not hasattr(model.head, 'parametrizations')
sparsifier.prepare(model, config=[
{'tensor_fqn': 'seq.0.weight', 'test': 42},
# No 'linear' to make sure it will be skipped in the sparsification
{'tensor_fqn': 'head.weight'}
])
assert len(sparsifier.groups) == 2
# Check if default argument is not assigned if explicit
assert sparsifier.groups[0]['tensor_fqn'] == 'seq.0.weight'
assert sparsifier.groups[0]['test'] == 42
# Check if FQN and module are pointing to the same location
assert sparsifier.groups[1]['tensor_fqn'] == 'head.weight'
assert sparsifier.groups[1]['module'] == model.head
# Check if parameterizations are attached
assert hasattr(model.seq[0], 'parametrizations')
assert not hasattr(model.linear, 'parametrizations')
assert hasattr(model.head, 'parametrizations')
def test_step(self):
model = Model()
sparsifier = ImplementedSparsifier(test=3)
sparsifier.enable_mask_update = True
sparsifier.prepare(model, [{'tensor_fqn': 'linear.weight'}])
sparsifier.step()
assert torch.all(model.linear.parametrizations.weight[0].mask[0] == 0)
def test_state_dict(self):
step_count = 3
model0 = Model()
sparsifier0 = ImplementedSparsifier(test=3)
sparsifier0.prepare(model0, [{'tensor_fqn': 'linear.weight'}])
mask = model0.linear.parametrizations['weight'][0].mask
mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
for step in range(step_count):
sparsifier0.step()
state_dict = sparsifier0.state_dict()
# Check the expected keys in the state_dict
assert 'state' in state_dict
assert 'step_count' in state_dict['state']['linear.weight']
assert state_dict['state']['linear.weight']['step_count'] == 3
assert 'groups' in state_dict
assert 'test' in state_dict['groups'][0]
assert 'tensor_fqn' in state_dict['groups'][0]
assert state_dict['groups'][0]['tensor_fqn'] == 'linear.weight'
# Check loading static_dict creates an equivalent model
model1 = Model()
sparsifier1 = ImplementedSparsifier()
sparsifier1.prepare(model1, None)
assert sparsifier0.state != sparsifier1.state
# Make sure the masks are different in the beginning
for mg in sparsifier0.groups:
if mg['tensor_fqn'] == 'linear.weight':
mask0 = mg['module'].parametrizations.weight[0].mask
for mg in sparsifier1.groups:
if mg['tensor_fqn'] == 'linear.weight':
mask1 = mg['module'].parametrizations.weight[0].mask
self.assertNotEqual(mask0, mask1)
sparsifier1.load_state_dict(state_dict)
# Make sure the states are loaded, and are correct
assert sparsifier0.state == sparsifier1.state
# Make sure the masks (and all dicts) are the same after loading
assert len(sparsifier0.groups) == len(sparsifier1.groups)
for idx in range(len(sparsifier0.groups)):
mg0 = sparsifier0.groups[idx]
mg1 = sparsifier1.groups[idx]
for key in mg0.keys():
assert key in mg1
if key == 'module':
# We cannot compare modules as they are different
param0 = mg0[key].parametrizations.weight[0]
param1 = mg1[key].parametrizations.weight[0]
assert hasattr(param0, 'mask')
assert hasattr(param1, 'mask')
self.assertEqual(param0.__dict__, param1.__dict__)
else:
assert mg0[key] == mg1[key]
def test_mask_squash(self):
model = Model()
sparsifier = ImplementedSparsifier(test=3)
sparsifier.prepare(model, [{'tensor_fqn': 'linear.weight'}])
assert hasattr(model.linear.parametrizations.weight[0], 'mask')
assert is_parametrized(model.linear, 'weight')
assert not is_parametrized(model.seq[0], 'weight')
sparsifier.squash_mask()
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear, 'weight')
def test_mask_squash_with_params1(self):
model = Model()
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
sparsifier.prepare(model, [{'tensor_fqn': 'linear.weight'}, {'tensor_fqn': 'seq.0.weight'}])
sparsifier.squash_mask(
params_to_keep_per_layer={
'linear': ('foo', 'bar'),
'seq.0': ('baz',)
})
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear, 'weight')
assert hasattr(model.seq[0], 'sparse_params')
assert hasattr(model.linear, 'sparse_params')
assert model.seq[0].sparse_params.get('foo', None) is None
assert model.seq[0].sparse_params.get('bar', None) is None
assert model.seq[0].sparse_params.get('baz', None) == 1
assert model.linear.sparse_params.get('foo', None) == 3
assert model.linear.sparse_params.get('bar', None) == 2
assert model.linear.sparse_params.get('baz', None) is None
def test_mask_squash_with_params2(self):
model = Model()
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
sparsifier.prepare(model, [{'tensor_fqn': 'linear.weight'}, {'tensor_fqn': 'seq.0.weight'}])
sparsifier.squash_mask(params_to_keep=('foo', 'bar'))
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear, 'weight')
assert hasattr(model.seq[0], 'sparse_params')
assert hasattr(model.linear, 'sparse_params')
assert model.seq[0].sparse_params.get('foo', None) == 3
assert model.seq[0].sparse_params.get('bar', None) == 2
assert model.seq[0].sparse_params.get('baz', None) is None
assert model.linear.sparse_params.get('foo', None) == 3
assert model.linear.sparse_params.get('bar', None) == 2
assert model.linear.sparse_params.get('baz', None) is None
def test_mask_squash_with_params3(self):
model = Model()
sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
sparsifier.prepare(model, [{'tensor_fqn': 'linear.weight'}, {'tensor_fqn': 'seq.0.weight'}])
sparsifier.squash_mask(
params_to_keep=('foo', 'bar'),
params_to_keep_per_layer={'seq.0': ('baz',)})
assert not is_parametrized(model.seq[0], 'weight')
assert not is_parametrized(model.linear, 'weight')
assert hasattr(model.seq[0], 'sparse_params')
assert hasattr(model.linear, 'sparse_params')
assert model.seq[0].sparse_params.get('foo', None) == 3
assert model.seq[0].sparse_params.get('bar', None) == 2
assert model.seq[0].sparse_params.get('baz', None) == 1
assert model.linear.sparse_params.get('foo', None) == 3
assert model.linear.sparse_params.get('bar', None) == 2
assert model.linear.sparse_params.get('baz', None) is None
class TestWeightNormSparsifier(TestCase):
def test_constructor(self):
model = Model()
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
assert isinstance(g['module'], nn.Linear)
# The groups are unordered
assert g['module_fqn'] in ('seq.0', 'linear', 'head')
def test_step(self):
model = Model()
sparsifier = WeightNormSparsifier(sparsity_level=0.5)
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear.weight'}])
for g in sparsifier.groups:
# Before step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) == 0 # checking sparsity level is 0
sparsifier.enable_mask_update = True
sparsifier.step()
self.assertAlmostEqual(model.linear.parametrizations['weight'][0].mask.mean().item(), 0.5, places=2)
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
# Test if the mask collapses to all zeros if the weights are randomized
iters_before_collapse = 1000
for _ in range(iters_before_collapse):
model.linear.weight.data = torch.randn(model.linear.weight.shape)
sparsifier.step()
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse
def test_step_2_of_4(self):
model = Model()
sparsifier = WeightNormSparsifier(sparsity_level=1.0,
sparse_block_shape=(1, 4),
zeros_per_block=2)
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear.weight'}])
sparsifier.step()
# make sure the sparsity level is approximately 50%
mask = model.linear.parametrizations['weight'][0].mask.to(torch.float) # mean works on float only
self.assertAlmostEqual(mask.mean().item(), 0.5, places=2)
# Make sure each block has exactly 50% zeros
module = sparsifier.groups[0]['module']
mask = module.parametrizations['weight'][0].mask
for row in mask:
for idx in range(0, len(row), 4):
block = row[idx:idx + 4]
block, _ = block.sort()
assert (block[:2] == 0).all()
assert (block[2:] != 0).all()
def test_prepare(self):
model = Model()
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
module = g['module']
# Check mask exists
assert hasattr(module.parametrizations['weight'][0], 'mask')
# Check parametrization exists and is correct
assert is_parametrized(module, 'weight')
assert type(module.parametrizations.weight[0]) == FakeSparsity
def test_mask_squash(self):
model = Model()
sparsifier = WeightNormSparsifier()
sparsifier.prepare(model, config=None)
sparsifier.squash_mask()
for g in sparsifier.groups:
module = g['module']
assert not is_parametrized(module, 'weight')
assert not hasattr(module, 'mask')
def test_sparsity_levels(self):
sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
zeros_per_blocks = [0, 1, 2, 3, 4]
testcases = itertools.tee(itertools.product(sparsity_levels,
sparse_block_shapes,
zeros_per_blocks))
# Create a config and model with all the testcases
model = nn.Sequential()
sparsifier = WeightNormSparsifier()
sparsity_per_layer_config = []
p = re.compile(r'[-\.\s]')
for sl, sbs, zpb in testcases[0]:
# Make sure the number of zeros is not > values in a block
if zpb > sbs[0] * sbs[1]:
continue
layer_name = f'{sl}_{sbs}_{zpb}'
layer_name = p.sub('_', layer_name)
layer = nn.Linear(12, 12, bias=False)
layer.weight = nn.Parameter(torch.ones(12, 12))
model.add_module(layer_name, layer)
config = {
'tensor_fqn': layer_name + ".weight",
'sparsity_level': sl,
'sparse_block_shape': sbs,
'zeros_per_block': zpb
}
sparsity_per_layer_config.append(config)
sparsifier.prepare(model, sparsity_per_layer_config)
sparsifier.step()
sparsifier.squash_mask()
model.eval()
for sl, sbs, zpb in testcases[1]:
if zpb > sbs[0] * sbs[1]:
continue
layer_name = f'{sl}_{sbs}_{zpb}'
layer_name = p.sub('_', layer_name)
layer = getattr(model, layer_name)
# Level of sparsity is achieved
sparse_mask = (layer.weight == 0).float()
if zpb == 0:
assert sparse_mask.mean() == 0
else:
# Ratio of individual zeros in the tensor
true_sl = min(max(sl, 0.0), 1.0)
true_sl = true_sl * zpb / sbs[0] / sbs[1]
assert sparse_mask.mean() == true_sl
class TestNearlyDiagonalSparsifier(TestCase):
def test_constructor(self):
model = Model()
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
assert isinstance(g['module'], nn.Linear)
# The groups are unordered
assert g['module_fqn'] in ('seq.0', 'linear', 'head')
def test_step(self):
model = Model()
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=[{'tensor_fqn': 'linear.weight'}])
for g in sparsifier.groups:
# Before step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) == 0 # checking sparsity level is 0
sparsifier.enable_mask_update = True
sparsifier.step()
mask = module.parametrizations['weight'][0].mask
height, width = mask.shape
assert torch.all(mask == torch.eye(height, width))
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level has increased
# Test if the mask collapses to all zeros if the weights are randomized
iters_before_collapse = 1000
for _ in range(iters_before_collapse):
model.linear.weight.data = torch.randn(model.linear.weight.shape)
sparsifier.step()
for g in sparsifier.groups:
# After step
module = g['module']
assert (1.0 - module.parametrizations['weight'][0].mask.mean()) > 0 # checking sparsity level did not collapse
def test_prepare(self):
model = Model()
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=None)
for g in sparsifier.groups:
module = g['module']
# Check mask exists
assert hasattr(module.parametrizations['weight'][0], 'mask')
# Check parametrization exists and is correct
assert is_parametrized(module, 'weight')
assert type(module.parametrizations.weight[0]) == FakeSparsity
def test_mask_squash(self):
model = Model()
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
sparsifier.prepare(model, config=None)
sparsifier.step()
sparsifier.squash_mask()
for g in sparsifier.groups:
module = g['module']
assert not is_parametrized(module, 'weight')
assert not hasattr(module, 'mask')
weights = module.weight
height, width = weights.shape
assert torch.all(weights == torch.eye(height, width) * weights) # only diagonal to be present
def test_sparsity_levels(self):
nearliness_levels = list(nearliness for nearliness in range(-1, 100))
model = nn.Sequential()
p = re.compile(r'[-\.\s]')
for nearliness in nearliness_levels:
sparsifier = NearlyDiagonalSparsifier(nearliness=1)
layer_name = f'{nearliness}'
layer_name = p.sub('_', layer_name)
layer = nn.Linear(32, 32, bias=False)
layer.weight = nn.Parameter(torch.ones(32, 32))
width, height = layer.weight.shape
model.add_module(layer_name, layer)
config = {
'tensor_fqn': layer_name + ".weight",
'nearliness': nearliness
}
sparsifier.prepare(model, [config])
# should raise a ValueError when nearliness arg is illegal
if (nearliness > 0 and nearliness % 2 == 0) or (nearliness // 2 >= min(width, height)):
with self.assertRaises(ValueError):
sparsifier.step()
else:
sparsifier.step()
sparsifier.squash_mask()
model.eval()
layer = getattr(model, layer_name)
# verify that mask created corresponds to the nearliness
self._verify_nearliness(layer.weight, nearliness)
# helper function to verify nearliness of a mask
def _verify_nearliness(self, mask: torch.Tensor, nearliness: int):
if nearliness <= 0:
assert torch.all(mask == torch.zeros(mask.shape[0], mask.shape[1]))
else:
height, width = mask.shape
dist_to_diagonal = nearliness // 2
for row in range(0, height):
for col in range(0, width):
if abs(row - col) <= dist_to_diagonal:
assert mask[row, col] == 1
else:
assert mask[row, col] == 0