blob: 47b7b553768037ea152d4c929a0a44caa915029b [file] [log] [blame]
# -*- coding: utf-8 -*-
# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.jit
from torch._C import parse_ir
# torch.quantization
from torch.quantization import QConfig
from torch.quantization import default_dynamic_qconfig
from torch.quantization import QConfigDynamic
from torch.quantization import default_observer
from torch.quantization import default_per_channel_weight_observer
from torch.quantization import default_qconfig
from torch.quantization import get_default_qconfig
# torch.quantization._quantize_script
from torch.quantization._quantize_script import script_qconfig
from torch.quantization._quantize_script import prepare_script
from torch.quantization._quantize_script import convert_script
from torch.quantization._quantize_script import quantize_script
from torch.quantization._quantize_script import prepare_dynamic_script
from torch.quantization._quantize_script import quantize_dynamic_script
# Testing utils
from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
from torch.testing._internal.common_quantized import override_qengines
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import attrs_with_prefix
from torch.testing._internal.jit_utils import JitTestCase
from torch.testing._internal.jit_utils import get_forward
from torch.testing._internal.jit_utils import get_forward_graph
from torch.testing._internal.jit_utils import get_module_method
from torch.jit._recursive import wrap_cpp_module
# Standard library
import itertools
import unittest
class TestQuantizeScriptJitPasses(JitTestCase):
""" Test graph mode quantization passes used by quantize_script
"""
def test_foldbn_trivial(self):
# Test trivial case
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
self.bn.eps = 0.0023
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
# Check that the transformation doesn't change numerics
for tracing_mode in [True, False]:
eager = TestModule()
eager.eval()
if tracing_mode:
x = torch.rand(1, 1, 6, 6)
scripted_or_traced = torch.jit.trace(eager, x)
else:
scripted_or_traced = torch.jit.script(eager)
scripted_or_traced.eval()
# Check that in the original script module's forward we have two
# CallMethod nodes. One of them should be for conv.forward and the other
# for bn.forward.
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward(scripted_or_traced._c).graph))
# Run FoldConvBatchnorm2d pass.
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
# Check that after the pass one of the CallMethods is gone (supposedly,
# the bn.forward).
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced._c)))
# Check that the transformation doesn't change numerics
x = torch.rand(1, 1, 6, 6)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_trivial_nobias(self):
# Test trivial case
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1, bias=False)
self.bn = torch.nn.BatchNorm2d(num_features=20)
# to make sure new bias is not zero
self.bn.eps = 0.0027
self.bn.bias = torch.nn.Parameter(torch.rand([20]))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
for tracing_mode in [True, False]:
eager = TestModule()
eager.eval()
if tracing_mode:
x = torch.rand(1, 1, 6, 6)
scripted_or_traced = torch.jit.trace(eager, x)
else:
scripted_or_traced = torch.jit.script(eager)
scripted_or_traced.eval()
# Check that in the original script module's forward we have two
# CallMethod nodes. One of them should be for conv.forward and the other
# for bn.forward.
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced._c)))
# Run FoldConvBatchnorm2d pass.
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
# Check that after the pass one of the CallMethods is gone (supposedly,
# the bn.forward).
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced._c)))
# Check that the transformation doesn't change numerics
x = torch.rand(1, 1, 6, 6)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_in_submodule(self):
# Test that we find Conv-BN patterns in submodules
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.conv = torch.nn.Conv2d(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub = SubModule()
def forward(self, x):
x = self.sub(x)
return x
for tracing_mode in [True, False]:
eager = TestModule()
eager.eval()
if tracing_mode:
x = torch.rand(1, 1, 10, 10)
scripted_or_traced = torch.jit.trace(eager, x)
else:
scripted_or_traced = torch.jit.script(eager)
scripted_or_traced.eval()
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
x = torch.rand(1, 1, 10, 10)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_in_customConv2D(self):
# Make sure a custom Conv2D class is not folded
# as we do not know it does.
class CustomConv2D(torch.nn.Module):
def __init__(self, a, b, c, d):
super(CustomConv2D, self).__init__()
def forward(self, x):
return F.relu(x)
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.conv = CustomConv2D(1, 20, 5, 1)
self.bn = torch.nn.BatchNorm2d(num_features=20)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.sub = SubModule()
def forward(self, x):
x = self.sub(x)
return x
for tracing_mode in [True, False]:
eager = TestModule()
eager.eval()
if tracing_mode:
x = torch.rand(1, 20, 10, 10)
scripted_or_traced = torch.jit.trace(eager, x)
else:
scripted_or_traced = torch.jit.script(eager)
scripted_or_traced.eval()
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
x = torch.rand(1, 20, 10, 10)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_shared_classtype(self):
class TestModule(torch.nn.Module):
def __init__(self, bias=False):
super(TestModule, self).__init__()
self.conv1 = torch.nn.Conv2d(5, 5, 3, bias=bias)
self.bn1 = torch.nn.BatchNorm2d(num_features=5)
self.bn1.running_mean.fill_(-0.2)
self.bn1.bias = torch.nn.Parameter(torch.rand([5]))
# to make sure new bias is not zero
self.bn1.eps = 0.0023
self.conv2 = torch.nn.Conv2d(5, 5, 3, bias=bias)
self.bn2 = torch.nn.BatchNorm2d(num_features=5)
self.bn2.eps = 0.0029
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
for tracing_mode in [True, False]:
for bias in [True, False]:
eager = TestModule(bias).eval()
if tracing_mode:
x = torch.rand(1, 5, 6, 6)
scripted_or_traced = torch.jit.trace(eager, x).copy()
else:
scripted_or_traced = torch.jit.script(eager).copy()
torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
folded = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced ._c))
x = torch.rand(1, 5, 6, 6)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_complex_cases(self):
# This test case attempt to try combinations of conv2d with bias/nobias
# as well as BatchNorm with affine/no-affine along with varying the
# number of layers.
# this only works when default dtype is double
torch.set_default_dtype(torch.double)
class SubModule(torch.nn.Module):
def __init__(self, num_blocks, enable_bias, enable_affine):
super(SubModule, self).__init__()
layers = []
for i in range(num_blocks):
layers.append(torch.nn.Conv2d(20, 20, 5, 1, bias=enable_bias))
bn_obj = torch.nn.BatchNorm2d(num_features=20, affine=enable_affine)
if enable_affine:
bn_obj.weight = torch.nn.Parameter(torch.rand_like(bn_obj.weight))
bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias))
bn_obj.running_mean = torch.rand_like(bn_obj.running_mean)
bn_obj.running_var = torch.rand_like(bn_obj.running_var)
layers.append(bn_obj)
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class TestModule(torch.nn.Module):
def __init__(self, num_blocks, enable_bias, enable_affine):
super(TestModule, self).__init__()
self.sub = SubModule(num_blocks, enable_bias, enable_affine)
def forward(self, x):
x = self.sub(x)
return x
bias_affine_options = itertools.product([True, False], [True, False], [True, False], [1, 2])
for (tracing_mode, enable_bias, enable_bn_affine, num_layers) in bias_affine_options:
eager = TestModule(num_layers, enable_bias, enable_bn_affine)
eager.eval()
if tracing_mode:
x = torch.rand(1, 20, 10, 10)
scripted_or_traced = torch.jit.trace(eager, x)
else:
scripted_or_traced = torch.jit.script(eager)
scripted_or_traced.eval()
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers * 2, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers, exactly=True) \
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
x = torch.rand(1, 20, 10, 10)
self.assertEqual(eager(x), scripted_or_traced(x))
torch.set_default_dtype(torch.float)
def test_fuse_linear(self):
input_strs = ["""
graph(%input, %weight, %bias, %4):
# CHECK-NOT: aten::t
# CHECK-NOT: aten::addmm
# CHECK: aten::linear
%weight_t = aten::t(%weight)
%res = aten::addmm(%bias, %input, %weight_t, %4, %4)
return (%res)""", """
graph(%input, %weight, %bias, %4):
# CHECK-NOT: aten::t
# CHECK-NOT: aten::matmul
# CHECK-NOT: aten::add_
# CHECK: aten::linear
%weight_t = aten::t(%weight)
%output = aten::matmul(%input, %weight_t)
%res = aten::add_(%output, %bias, %4)
return (%res)""", """
graph(%input, %weight):
# CHECK-NOT: aten::t
# CHECK-NOT: aten::matmul
# CHECK: aten::linear
%weight_t = aten::t(%weight)
%output = aten::matmul(%input, %weight_t)
return (%output)"""]
for input_str in input_strs:
graph = parse_ir(input_str)
torch._C._jit_pass_fuse_linear(graph)
FileCheck().run(input_str, graph)
def test_insert_observers(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
def forward(self, x):
return self.conv(x)
m = torch.jit.script(M())
observer = torch.jit.script(default_observer())
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# for input and output of conv
assert len(attrs_with_prefix(m, '_observer_')) == 2
# for weight
assert len(attrs_with_prefix(m.conv, '_observer_')) == 1
def test_insert_observers_child_qconfig(self):
class Sub(torch.nn.Module):
def __init__(self):
super(Sub, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.sub = Sub()
def forward(self, x):
return self.sub(self.conv(x))
m = torch.jit.script(M())
qconfig = script_qconfig(default_qconfig)
qconfig_dict = {
'sub.fc': qconfig
}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward",
qconfig_dict,
False))
# input and output of sub
assert len(attrs_with_prefix(m, '_observer_')) == 2
# not quantized
assert len(attrs_with_prefix(m.conv, '_observer_')) == 0
# no observers since we observe in the outer most call site
assert len(attrs_with_prefix(m.sub, '_observer_')) == 0
# weight of linear
assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1
def test_insert_observers_skip_values(self):
class ConvFunctionalReLU(torch.nn.Module):
def __init__(self):
super(ConvFunctionalReLU, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
def forward(self, x):
return F.relu(self.conv(x))
class ConvReLUModule(torch.nn.Module):
def __init__(self):
super(ConvReLUModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
class AddReLUModule(torch.nn.Module):
def __init__(self):
super(AddReLUModule, self).__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
out = x
out += x
return self.relu(out)
class AddFunctionalReLU(torch.nn.Module):
def __init__(self):
super(AddFunctionalReLU, self).__init__()
def forward(self, x):
out = x
out += x
return F.relu(out)
def attrs_with_prefix(module, prefix):
return [x for x, _ in module._modules._c.items()
if x.startswith(prefix)]
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = torch.jit.script(ConvFunctionalReLU())
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# observer for weight of conv
assert len(attrs_with_prefix(m.conv, '_observer_')) == 1
# observer for input of conv and output of relu
assert len(attrs_with_prefix(m, '_observer_')) == 2
m = torch.jit.script(ConvReLUModule())
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# observer for input of conv and output of relu
assert len(attrs_with_prefix(m, '_observer_')) == 2
# observer for weight of conv
assert len(attrs_with_prefix(m.conv, '_observer_')) == 1
# observer for output of relu
assert len(attrs_with_prefix(m.relu, '_observer_')) == 0
m = torch.jit.script(AddReLUModule())
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
assert len(attrs_with_prefix(m, '_observer')) == 2
assert len(attrs_with_prefix(m.relu, '_observer')) == 0
FileCheck().check('aten::add_') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.check('ReLU = prim::GetAttr') \
.run(str(get_forward_graph(m._c)))
m = torch.jit.script(AddFunctionalReLU())
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
assert len(attrs_with_prefix(m, '_observer')) == 2
FileCheck().check('aten::add_') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.check('CallFunction') \
.check('Observer = prim::GetAttr[name="_observer_') \
.run(str(get_forward_graph(m._c)))
def test_insert_observers_weight_dtype(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
def forward(self, x):
return F.relu(self.conv(x))
m = torch.jit.script(M())
qconfig_dict = {
'': script_qconfig(default_qconfig)
}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
activation_dtypes = set(obs.getattr('dtype') for x, obs in m._modules._c.items()
if x.startswith('_observer_'))
weight_dtypes = set(obs.getattr('dtype') for x, obs in m.conv._modules._c.items()
if x.startswith('_observer_'))
assert len(activation_dtypes) == 1, 'Expected to have 1 activation dtype'
assert len(weight_dtypes) == 1, 'Expected to have 1 weight dtype'
assert list(activation_dtypes)[0] != list(weight_dtypes)[0], 'Expected activation dtype to '
' be different from wegiht dtype'
def test_insert_observers_for_reused_weight(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, x, y, weight):
x = F.conv2d(x, weight)
y = F.conv2d(y, weight)
return x + y
m = torch.jit.script(M()).eval()
m = prepare_script(m, {'': default_qconfig}, False)
# 3 for x, y, weight, one for output of each F.conv2d and one for output of add
assert len(attrs_with_prefix(m, '_observer')) == 6
def test_insert_observers_shared_class_type(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 5, 3).float()
self.conv2 = torch.nn.Conv2d(3, 5, 3).float()
def forward(self, x):
return self.conv2(self.conv1(x))
m = torch.jit.script(M())
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# conv1 and conv2 shares the same type, we need to
# make sure we didn't quantize the type twice
conv1_observers = attrs_with_prefix(m.conv1, '_observer_')
conv2_observers = attrs_with_prefix(m.conv2, '_observer_')
assert len(conv1_observers) == 1, \
'Expected to have 1 observer submodules'
assert len(conv2_observers) == 1, \
'Expected to have 1 observer submodules'
assert conv1_observers == conv2_observers, \
'Expect conv1 and conv2 to have same observers since the class type is shared'
def test_insert_observers_for_general_ops(self):
""" Make sure we skip observers for ops that doesn't require
observation, e.g. flatten
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x)
return x
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = torch.jit.script(M())
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# input and output of conv
assert len(attrs_with_prefix(m, '_observer_')) == 2
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="conv"]') \
.check('prim::CallMethod') \
.check('Observer = prim::GetAttr[name="_observer_') \
.check('aten::flatten') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
# TODO: this is too long, split this to test_insert_observers.py and remove
# insrt_observers prefix
def test_insert_observers_propagate_observed(self):
""" Make sure we propagate observed property through general ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
x = self.conv1(x)
x = torch.flatten(x)
# we don't want to insert observer for input of self.conv2
# because output of self.conv1 is already observed
x = self.conv2(x)
return x
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = torch.jit.script(M())
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# input and output of conv
assert len(attrs_with_prefix(m, '_observer_')) == 3
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="conv1"]') \
.check('prim::CallMethod') \
.check('Observer = prim::GetAttr[name="_observer_') \
.check('aten::flatten') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="conv2"]') \
.check('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
def test_insert_observers_propagate_observed_in_submodule(self):
""" Make sure we propagate observed property through general ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.avgpool(x)
# we don't want to insert observer for input of self.conv2
# because output of self.conv1 is already observed
x = self.conv2(x)
return x
qconfig_dict = {'': script_qconfig(default_qconfig)}
m = torch.jit.script(M())
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
# input and output of conv
assert len(attrs_with_prefix(m, '_observer_')) == 3
FileCheck().check('Observer = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="conv1"]') \
.check('prim::CallMethod') \
.check('Observer = prim::GetAttr[name="_observer_') \
.check('prim::CallMethod') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="conv2"]') \
.check('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
def test_insert_observers_propagate_observed_for_function(self):
def channel_shuffle(x, groups):
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups,
channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 1).float()
self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
def forward(self, x):
x = self.conv1(x)
x = channel_shuffle(x, 1)
x = self.conv2(x)
return x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
m = torch.jit.script(M()).eval()
m = prepare_script(m, {'': default_qconfig}, inplace=False)
# we want to test that channel_shuffle is going to pass
# the observed property from the output of conv1 to input of conv2
# so that we don't insert observers for input of conv2
assert len(attrs_with_prefix(m, '_observer_',)) == 3
def test_insert_observers_for_if(self):
class Res(torch.nn.Module):
def __init__(self, use_skip):
super(Res, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = use_skip
def forward(self, x):
if self.use_skip:
return self.conv(x)
else:
return self.conv(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.res1 = Res(True)
self.res2 = Res(False)
def forward(self, x):
x = self.res1(x)
x = self.res2(x)
return x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
m = torch.jit.script(M()).eval()
m = prepare_script(m, {'': default_qconfig}, inplace=False)
assert len(attrs_with_prefix(m, '_observer_',)) == 3
def test_insert_quant_dequant(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3).float()
def forward(self, x):
return self.conv(x)
for is_per_channel in [True, False]:
m = torch.jit.script(M())
observer = default_per_channel_weight_observer.with_args(ch_axis=1) \
if is_per_channel else default_observer
qconfig = QConfig(activation=observer, weight=observer)
qconfig_dict = {
'': script_qconfig(qconfig)
}
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, "forward", qconfig_dict, False))
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
m(data)
m = wrap_cpp_module(torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False))
assert len(m._modules._c.items()) == 1, \
'Expected to have single submodule of conv'
# make sure the quantized model is executable
m(data)
quant_func = "aten::quantize_per_channel" if is_per_channel \
else "aten::quantize_per_tensor"
FileCheck().check_count(quant_func, 3, exactly=True) \
.run(m.graph)
def test_insert_quant_dequant_shared_class_type(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
return self.conv2(self.conv1(x))
for is_per_channel in [True, False]:
m = torch.jit.script(M())
observer = default_per_channel_weight_observer.with_args(ch_axis=1) \
if is_per_channel else default_observer
qconfig = QConfig(activation=observer, weight=observer)
qconfig_dict = {'': qconfig}
m = prepare_script(m, qconfig_dict)
# observers for input, output and value between conv1/conv2
assert len(attrs_with_prefix(m, '_observer_')) == 3, \
'Expected to have 3 obervers'
# observer for weight
assert len(attrs_with_prefix(m.conv1, '_observer_')) == 1, \
'Expected to have 1 obervers'
# observer for weight
assert len(attrs_with_prefix(m.conv2, '_observer_')) == 1, \
'Expected to have 1 obervers'
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
m(data)
m = wrap_cpp_module(torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False))
m(data)
assert m.conv1._c._type() == m.conv2._c._type()
# check all observers have been removed
assert len(attrs_with_prefix(m, '_observer_')) == 0, \
'Expected to have 0 obervers'
assert len(attrs_with_prefix(m.conv1, '_observer_')) == 0, \
'Expected to have 0 obervers'
assert len(attrs_with_prefix(m.conv2, '_observer_')) == 0, \
'Expected to have 0 obervers'
quant_func = "aten::quantize_per_channel" if is_per_channel \
else "aten::quantize_per_tensor"
for module in ['conv1', 'conv2']:
conv = m._c.getattr(module)
# quantize weight
FileCheck().check(quant_func) \
.check_next("aten::dequantize") \
.check("prim::CallMethod[name=\"_conv_forward\"]") \
.check("return") \
.run(get_forward_graph(conv))
# no quantize node in _conv_forward
FileCheck().check_not(quant_func) \
.check("aten::conv2d") \
.check_not(quant_func) \
.check("return") \
.run(conv._get_method('_conv_forward').graph)
def test_dedup_module_uses(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(x)
x -= 0.5
return self.relu(x)
data = torch.randn((2, 2))
m = torch.jit.script(M())
ref_res = m(data)
assert len([x for x, _ in m._modules._c.items()
if x.startswith('relu')]) == 1, \
"Expected to have 1 relu modules after dedup module uses"
torch._C._jit_pass_dedup_module_uses(m._c)
m = torch.jit._recursive.wrap_cpp_module(m._c)
res = m(data)
assert len([x for x, _ in m._modules._c.items()
if x.startswith('relu')]) == 2, \
"Expected to have 2 relu modules after dedup module uses"
self.assertEqual(res, ref_res)
def test_replicate_dequantize(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
def forward(self, x):
x = torch.dequantize(x)
r = self.conv(x)
r += x
return r
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
m = torch.jit.script(M())
ref_res = m(x)
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
.run(m.graph)
torch._C._jit_pass_replicate_dequantize(m.graph)
FileCheck().check_count("aten::dequantize", 2, exactly=True) \
.run(m.graph)
res = get_forward(m._c)(x)
self.assertEqual(res, ref_res)
def test_replicate_dequantize_in_block(self):
class M(torch.nn.Module):
def __init__(self, cond):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.cond = cond
def forward(self, x):
x = torch.dequantize(x)
if self.cond:
x = self.conv(x)
else:
x = x + 3
return x
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
m = torch.jit.script(M(True))
ref_res = m(x)
FileCheck().check_count("aten::dequantize", 1, exactly=True) \
.run(m.graph)
torch._C._jit_pass_replicate_dequantize(m.graph)
FileCheck().check_count("aten::dequantize", 2, exactly=True) \
.run(m.graph)
# check dequantize is right before CallMethod of conv
FileCheck().check("aten::dequantize") \
.check_next("CallMethod") \
.run(m.graph)
# check dequantize is right before add
FileCheck().check("aten::dequantize") \
.check("aten::dequantize") \
.check_next("aten::add") \
.run(m.graph)
res = get_forward(m._c)(x)
self.assertEqual(res, ref_res)
def test_swap_dequantize(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = torch.dequantize(x)
x = self.maxpool(x)
x = self.avgpool(x)
return x
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
m = torch.jit.script(M())
ref_res = m(x)
torch._C._jit_pass_inline(m.graph)
FileCheck().check("aten::dequantize") \
.check("aten::max_pool2d") \
.check("aten::adaptive_avg_pool2d") \
.run(m.graph)
torch._C._jit_pass_swap_dequantize(m.graph)
FileCheck().check("aten::max_pool2d") \
.check("aten::adaptive_avg_pool2d") \
.check("dequantize") \
.run(m.graph)
res = get_forward(m._c)(x)
self.assertEqual(res, ref_res)
def test_swap_functional_linear(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, x, weight, bias):
x = torch.dequantize(x)
weight = torch.dequantize(weight)
x = F.linear(x, weight, bias)
x = torch.quantize_per_tensor(x, scale=1.0, zero_point=0, dtype=torch.quint8)
return x
x = torch.rand((10, 5), dtype=torch.float)
x = torch.quantize_per_tensor(x, scale=0.5, zero_point=1, dtype=torch.quint8)
weight = torch.rand((5, 5), dtype=torch.float)
weight = torch.quantize_per_tensor(weight, scale=0.5, zero_point=1, dtype=torch.qint8)
bias = torch.rand((5), dtype=torch.float)
m = torch.jit.script(M())
ref_res = m(x, weight, bias)
FileCheck().check("CallFunction") \
.run(m.graph)
torch._C._jit_pass_swap_functional_linear(m.graph)
FileCheck().check("aten::linear") \
.check_not("CallFunction") \
.run(m.graph)
res = m(x, weight, bias)
self.assertEqual(res, ref_res)
def test_replicate_quantize_for_if(self):
""" We want to move quantize nodes for output of prim::If
inside the prim::If blocks so that we can match quantization
patterns.
"""
class Res(torch.nn.Module):
def __init__(self):
super(Res, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = True
def forward(self, x, cond):
# type: (Tensor, bool) -> Tensor
# to avoid being frozen
self.use_skip = cond
if self.use_skip:
return self.conv(x)
else:
return self.conv(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.res1 = Res()
self.res2 = Res()
def forward(self, x):
x = self.res1(x, True)
x = self.res2(x, False)
return x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
qconfig_dict = {'': default_qconfig}
m = torch.jit.script(M()).eval()
m = quantize_script(m, qconfig_dict, _test_only_eval_fn, [data], inplace=False)
# make sure patterns in both branches are fused
FileCheck().check_count("quantized::conv2d(", 4, exactly=True) \
.run(m.graph)
def test_finalize_for_linear(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc(x)
data = [(torch.rand((1, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
qconfig_dict = {'': default_qconfig}
model = torch.jit.script(M()).eval()
model = quantize_script(model, qconfig_dict, _test_only_eval_fn, [data], inplace=False)
# make sure there is only one quantize_per_tensor for input
# and linear_prepack is folded
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.check_not("quantized::linear_prepack") \
.check("quantized::linear") \
.run(model.graph)
def test_finalize_debug(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
self.avgpool = torch.nn.AvgPool2d(3)
def forward(self, x):
x = self.conv(x)
x = self.avgpool(x)
return x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
qconfig_dict = {'': default_qconfig}
model = torch.jit.script(M()).eval()
model = quantize_script(model, qconfig_dict, _test_only_eval_fn, [data], inplace=False, debug=True)
FileCheck().check_not("quantized::conv2d") \
.check("aten::conv2d") \
.check("aten::avg_pool2d") \
.check("aten::q_scale") \
.check_next("aten::q_zero_point") \
.check_next("prim::dtype") \
.check_next("aten::quantize_per_tensor") \
.check("aten::dequantize") \
.run(model.graph)
def test_module_list(self):
class SimpleLinearLayer(torch.nn.Module):
def __init__(self):
super(SimpleLinearLayer, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc(x)
class ComplexModel(torch.nn.Module):
def __init__(self):
super(ComplexModel, self).__init__()
self.layers = torch.nn.ModuleList([SimpleLinearLayer() for i in range(2)])
def forward(self, x):
# type: (torch.Tensor) -> List[torch.Tensor]
states = []
for layer in self.layers:
val = layer(x)
states.append(val)
return states
data = torch.rand((1, 5), dtype=torch.float)
qconfig_dict = {'': default_qconfig}
model = torch.jit.script(ComplexModel()).eval()
model = prepare_script(model, qconfig_dict)
assert len(attrs_with_prefix(model, '_observer')) == 3
model(data)
model = convert_script(model, debug=False)
FileCheck().check("quantized::linear") \
.check("quantized::linear") \
.run(model.graph)
def test_conv_trace(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1d = torch.nn.Conv1d(3, 3, 3).float()
self.conv2d = torch.nn.Conv2d(3, 3, 3).float()
self.conv3d = torch.nn.Conv3d(3, 3, 3).float()
def forward(self, x, y, z):
a = self.conv1d(x)
b = self.conv2d(y)
c = self.conv3d(z)
return (a, b, c)
qconfig_dict = {'': default_qconfig}
inputs = (torch.rand((1, 3, 10), dtype=torch.float),
torch.rand((1, 3, 10, 10), dtype=torch.float),
torch.rand((1, 3, 10, 10, 10), dtype=torch.float))
model = torch.jit.trace(M(), inputs).eval()
m = prepare_script(model, qconfig_dict)
FileCheck().check('aten::conv1d') \
.check_not("aten::_convolution") \
.run(str(get_forward_graph(m.conv1d._c)))
FileCheck().check('aten::conv2d') \
.check_not("aten::_convolution") \
.run(str(get_forward_graph(m.conv2d._c)))
FileCheck().check('aten::conv3d') \
.check_not("aten::_convolution") \
.run(str(get_forward_graph(m.conv3d._c)))
def test_replicate_dequant_same_value(self):
class Mul(torch.nn.Module):
def __init__(self):
super(Mul, self).__init__()
def forward(self, x):
return x * x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
qconfig_dict = {'': default_qconfig}
model = torch.jit.script(Mul()).eval()
m = quantize_script(model, qconfig_dict, _test_only_eval_fn, [data])
FileCheck().check("quantized::mul") \
.check_not("aten::mul") \
.run(m.graph)
class TestQuantizeScriptPTSQOps(JitTestCase):
""" Test graph mode post training static quantization works
for individual ops end to end.
"""
def _test_op_impl(self, module, data, quantized_op):
qengine = torch.backends.quantized.engine
if qengine == 'none':
qconfig = default_qconfig
else:
qconfig = get_default_qconfig(qengine)
qconfig_dict = {'': qconfig}
model = torch.jit.script(module).eval()
model = quantize_script(model, qconfig_dict, _test_only_eval_fn, [data], inplace=False)
FileCheck().check(quantized_op) \
.run(model.graph)
# make sure it runs
*inputs, target = data[0]
model(*inputs)
return model
@override_qengines
def test_quantized_conv1d(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv1d(3, 3, 3).float()
def forward(self, x):
return self.conv(x)
data = [(torch.rand((1, 3, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
model = self._test_op_impl(M(), data, "quantized::conv1d")
# make sure there is only one quantize_per_tensor for input
# and conv2d_prepack is folded
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.run(model.graph)
FileCheck().check_not("quantized::conv1d_prepack") \
.run(model.graph)
@unittest.skipUnless(
'fbgemm' in torch.backends.quantized.supported_engines,
" with instruction set support avx2 or newer.",
)
def test_quantized_conv2d(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
return self.conv(x)
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
model = self._test_op_impl(M(), data, "quantized::conv2d")
# make sure there is only one quantize_per_tensor for input
# and conv2d_prepack is folded
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.run(model.graph)
FileCheck().check_not("quantized::conv2d_prepack") \
.run(model.graph)
@unittest.skipUnless(
'fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
def test_quantized_conv3d(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv3d(3, 3, 3).float()
def forward(self, x):
return self.conv(x)
data = [(torch.rand((1, 3, 10, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
model = self._test_op_impl(M(), data, "quantized::conv3d")
# make sure there is only one quantize_per_tensor for input
# and conv3d_prepack is folded
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.run(model.graph)
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
def test_quantized_conv2d_relu(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(1, 4, 2, 3).float()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
data = [(torch.randn(1, 1, 10, 10, dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
model = self._test_op_impl(M(), data, "quantized::conv2d_relu")
FileCheck().check_not("aten::conv2d") \
.check_not("aten::relu") \
.check_not("quantized::conv2d(") \
.check_not("quantized::relu(") \
.run(model.graph)
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
def test_quantized_conv3d_relu(self):
class M(torch.nn.Module):
def __init__(self, functional):
super(M, self).__init__()
self.conv = torch.nn.Conv3d(1, 4, 2, 3).float()
if functional:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
data = [(torch.randn(1, 1, 5, 5, 5, dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
model = self._test_op_impl(M(functional=False), data,
"quantized::conv3d_relu")
model_functional = self._test_op_impl(M(functional=True), data,
"quantized::conv3d_relu")
checker = FileCheck().check_not("aten::conv3d") \
.check_not("aten::relu") \
.check_not("quantized::conv3d(") \
.check_not("quantized::relu(")
checker.run(model.graph)
checker.run(model_functional.graph)
def test_quantized_add(self):
class Add(torch.nn.Module):
def __init__(self):
super(Add, self).__init__()
def forward(self, x, y):
return x + y
class InplaceAdd(torch.nn.Module):
def __init__(self):
super(InplaceAdd, self).__init__()
def forward(self, x, y):
x += y
return x
for M in [Add, InplaceAdd]:
m = torch.jit.script(M()).eval()
m = prepare_script(m, {'': default_qconfig}, True)
# two for input tensor, one for output
assert len(attrs_with_prefix(m, '_observer')) == 3
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
m(data, data)
m = convert_script(m, True)
FileCheck().check_not("aten::add") \
.check_not("aten::add_") \
.check("quantized::add") \
.run(m.graph_for(data, data))
def test_quantized_add_scalar(self):
class AddScalar(torch.nn.Module):
def __init__(self):
super(AddScalar, self).__init__()
def forward(self, x):
return x + 3
class InplaceAddScalar(torch.nn.Module):
def __init__(self):
super(InplaceAddScalar, self).__init__()
def forward(self, x):
x += 3
return x
for M in [AddScalar, InplaceAddScalar]:
m = torch.jit.script(M()).eval()
m = prepare_script(m, {'': default_qconfig}, True)
# for input tensor
assert len(attrs_with_prefix(m, '_observer')) == 1
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
m(data)
m = convert_script(m, True)
FileCheck().check_not("aten::add") \
.check_not("aten::add_") \
.check("quantized::add_scalar") \
.run(m.graph_for(data))
def test_quantized_add_relu(self):
class M(torch.nn.Module):
def __init__(self, inplace):
super(M, self).__init__()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x += y
return self.relu(x)
for inplace in [True, False]:
m = torch.jit.script(M(inplace).eval())
m = prepare_script(m, {'': default_qconfig}, True)
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
m(data, data)
m = convert_script(m, True)
FileCheck().check_not("aten::add") \
.check_not("aten::relu") \
.check_not("aten::relu_") \
.check_not("quantized::add") \
.check_not("quantized::relu") \
.check("quantized::add_relu") \
.run(m.graph_for(data, data))
def test_quantized_add_scalar_relu(self):
class AddScalar(torch.nn.Module):
def __init__(self):
super(AddScalar, self).__init__()
def forward(self, x):
return F.relu(x + 3)
class InplaceAddScalar(torch.nn.Module):
def __init__(self):
super(InplaceAddScalar, self).__init__()
def forward(self, x):
x += 3
return F.relu(x)
for M in [AddScalar, InplaceAddScalar]:
m = torch.jit.script(M()).eval()
m = prepare_script(m, {'': default_qconfig}, True)
# for input tensor
assert len(attrs_with_prefix(m, '_observer')) == 1
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
m(data)
m = convert_script(m, True)
FileCheck().check_not("aten::add") \
.check_not("aten::add_") \
.check_not("aten::relu") \
.check_not("quantized::add_scalar") \
.check_not("quantized::relu") \
.check("quantized::add_scalar_relu") \
.run(m.graph_for(data))
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
def test_quantized_cat(self):
""" Note that we to support the case that torch.cat is quantized
indepdently, we need to have an observer that works
for list of Tensors.
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return torch.cat([x, y], 1)
m = torch.jit.script(M().eval())
m = prepare_script(m, {'': default_qconfig}, True)
# four for input and output of conv and one for output of cat
# this also tests the ListConstruct can preserve the observed property so that
# torch.cat knows that inputs are observed
assert len(attrs_with_prefix(m, '_observer_')) == 5
data = torch.randn(1, 1, 10, 10, dtype=torch.float)
m(data, data)
m = convert_script(m, True)
FileCheck().check_not("aten::cat") \
.check("quantized::cat") \
.run(m.graph_for(data, data))
def test_qbatch_norm(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.bn = torch.nn.BatchNorm2d(3).to(torch.float)
def forward(self, x):
return self.bn(x)
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
model = self._test_op_impl(M(), data, "quantized::batch_norm2d")
FileCheck().check_not("aten::batch_norm") \
.run(model.graph)
def test_qbatch_norm_relu(self):
class BNRelu(torch.nn.Module):
def __init__(self, inplace):
super(BNRelu, self).__init__()
self.bn = torch.nn.BatchNorm2d(3).to(torch.float)
self.relu = torch.nn.ReLU(inplace=inplace)
def forward(self, x):
return self.relu(self.bn(x))
# Note Fusion for functional Relu with inplace argument isn't currently supported in fusion patterns.
class BNFuncRelu(torch.nn.Module):
def __init__(self):
super(BNFuncRelu, self).__init__()
self.bn = torch.nn.BatchNorm2d(3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), False)
class BNFuncInplaceRelu(torch.nn.Module):
def __init__(self):
super(BNFuncInplaceRelu, self).__init__()
self.bn = torch.nn.BatchNorm2d(3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), True)
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
for instance in [BNRelu(True), BNRelu(False), BNFuncRelu(), BNFuncInplaceRelu()]:
model = self._test_op_impl(instance, data, "quantized::batch_norm2d_relu")
FileCheck().check_not("aten::batch_norm") \
.check_not("aten::relu") \
.check_not("aten::relu_") \
.run(model.graph)
def test_quantized_mul(self):
class Mul(torch.nn.Module):
def __init__(self):
super(Mul, self).__init__()
def forward(self, x, y):
return x * y
class InplaceMul(torch.nn.Module):
def __init__(self):
super(InplaceMul, self).__init__()
def forward(self, x, y):
x *= y
return x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.rand((1, 3, 10, 10), dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
for M in [Mul(), InplaceMul()]:
m = self._test_op_impl(M, data, "quantized::mul")
FileCheck().check_not("aten::mul") \
.check_not("aten::mul_") \
.run(m.graph)
def test_quantized_mul_scalar(self):
class MulScalar(torch.nn.Module):
def __init__(self):
super(MulScalar, self).__init__()
def forward(self, x):
return x * 3
class InplaceMulScalar(torch.nn.Module):
def __init__(self):
super(InplaceMulScalar, self).__init__()
def forward(self, x):
x *= 3
return x
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
for M in [MulScalar(), InplaceMulScalar()]:
m = self._test_op_impl(M, data, "quantized::mul_scalar")
FileCheck().check_not("aten::mul") \
.check_not("aten::mul_") \
.run(m.graph)
def test_quantized_mul_relu(self):
class MulRelu(torch.nn.Module):
def __init__(self, inplace):
super(MulRelu, self).__init__()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = x * y
return self.relu(x)
class InplaceMulRelu(torch.nn.Module):
def __init__(self, inplace):
super(InplaceMulRelu, self).__init__()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x *= y
return self.relu(x)
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.rand((1, 3, 10, 10), dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
for M in [MulRelu(True), MulRelu(False), InplaceMulRelu(True), InplaceMulRelu(False)]:
m = self._test_op_impl(M, data, "quantized::mul_relu")
FileCheck().check_not("aten::mul") \
.check_not("aten::mul_") \
.check_not("aten::relu") \
.check_not("aten::relu_") \
.check_not("quantized::relu") \
.run(m.graph)
def test_quantized_mul_scalar_relu(self):
class MulScalar(torch.nn.Module):
def __init__(self):
super(MulScalar, self).__init__()
def forward(self, x):
return F.relu(x * 3)
class InplaceMulScalar(torch.nn.Module):
def __init__(self):
super(InplaceMulScalar, self).__init__()
def forward(self, x):
x *= 3
return F.relu(x)
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
for M in [MulScalar(), InplaceMulScalar()]:
m = self._test_op_impl(M, data, "quantized::mul_scalar_relu")
FileCheck().check_not("aten::mul") \
.check_not("aten::mul_") \
.check_not("aten::relu") \
.check_not("quantized::relu") \
.run(m.graph)
def test_hardswish(self):
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
m = self._test_op_impl(torch.nn.Hardswish(), data, "quantized::hardswish")
FileCheck().check_not("aten::hardswish") \
.run(m.graph)
def test_layer_norm(self):
data = [(torch.rand((1, 3, 10, 10), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
layer_norm = torch.nn.LayerNorm([3, 10, 10])
m = self._test_op_impl(layer_norm, data, "quantized::layer_norm")
FileCheck().check_not("aten::layer_norm") \
.run(m.graph)
def test_quantize_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
self.dropout = torch.nn.Dropout()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.relu = torch.nn.ReLU()
self.relu6 = torch.nn.ReLU6()
self.leaky_relu = torch.nn.LeakyReLU()
def forward(self, x):
x = self.conv(x)
x = self.maxpool1d(x)
x = self.maxpool2d(x)
x = self.maxpool3d(x)
x = torch.flatten(x)
x = torch.max(x)
x = torch.min(x)
x = torch.sigmoid(x)
x = x.reshape([-1])
x = x.resize_(1, 1, x.numel())
x = self.sigmoid(x)
x = x.view(-1)
x = x.transpose(1, 2)
x = x.contiguous()
x, y = torch.chunk(x, 2)
x = F.dropout(x)
x = self.dropout(x)
x, _ = torch.sort(x)
x = F.sigmoid(x)
x = x.permute(0, 2, 3, 1)
x = torch.repeat_interleave(x, 3, 1)
x = self.tanh(x)
x = F.tanh(x)
x = torch.tanh(x)
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x.hardsigmoid_()
x = self.relu(x)
x = F.relu(x)
x.relu_()
x = self.relu6(x)
x = F.relu6(x)
x = self.leaky_relu(x)
x = F.leaky_relu(x)
x.leaky_relu_()
x = self.conv(x)
return x
m = torch.jit.script(M())
qconfig = script_qconfig(default_qconfig)
# dummy data to suppress warning
data = torch.rand((1, 3, 10, 10))
get_forward(qconfig.activation)(data)
get_forward(qconfig.weight)(data)
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(
m._c, 'forward', {'': qconfig}, inplace=False))
m = convert_script(m, True)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.check_count("quantized::conv2d", 2, exactly=True) \
.check("aten::dequantize") \
.run(m.graph)
def test_quantize_general_value_ops(self):
""" A test that checks dequantize will be swapped for \
all supported general value ops like aten::avg_pool2d \
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.avg_pool1d = torch.nn.AvgPool1d(3)
self.avg_pool2d = torch.nn.AvgPool2d(3)
self.avg_pool3d = torch.nn.AvgPool2d(3)
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
self.hardtanh = torch.nn.Hardtanh()
self.elu = torch.nn.ELU()
def forward(self, x):
x = self.conv(x)
x = self.avg_pool1d(x)
x = self.avg_pool2d(x)
x = self.avg_pool3d(x)
x = self.adaptive_avg_pool1d(x)
x = self.adaptive_avg_pool2d(x)
x = self.adaptive_avg_pool3d(x)
x = F.avg_pool1d(x, 3)
x = F.avg_pool2d(x, 3)
x = F.avg_pool3d(x, 3)
x = F.adaptive_avg_pool1d(x, (1))
x = F.adaptive_avg_pool2d(x, (1, 1))
x = F.adaptive_avg_pool3d(x, (1, 1, 1))
x = torch.mean(x)
x = x.mean()
# interpolate node will introduce 3 quantize_per_tensor ops
x = F.interpolate(x, 4, mode='nearest') # interpolate node
x = F.upsample(x, (32, 32)) # interpolate node
x = F.upsample_nearest(x, (32, 32)) # interpolate node
x = F.interpolate(x, 4, mode='linear') # common node
x = F.upsample_bilinear(x, (32, 32)) # common node
x = torch.clamp(x, -3, 3)
x = x.clamp(-2.5, 2.5)
# x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready
x = self.hardtanh(x)
x = F.hardtanh(x)
x.hardtanh_()
x = self.elu(x)
x = F.elu(x)
x.elu_()
x = self.conv(x)
return x
m = torch.jit.script(M())
qconfig = script_qconfig(default_qconfig)
# dummy data to suppress warning
data = torch.rand((1, 3, 10, 10))
get_forward(qconfig.activation)(data)
get_forward(qconfig.weight)(data)
m = wrap_cpp_module(torch._C._jit_pass_insert_observers(
m._c, 'forward', {'': qconfig}, inplace=False))
# Checking the model before fianlize contain unfused patterns
# that numerically matches the model after quantize by checking
# number of aten::quantize_per_tensor functions
# conv has 3 quantize_per_tensor for activations and 1 for weight
# and for N general value op between conv we should have
# N + 1 quantize_per_tensor between these ops
m1 = convert_script(m, debug=True)
# NB: This Needs to be updated when we add more ops to test
# number of quantize_per_tensor op for type
num_quant_by_op_type = {'conv': 2, 'common': 1, 'interpolate': 3}
# number of ops for each type
num_op_by_op_type = {'conv': 2, 'common': 24, 'interpolate': 3}
num_quantize_per_tensor = 1 # for output
for op_type, num_op in num_op_by_op_type.items():
num_quantize_per_tensor += num_op * num_quant_by_op_type[op_type]
FileCheck().check_count("aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True) \
.run(m1.graph)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
m2 = convert_script(m, debug=False)
FileCheck().check_count("aten::quantize_per_tensor(", 1, exactly=True) \
.check_count("quantized::conv2d(", 2, exactly=True) \
.check("aten::dequantize(") \
.run(m2.graph)
class TestQuantizeDynamicScript(JitTestCase):
def test_prepare_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
m = torch.jit.script(M())
m = prepare_dynamic_script(m, {'': default_dynamic_qconfig})
# for input of FC for dynamic quant
assert len(attrs_with_prefix(m, '_observer_')) == 1
# for weight
assert len(attrs_with_prefix(m.fc, '_observer_')) == 1
FileCheck().check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
.check('prim::GetAttr[name="fc"]') \
.check('prim::CallMethod') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
def test_prepare_dynamic_child_qconfig(self):
class Sub(torch.nn.Module):
def __init__(self):
super(Sub, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.sub = Sub()
def forward(self, x):
return self.sub(self.conv(x))
m = torch.jit.script(M())
# only quantize child module.
m = prepare_dynamic_script(m, {'sub.fc': default_dynamic_qconfig})
# input of sub for dynamic quant
assert len(attrs_with_prefix(m, '_observer_')) == 1
# not quantized
assert len(attrs_with_prefix(m.conv, '_observer_')) == 0
# no observers since we observe in the outer most call site
assert len(attrs_with_prefix(m.sub, '_observer_')) == 0
# weight of linear
assert len(attrs_with_prefix(m.sub.fc, '_observer_')) == 1
FileCheck().check('prim::GetAttr[name="sub') \
.check('prim::CallMethod') \
.check('DynamicQuantObserver = prim::GetAttr[name="_observer_') \
.check('prim::CallMethod') \
.check_not('Observer = prim::GetAttr[name="_observer_') \
.run(m.graph)
def test_insert_quant_dequant_linear_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).float()
self.fc2 = torch.nn.Linear(5, 5).float()
def forward(self, x):
x = self.fc1(x)
return self.fc2(x)
m = torch.jit.script(M())
m = prepare_dynamic_script(m, {'': default_dynamic_qconfig})
data = torch.randn(5, 5, dtype=torch.float)
m(data)
m = wrap_cpp_module(torch._C._jit_pass_insert_quant_dequant(m._c, "forward", False, True))
assert len(m._modules._c.items()) == 2, \
'Expected to have two submodule of linear'
m(data)
quant_func = "aten::quantize_per_tensor"
# quantizing activations
FileCheck().check("aten::_choose_qparams_per_tensor") \
.check_next(quant_func) \
.check_next("aten::dequantize") \
.check("aten::_choose_qparams_per_tensor") \
.check_next(quant_func) \
.check_next("aten::dequantize") \
.check(quant_func) \
.check_next("aten::dequantize") \
.check_not(quant_func) \
.check("return") \
.run(m.graph)
def test_finalize_for_linear_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc(x)
data = torch.rand((1, 5), dtype=torch.float)
qconfig_dict = {'': default_dynamic_qconfig}
model = torch.jit.script(M()).eval()
model = quantize_dynamic_script(model, qconfig_dict, data)
FileCheck().check("quantized::linear_dynamic") \
.run(model.graph)
def test_dynamic_multi_op(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = x + 5
return self.fc1(x)
m = torch.jit.script(M())
data = torch.randn((1, 5), dtype=torch.float)
qconfig_dict = {'' : default_dynamic_qconfig}
model = quantize_dynamic_script(m, qconfig_dict, data)
# add op is not dynamically quantized.
FileCheck().check("aten::add") \
.check("quantized::linear_dynamic") \
.run(model.graph)
def test_dynamic_quant_multi_uses(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
size1 = x.size()
size2 = x.size()
return self.fc(x), size1, size2
model = torch.jit.script(M()).eval()
data = torch.rand((1, 5), dtype=torch.float)
qconfig_dict = {'': default_dynamic_qconfig}
model = quantize_dynamic_script(model, qconfig_dict, [data])
FileCheck().check("quantized::linear_dynamic") \
.check_not("aten::_choose_qparams_per_tensor") \
.run(model.graph)
def test_prepare_dynamic_lstm(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.lstm = torch.nn.LSTM(2, 2).to(dtype=torch.float)
def forward(self, x):
return self.lstm(x)
from torch.quantization.observer import default_dynamic_quant_observer, _MinMaxTensorListObserver
qconfig = QConfigDynamic(activation=default_dynamic_quant_observer,
weight=_MinMaxTensorListObserver)
m = torch.jit.script(M())
m = prepare_dynamic_script(m, {'': qconfig})
assert len(attrs_with_prefix(m.lstm, '_observer_')) == 1
FileCheck().check('_MinMaxTensorListObserver = prim::GetAttr[name="_observer_0') \
.check("aten::lstm") \
.check("return") \
.run(str(get_module_method(m, 'lstm', 'forward__0').graph))