blob: 8f5f6dd19abbefa7030728dc7263010cc40ec6ba [file] [log] [blame]
# -*- coding: utf-8 -*-
import logging
import torch
from torch import nn
from torch.ao.sparsity import BasePruner, PruningParametrization
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
DEVICES = {"cpu", "cuda" if torch.cuda.is_available() else "cpu"}
class Linear(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=False)
)
self.linear = nn.Linear(16, 16, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class LinearB(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=True)
)
self.linear = nn.Linear(16, 16, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinear(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=False),
nn.ReLU(),
nn.Linear(5, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 6, bias=False)
)
self.linear = nn.Linear(6, 4, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinearB(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 8, bias=True),
nn.ReLU(),
nn.Linear(8, 6, bias=True)
)
self.linear = nn.Linear(6, 4, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinearMixed(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 6, bias=True)
)
self.linear = nn.Linear(6, 4, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class Conv2dA(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=False),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dB(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dC(nn.Module):
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class SimplePruner(BasePruner):
def update_mask(self, layer, **kwargs):
layer.parametrizations.weight[0].pruned_outputs.add(1)
class MultiplePruner(BasePruner):
def update_mask(self, layer, **kwargs):
layer.parametrizations.weight[0].pruned_outputs.update([1, 2])
class TestBasePruner(TestCase):
def _check_pruner_prepared(self, model, pruner, device):
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
# Check mask exists
assert hasattr(module, 'mask')
# Check parametrization exists and is correct
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
assert type(module.parametrizations.weight[0]) == PruningParametrization
def _check_pruner_converted(self, model, pruner, device):
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
assert not hasattr(module, "parametrizations")
assert not hasattr(module, 'mask')
def _check_pruner_valid_before_step(self, model, pruner, device):
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
assert module.parametrizations.weight[0].pruned_outputs == set()
def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device):
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
assert module.parametrizations.weight[0].pruned_outputs == pruned_set
def _test_constructor_on_device(self, model, device):
self.assertRaisesRegex(TypeError, 'with abstract methods update_mask',
BasePruner)
model = model.to(device)
pruner = SimplePruner(model, None, None)
for g in pruner.module_groups:
module = g['module']
assert module.weight.device == device
assert len(pruner.module_groups) == 2
pruner.step()
# Can instantiate the model with configs
pruner = SimplePruner(model, [model.linear], {'test': 3})
assert len(pruner.module_groups) == 1
assert pruner.module_groups[0]['path'] == 'linear'
assert 'test' in pruner.module_groups[0]
assert pruner.module_groups[0]['test'] == 3
def test_constructor(self):
model = Linear()
for device in DEVICES:
self._test_constructor_on_device(model, torch.device(device))
def _test_prepare_linear_on_device(self, model, device):
model = model.to(device)
x = torch.ones(128, 16)
pruner = SimplePruner(model, None, None)
pruner.prepare()
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (128, 16)
def test_prepare_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
self._test_prepare_linear_on_device(model, torch.device(device))
def _test_prepare_conv2d_on_device(self, model, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28))
pruner = SimplePruner(model, None, None)
pruner.prepare()
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
def test_prepare_conv2d(self):
models = [Conv2dA(), Conv2dB(), Conv2dC()]
for device in DEVICES:
for model in models:
self._test_prepare_conv2d_on_device(model, torch.device(device))
def _test_convert_linear_on_device(self, model, device):
model = model.to(device)
x = torch.ones(128, 16)
pruner = SimplePruner(model, None, None)
pruner.prepare()
pruner.convert()
self._check_pruner_converted(model, pruner, device)
assert model(x).shape == (128, 16)
def test_convert_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
self._test_convert_linear_on_device(model, torch.device(device))
def _test_convert_conv2d_on_device(self, model, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28))
pruner = SimplePruner(model, None, None)
pruner.prepare()
pruner.convert()
self._check_pruner_converted(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
def test_convert_conv2d(self):
models = [Conv2dA(), Conv2dB(), Conv2dC()]
for device in DEVICES:
for model in models:
self._test_convert_conv2d_on_device(model, torch.device(device))
def _test_step_linear_on_device(self, model, is_basic, device):
model = model.to(device)
if is_basic:
x = torch.ones(16, 16)
pruner = SimplePruner(model, None, None)
pruner.prepare()
pruner.enable_mask_update = True
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
else:
x = torch.ones(7, 7)
pruner = MultiplePruner(model, None, None)
pruner.prepare()
pruner.enable_mask_update = True
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
def test_step_linear(self):
basic_models = [Linear(), LinearB()]
complex_models = [MultipleLinear(), MultipleLinearB(), MultipleLinearMixed()]
for device in DEVICES:
for model in basic_models:
self._test_step_linear_on_device(model, True, torch.device(device))
for model in complex_models:
self._test_step_linear_on_device(model, False, torch.device(device))
def _test_step_conv2d_on_device(self, model, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28))
pruner = SimplePruner(model, None, None)
pruner.prepare()
pruner.enable_mask_update = True
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
assert model(x).shape == (1, 64, 24, 24)
def test_step_conv2d(self):
models = [Conv2dA(), Conv2dB(), Conv2dC()]
for device in DEVICES:
for model in models:
self._test_step_conv2d_on_device(model, torch.device(device))