blob: b3de59382e39a427fbe1dcf18701003f5595f398 [file] [log] [blame]
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq
import torch.multiprocessing as mp
# symbolic trace
from torch.fx import symbolic_trace
# graph mode quantization based on fx
from torch.quantization import (
QuantType,
fuse_fx,
prepare_fx,
convert_fx,
)
from torch.quantization import (
default_qconfig,
default_qat_qconfig,
prepare,
prepare_qat,
convert,
)
# test utils
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
skip_if_no_torchvision,
train_one_epoch,
run_ddp,
)
from torch.testing._internal.common_distributed import skip_if_not_multigpu
from torch.testing._internal.common_quantization import NodeSpec as ns
import itertools
import operator
import unittest
class TestQuantizeFx(QuantizationTestCase):
""" Unit tests for functionalities
"""
@skipIfNoFBGEMM
def test_functional(self):
""" Test quantizing functional conv and linear
"""
class Conv(torch.nn.Module):
def __init__(self):
super().__init__()
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
def forward(self, x, weight):
return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)
conv_input = torch.rand(1, 3, 224, 224)
conv_weight = torch.rand(3, 3, 3, 3)
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight):
return F.linear(x, weight)
linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
linear_module_input = torch.rand(8, 5)
tests = [
(False, Conv, (conv_input, conv_weight), ns.call_function(torch.ops.quantized.conv2d)),
(True, Linear, (linear_input, linear_weight), ns.call_function(torch.ops.quantized.linear_dynamic)),
(False, Linear, (linear_input, linear_weight), ns.call_function(torch.ops.quantized.linear)),
(True, LinearModule, (linear_module_input,), ns.call_module(nnqd.Linear)),
(False, LinearModule, (linear_module_input,), ns.call_module(nnq.Linear)),
]
for is_dynamic, M, inputs, quantized_node in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
self.checkGraphModeFxOp(
M(), inputs, quant_type, quantized_node)
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
@skipIfNoFBGEMM
def test_linear(self):
class ModuleLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(ModuleLinear, self).__init__()
self.linear = torch.nn.Linear(30, 4).float()
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(self.linear(x))
class FuncLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(FuncLinear, self).__init__()
self.w = torch.randn(4, 30)
self.b = torch.randn(4)
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(F.linear(x, self.w, self.b))
data = (torch.rand((1, 30), dtype=torch.float),)
options = itertools.product(
[(ModuleLinear(has_relu=False), True)],
# TODO: enable after raw `tensor` is supported in fx
# (FuncLinear(has_relu=False), False)],
self.all_quant_types)
quantized_nodes = {
# is_module
True: {
# quant_type:
QuantType.DYNAMIC: ns.call_module(nnqd.Linear),
QuantType.STATIC: ns.call_module(nnq.Linear),
# note that we are checking the final result
QuantType.QAT: ns.call_module(nnq.Linear),
},
False: {
# quant_type:
QuantType.DYNAMIC: ns.call_function(torch.ops.quantized.linear_dynamic),
QuantType.STATIC: ns.call_function(torch.ops.quantized.linear),
QuantType.QAT: ns.call_function(torch.ops.quantized.linear),
}
}
for (model, is_module), quant_type in options:
self.checkGraphModeFxOp(
model, data, quant_type, quantized_nodes[is_module][quant_type])
for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
for model, quantized_node in [
(ModuleLinear(has_relu=True, f_relu=f_relu), ns.call_module(nniq.LinearReLU))]:
# TODO: support functional linear + relu fusion
# (FuncLinear(has_relu=True, f_relu=f_relu), ns.call_function(torch.ops.quantized.linear_relu))]:
self.checkGraphModeFxOp(model, data, quant_type, quantized_node)
@skipIfNoFBGEMM
def test_quantized_conv(self):
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class Conv(torch.nn.Module):
def __init__(self, dim):
super(Conv, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return self.conv(x)
options = itertools.product([1, 2, 3], self.static_quant_types)
quantized_nodes = {
# dim
1: ns.call_module(nnq.Conv1d),
2: ns.call_module(nnq.Conv2d),
3: ns.call_module(nnq.Conv3d),
}
for dim, quant_type in options:
model = self.checkGraphModeFxOp(
Conv(dim), self.img_data_dict[dim], quant_type,
quantized_nodes[dim])
@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class ConvNdRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(ConvNdRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNdFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x))
class ConvNdInplaceFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdInplaceFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x), True)
options = itertools.product([1, 2, 3], self.static_quant_types)
quantized_nodes = {
# dim
1: ns.call_module(nniq.ConvReLU1d),
2: ns.call_module(nniq.ConvReLU2d),
3: ns.call_module(nniq.ConvReLU3d),
}
for dim, quant_type in options:
for orig_m in [ConvNdRelu(dim, True),
ConvNdRelu(dim, False),
ConvNdFunctionalRelu(dim),
ConvNdInplaceFunctionalRelu(dim)]:
conv_name = "conv{}d".format(dim)
m = self.checkGraphModeFxOp(
orig_m, self.img_data_dict[dim], quant_type,
quantized_nodes[dim])
def _test_quantized_binary_op_impl(self, binary_op, ibinary_op, quantized_op):
class Op(torch.nn.Module):
def __init__(self, is_inplace, is_scalar):
super(Op, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
self.is_scalar = is_scalar
self.op = ibinary_op if is_inplace else binary_op
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
x = self.op(x, y)
return x
# TODO: decide whether we want to quantize or not
# in this case
# class NonQuantizedOp(torch.nn.Module):
# def __init__(self, is_inplace, is_scalar):
# super(NonQuantizedOp, self).__init__()
# self.is_scalar = is_scalar
# self.op = ibinary_op if is_inplace else binary_op
# def forward(self, x, y):
# y = 3 if self.is_scalar else y
# x = self.op(x, y)
# return x
data = (torch.randn(1, 1, 1, 1, dtype=torch.float),
torch.randn(1, 1, 1, 1, dtype=torch.float))
quantized_node = ns.call_function(quantized_op)
options = itertools.product([True, False], [True, False])
quant_type = QuantType.STATIC
for is_inplace, is_scalar in options:
self.checkGraphModeFxOp(
Op(is_inplace, is_scalar), data, quant_type, quantized_node)
def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op, quantized_op):
class OpRelu(torch.nn.Module):
def __init__(self, is_inplace, is_functional_relu,
is_scalar):
super(OpRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
self.op = ibinary_op if is_inplace else binary_op
self.is_functional_relu = is_functional_relu
self.is_scalar = is_scalar
self.relu = F.relu if self.is_functional_relu \
else torch.nn.ReLU()
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
x = self.op(x, y)
x = self.relu(x)
return x
data = (torch.rand((1, 1, 1, 1), dtype=torch.float),
torch.rand((1, 1, 1, 1), dtype=torch.float))
quant_type = QuantType.STATIC
quantized_node = ns.call_function(quantized_op)
options = itertools.product(
[True, False], [True, False], [True, False])
for is_inplace_op, is_functional_relu, is_scalar in options:
self.checkGraphModeFxOp(
OpRelu(is_inplace_op, is_functional_relu, is_scalar),
data, quant_type, quantized_node)
@skipIfNoFBGEMM
def test_quantized_add(self):
self._test_quantized_binary_op_impl(
operator.add, operator.iadd, torch.ops.quantized.add)
@skipIfNoFBGEMM
def test_quantized_mul(self):
self._test_quantized_binary_op_impl(
operator.mul, operator.imul, torch.ops.quantized.mul)
@skipIfNoFBGEMM
def test_quantized_add_relu(self):
self._test_quantized_binary_op_relu_impl(
operator.add, operator.iadd, torch.ops.quantized.add_relu)
@skipIfNoFBGEMM
def test_quantized_mul_relu(self):
self._test_quantized_binary_op_relu_impl(
operator.mul, operator.imul, torch.ops.quantized.mul_relu)
@skipIfNoFBGEMM
def test_quantized_cat(self):
""" quantization of the output of cat will be depend on the
input of cat. we only quantize the output of cat when its inputs are quantized.
"""
class QuantizedCat(torch.nn.Module):
def __init__(self):
super(QuantizedCat, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return torch.cat([x, y], 1)
# TODO: decide whether to quantize in this case
# class NonQuantizedCat(torch.nn.Module):
# def __init__(self):
# super(NonQuantizedCat, self).__init__()
# def forward(self, x, y):
# return torch.cat([x, y], 1)
data = (torch.randn(1, 2, 5, 5, dtype=torch.float),
torch.randn(1, 2, 5, 5, dtype=torch.float))
quantized_node = ns.call_function(torch.ops.quantized.cat)
for quant_type in self.static_quant_types:
self.checkGraphModeFxOp(QuantizedCat(), data, quant_type, quantized_node)
@skipIfNoFBGEMM
def test_qbatch_norm(self):
bn_module = {
# TODO: quantized batchnorm 1d module is missing
# 1 : torch.nn.BatchNorm1d,
2 : torch.nn.BatchNorm2d,
3 : torch.nn.BatchNorm3d,
}
class M(torch.nn.Module):
def __init__(self, dim):
super(M, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return self.bn(x)
options = itertools.product(self.static_quant_types, [2, 3])
quantized_nodes = {
# 1: ns.call_module(nnq.BatchNorm1d),
2: ns.call_module(nnq.BatchNorm2d),
3: ns.call_module(nnq.BatchNorm3d),
}
for quant_type, dim in options:
model = self.checkGraphModeFxOp(
M(dim), self.img_data_dict[dim], quant_type, quantized_nodes[dim])
@skipIfNoFBGEMM
def test_qbatch_norm_relu(self):
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
class BNRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(BNRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
self.relu = torch.nn.ReLU(inplace=inplace)
def forward(self, x):
return self.relu(self.bn(x))
class BNFuncRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), False)
class BNFuncInplaceRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncInplaceRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), True)
options = itertools.product(self.static_quant_types, [2, 3])
quantized_nodes = {
2: ns.call_module(nniq.BNReLU2d),
3: ns.call_module(nniq.BNReLU3d),
}
for quant_type, dim in options:
for instance in [BNRelu(dim, True), BNRelu(dim, False),
BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
self.checkGraphModeFxOp(
instance, self.img_data_dict[dim], quant_type,
quantized_nodes[dim])
def _test_activation_impl(
self, float_module, float_op, quantized_module, quantized_op):
''' Test for activation op(with inplace options), float_op can be
torch op or functional op
'''
class M(torch.nn.Module):
def __init__(self, is_module, inplace):
super(M, self).__init__()
self.is_module = is_module
self.inplace = inplace
if self.is_module:
self.op = float_module(self.inplace)
else:
self.op = float_op
def forward(self, input):
if self.is_module:
return self.op(input)
else:
return self.op(input, self.inplace)
options = itertools.product([True, False], [True, False], self.static_quant_types)
quantized_nodes = {
# is_module
True: ns.call_module(quantized_module),
False: ns.call_function(quantized_op),
}
for is_module, is_inplace, quant_type in options:
self.checkGraphModeFxOp(
M(is_module, is_inplace), self.img_data_2d,
quant_type, quantized_nodes[is_module])
def test_hardswish(self):
self._test_activation_impl(nn.Hardswish, F.hardswish, nnq.Hardswish, torch.ops.quantized.hardswish)
def test_elu(self):
self._test_activation_impl(nn.ELU, F.elu, nnq.ELU, torch.ops.quantized.elu)
def _test_norm_impl(
self, float_module, float_op, op_args, data, quantized_module, quantized_op,
skip_op_arg_for_functional=False):
''' Test for normalization op, float_op can be torch op or functional op,
op_args is a list of positional argument for the module/op
'''
class M(torch.nn.Module):
def __init__(self, is_module):
super(M, self).__init__()
self.is_module = is_module
if self.is_module:
self.op = float_module(*op_args)
else:
self.op = float_op
def forward(self, input):
if self.is_module:
return self.op(input)
else:
args = [input]
if not skip_op_arg_for_functional:
args += op_args
return self.op(*args)
options = itertools.product([True, False], self.static_quant_types)
quantized_nodes = {
# is_module
True: ns.call_module(quantized_module),
False: ns.call_function(quantized_op),
}
for is_module, quant_type in options:
self.checkGraphModeFxOp(
M(is_module), data, quant_type, quantized_nodes[is_module])
def test_layer_norm(self):
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
self._test_norm_impl(
nn.LayerNorm, F.layer_norm, [[2, 5, 5]], data, nnq.LayerNorm, torch.ops.quantized.layer_norm)
def test_instance_norm(self):
data_1d = (torch.rand((1, 4, 5), dtype=torch.float),)
data_2d = (torch.rand((1, 4, 5, 1), dtype=torch.float),)
data_3d = (torch.rand((1, 4, 5, 1, 1), dtype=torch.float),)
data_dict = {1 : data_1d, 2 : data_2d, 3 : data_3d}
instance_norm_modules = {1 : nn.InstanceNorm1d,
2 : nn.InstanceNorm2d,
3 : nn.InstanceNorm3d}
quantized_instance_norm_modules = {
1 : nnq.InstanceNorm1d,
2 : nnq.InstanceNorm2d,
3 : nnq.InstanceNorm3d
}
for dim in [1, 2, 3]:
data = data_dict[dim]
module = instance_norm_modules[dim]
quantized_module = quantized_instance_norm_modules[dim]
self._test_norm_impl(
module, F.instance_norm, [4], data,
quantized_module, torch.ops.quantized.instance_norm,
skip_op_arg_for_functional=True)
@skipIfNoFBGEMM
def test_clamp(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu6 = torch.nn.ReLU6()
self.relu6_ = torch.nn.ReLU6(True)
self.hardtanh = torch.nn.Hardtanh()
self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.relu6(x)
self.relu6_(x)
x = F.relu6(x)
x = torch.clamp(x, -3, 3)
x = x.clamp(-2.5, 2.5)
# x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready
x = self.hardtanh(x)
self.hardtanh_(x)
x = F.hardtanh(x)
F.hardtanh_(x)
return x
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),)
# list of node that should occur in order
node_list = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_function(F.hardtanh_),
ns.call_method('dequantize')
]
for quant_type in self.static_quant_types:
m = self.checkGraphModeFxOp(
M(), data, quant_type, expected_node_list=node_list)
@skipIfNoFBGEMM
def test_general_shape_ops(self):
""" A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
self.dropout = torch.nn.Dropout()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
# add_scalar
x = x + 3
# mul_scalar
x = x * 3
# add_scalar_out
x += 3
# mul_scalar_out
x *= 3
# add_scalar_relu
x = x + 3
x = F.relu(x)
# add_scalar_relu_out
x += 3
x = F.relu(x)
# mul_scalar_relu
x = x * 3
x = F.relu(x)
# mul_scalar_relu_out
x *= 3
x = F.relu(x)
x = self.maxpool1d(x)
x = self.maxpool2d(x)
x = self.maxpool3d(x)
x = torch.flatten(x)
x = torch.max(x)
x = torch.min(x)
x = x.reshape([-1])
x = x.resize_(1, 1, x.numel())
x = x.view(-1)
# prim::ListConstruct
xs = [x, x]
# prim::ListUnpack
x, y = xs
# prim::TupleConstruct
xs = (x, x)
# prim::TupleUnpack
x, y = xs
x = x.transpose(1, 2)
x = x.contiguous()
x, y = torch.chunk(x, 2)
x = F.dropout(x)
x = self.dropout(x)
x, _ = torch.sort(x)
x = x.permute(0, 2, 3, 1)
x = x.repeat_interleave(3, 1)
x = torch.repeat_interleave(x, 3, 1)
x = self.relu(x)
x = F.relu(x)
x = F.relu(x, inplace=True)
x = x.relu()
x.relu_()
x = x.squeeze(0)
x.squeeze_(0)
x = torch.squeeze(x, 0)
x = x.unsqueeze(0)
x.unsqueeze_(0)
x = torch.unsqueeze(x, 0)
x = x.detach()
x.detach_()
x = x.repeat(4, 2)
y = []
y.append(x)
z = torch.stack(y, 0)
z = [z, z]
x, _ = z
x = self.conv2(x)
return x
data = torch.rand(1, 3, 10, 10)
# This model is not executable since we just put all ops
# in the same forward
m = M()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(original, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)
@skipIfNoFBGEMM
def test_general_value_ops(self):
""" A test that checks correct patterns are produced for
all supported general value ops like aten::avg_pool2d \
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.avg_pool1d = torch.nn.AvgPool1d(3)
self.avg_pool2d = torch.nn.AvgPool2d(3)
self.avg_pool3d = torch.nn.AvgPool3d(3)
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
self.leaky_relu = torch.nn.LeakyReLU()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()
def forward(self, x):
x = self.conv(x)
x = self.avg_pool1d(x)
x = self.avg_pool2d(x)
x = self.avg_pool3d(x)
x = self.adaptive_avg_pool1d(x)
x = self.adaptive_avg_pool2d(x)
x = self.adaptive_avg_pool3d(x)
x = F.avg_pool1d(x, 3)
x = F.avg_pool2d(x, 3)
x = F.avg_pool3d(x, 3)
x = F.adaptive_avg_pool1d(x, (1))
x = F.adaptive_avg_pool2d(x, (1, 1))
x = F.adaptive_avg_pool3d(x, (1, 1, 1))
x = torch.mean(x)
x = torch.mean(x, [2, 3], False)
x = x.mean()
x = x.mean([2, 3], True)
x = F.interpolate(x, 4, mode='nearest')
x = F.interpolate(x, 4, mode='linear')
x = self.leaky_relu(x)
x = F.leaky_relu(x)
x = F.leaky_relu(x, inplace=True)
x = x.leaky_relu()
x.leaky_relu_()
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x = F.hardsigmoid(x, inplace=True)
x = x.hardsigmoid()
x.hardsigmoid_()
x = self.sigmoid(x)
x = torch.sigmoid(x)
# F.sigmoid is deprecated
x = x.sigmoid()
x.sigmoid_()
x = self.tanh(x)
# F.tanh is deprecated
x = torch.tanh(x)
x = x.tanh()
x.tanh_()
x = self.conv(x)
return x
# This model is not executable since we just put all ops
# in the same forward
m = M()
original = symbolic_trace(m)
# nothing to fuse so skipping the fuse step
qconfig_dict = {'': default_qconfig}
prepared = prepare_fx(original, qconfig_dict)
# not runnable
quantized = convert_fx(prepared)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers
# check exact counts of quantize and dequantize
count_check = {
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_method('dequantize') : 1
}
order_check = [
ns.call_function(torch.quantize_per_tensor),
ns.call_module(nnq.Conv2d),
ns.call_module(nnq.Conv2d),
ns.call_method('dequantize'),
]
self.checkGraphModuleNodes(
quantized,
expected_node_occurrence=count_check,
expected_node_list=order_check)
class TestQuantizeFxModels(QuantizationTestCase):
def _test_model_impl(
self, mode, name, model, eager_quantizable_model,
check_with_eager=True,
diff_of_quant=None,
diff_from_eager=None):
if diff_of_quant is None or diff_from_eager is None:
diff_of_quant = {}
diff_from_eager = {}
if mode not in diff_of_quant or mode not in diff_from_eager:
diff_of_quant[mode] = {}
diff_from_eager[mode] = {}
input_tensor = torch.rand(1, 3, 224, 224)
input_tensor_inception = torch.rand(1, 3, 299, 299)
output_value = torch.randint(0, 1, (1,))
# print('quantizing:', name, ' mode:', mode)
if name == 'inception_v3':
input_value = input_tensor_inception
else:
input_value = input_tensor
qconfig = default_qconfig if mode == 'static' else default_qat_qconfig
qconfig_dict = {'': qconfig}
graph_module = symbolic_trace(model)
# print('graph module:', graph_module.src)
script = torch.jit.script(graph_module)
# make sure graph module and script module are both runanble
original_out = graph_module(input_value)
is_not_tuple_out = not isinstance(original_out, tuple)
script_out = script(input_value)
self.assertEqual(
(original_out - script_out).abs().max(), 0,
'Reslut of original graph module and script module does not match')
# set to train just before quantization
if mode != 'static':
model.train()
graph_module = fuse_fx(graph_module)
prepared = prepare_fx(graph_module, qconfig_dict)
if mode == 'ddp':
mp.spawn(run_ddp,
args=(world_size, prepared),
nprocs=world_size,
join=True)
elif mode == 'qat':
assert prepared.training, 'prepared must be in training mode for qat'
optimizer = torch.optim.SGD(prepared.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
train_one_epoch(prepared, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
else:
for i in range(10):
prepared(input_value)
# print('after observation root:', prepared.root)
qgraph = convert_fx(prepared)
# print('after quantization root:', qgraph.root)
# print('after quantization code:', qgraph.src)
qgraph.eval()
qgraph_script = torch.jit.script(qgraph)
# print('quantized and scripted:', qgraph_script.graph)
qgraph_out = qgraph(input_value)
qgraph_script = qgraph_script(input_value)
if is_not_tuple_out:
diff_of_quant[mode][name] = (original_out - qgraph_out).abs().max()
assert torch.allclose(qgraph_out, qgraph_script), 'graph, scripted graph'
else:
print('tuple output')
if eager_quantizable_model is not None:
# comparing to eager mode quantization
qeager = eager_quantizable_model
ref_out = qeager(input_value)
qeager.qconfig = qconfig
if mode == 'static':
qeager.fuse_model()
prepare(qeager, inplace=True)
else:
qeager.train()
qeager.fuse_model()
prepare_qat(qeager, inplace=True)
# calibration
if mode == 'ddp':
mp.spawn(run_ddp,
args=(world_size, qeager),
nprocs=world_size,
join=True)
elif mode == 'qat':
assert qeager.training, 'qeager should be in training mode for qat'
optimizer = torch.optim.SGD(qeager.parameters(), lr=0.0001)
train_one_epoch(qeager, criterion, optimizer, [(input_value, output_value)], torch.device('cpu'), 1)
else:
for i in range(10):
qeager(input_value)
# print('ref after observation:', qeager)
convert(qeager, inplace=True)
qeager.eval()
# print('ref after quantization:', qeager)
qeager_out = qeager(input_value)
qeager_script = torch.jit.script(qeager)
qscript_out = qeager_script(input_value)
if is_not_tuple_out:
diff_from_eager[mode][name] = (qeager_out - qgraph_out).abs().max()
if check_with_eager:
self.assertEqual(diff_from_eager[mode][name], 0,
'Result of graph mode quantization and ' +
'eager mode quantization on model: ' + name +
' should match. Mode: ' + mode +
' diff:' + str(diff_from_eager[mode][name]))
@skip_if_no_torchvision
@skipIfNoFBGEMM
@unittest.skip("skip for now since tbb failed")
def test_torchvision(self):
from torchvision import models
from torchvision.models import quantization as quantized_models
def get_available_classification_models(models):
return [k for k, v in models.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
model_list = get_available_classification_models(models)
quantized_model_list = get_available_classification_models(quantized_models)
no_pretrained_model = set(['shufflenet_v2_x0_5', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'])
quantized_model_list = set(quantized_model_list) - no_pretrained_model
# test eager and graph consistency
model_list = quantized_model_list
# slice need to be fixed in symbolic tracing(https://github.com/pytorch/pytorch/issues/43511)
model_list = set(model_list) - {'googlenet', 'inception_v3'}
# getattr should not be used as node name(https://github.com/pytorch/pytorch/issues/43522)
model_list -= {'shufflenet_v2_x1_0', 'mobilenet_v2'}
# mobilenet: dropout error RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'QUInt8'
# incpetion_v3: looks like there is some problem with AuxLogits
quantized_not_working = [('qat', 'mobilenet_v2'),
('qat', 'inception_v3'),
('static', 'inception_v3')]
fx_eager_not_matching = ['googlenet', # because _transform_input is not quantized in eager
'mobilenet_v2'] # because relu6 is replaced as relu in mobilenetv2
diff_of_quant = {}
diff_from_eager = {}
modes = ['static', 'qat']
options = itertools.product(modes, model_list)
for mode, name in options:
pretrained = name in quantized_model_list # load pretrained model to compare with quantized model
if name in quantized_model_list:
if (mode, name) in quantized_not_working:
eager_quantizable_model = None
else:
eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float()
# compare with eager mode quantized model when it is available
pretrained = eager_quantizable_model is not None
model = models.__dict__[name](pretrained=pretrained).eval().float()
check_with_eager = name not in fx_eager_not_matching
self._test_model_impl(
mode, name, model, eager_quantizable_model,
check_with_eager,
diff_of_quant, diff_from_eager)
def print_diffs(diffs):
for mode, diffs_for_mode in diffs.items():
print('mode:', mode)
for name, diff in diffs_for_mode.items():
print(name, ':', diff)
# print('differences between float and quantized')
# print_diffs(diff_of_quant)
# print('----------------------')
# print('differences between graph mode and eager mode')
# print_diffs(diff_from_eager)
# print('----------------------')
@skip_if_no_torchvision
@skip_if_not_multigpu
@skipIfNoFBGEMM
@unittest.skip('TODO: not working yet due to https://github.com/pytorch/pytorch/issues/43513')
def test_resnet18_ddp(self):
from torchvision import models
from torchvision.models import quantization as quantized_models
eager_quantizable_model = quantized_models.__dict__[name](pretrained=True, quantize=False).eval().float()
model = models.__dict__[name](pretrained=True).eval().float()
self._test_model_impl(
'ddp', 'resnet18', model, eager_quantizable_model)