blob: e2e0a678621f3c1e304e9506f5917027239a9727 [file] [log] [blame]
import torch
import torch.nn.functional as F
import torch.nn.quantized as nnq
import torch.nn.quantized.dynamic as nnqd
import torch.nn.intrinsic.quantized as nniq
from torch.quantization.fx import QuantType
# test utils
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
)
import itertools
import operator
class TestQuantizeFx(QuantizationTestCase):
""" Unit tests for functionalities
"""
@skipIfNoFBGEMM
def test_functional(self):
""" Test quantizing functional conv and linear
"""
class Conv(torch.nn.Module):
def __init__(self):
super().__init__()
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
def forward(self, x, weight):
return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)
conv_input = torch.rand(1, 3, 224, 224)
conv_weight = torch.rand(3, 3, 3, 3)
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight):
return F.linear(x, weight)
linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
linear_module_input = torch.rand(8, 5)
tests = [
(False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)),
(True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)),
(False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)),
(True, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.dynamic.Linear)),
(False, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.Linear)),
]
for is_dynamic, M, inputs, quantized_node in tests:
quant_type = QuantType.DYNAMIC if is_dynamic else QuantType.STATIC
self.checkGraphModeFxOp(M(), inputs, quantized_node, quant_type=quant_type)
class TestQuantizeFxOps(QuantizationTestCase):
"""Unit tests for individual ops
"""
@skipIfNoFBGEMM
def test_linear(self):
class ModuleLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(ModuleLinear, self).__init__()
self.linear = torch.nn.Linear(30, 4).float()
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(self.linear(x))
class FuncLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(FuncLinear, self).__init__()
self.w = torch.randn(4, 30)
self.b = torch.randn(4)
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(F.linear(x, self.w, self.b))
data = (torch.rand((1, 30), dtype=torch.float),)
options = itertools.product(
[(ModuleLinear(has_relu=False), True)],
# TODO: enable after raw `tensor` is supported in fx
# (FuncLinear(has_relu=False), False)],
self.all_quant_types)
quantized_nodes = {
# is_module
True: {
# quant_type:
QuantType.DYNAMIC: ('call_module', nnqd.Linear),
QuantType.STATIC: ('call_module', nnq.Linear),
# note that we are checking the final result
QuantType.QAT: ('call_module', nnq.Linear),
},
False: {
# quant_type:
QuantType.DYNAMIC: ('call_function', torch.ops.quantized.linear_dynamic),
QuantType.STATIC: ('call_function', torch.ops.quantized.linear),
QuantType.QAT: ('call_function', torch.ops.quantized.linear),
}
}
for (model, is_module), quant_type in options:
self.checkGraphModeFxOp(model, data, quantized_nodes[is_module][quant_type], quant_type=quant_type)
for f_relu, quant_type in itertools.product([True, False], [QuantType.STATIC, QuantType.QAT]):
for model, quantized_node in [
(ModuleLinear(has_relu=True, f_relu=f_relu), ('call_module', nniq.LinearReLU))]:
# TODO: support functional linear + relu fusion
# (FuncLinear(has_relu=True, f_relu=f_relu), ('call_function', torch.ops.quantized.linear_relu))]:
self.checkGraphModeFxOp(model, data, quantized_node, quant_type=quant_type)
@skipIfNoFBGEMM
def test_quantized_conv(self):
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class Conv(torch.nn.Module):
def __init__(self, dim):
super(Conv, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return self.conv(x)
options = itertools.product([1, 2, 3], self.static_quant_types)
quantized_nodes = {
# dim
1: ('call_module', nnq.Conv1d),
2: ('call_module', nnq.Conv2d),
3: ('call_module', nnq.Conv3d),
}
for dim, quant_type in options:
model = self.checkGraphModeFxOp(
Conv(dim), self.img_data_dict[dim],
quantized_nodes[dim], quant_type=quant_type)
@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
class ConvNdRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(ConvNdRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNdFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x))
class ConvNdInplaceFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdInplaceFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x), True)
options = itertools.product([1, 2, 3], self.static_quant_types)
quantized_nodes = {
# dim
1: ('call_module', nniq.ConvReLU1d),
2: ('call_module', nniq.ConvReLU2d),
3: ('call_module', nniq.ConvReLU3d),
}
for dim, quant_type in options:
for orig_m in [ConvNdRelu(dim, True),
ConvNdRelu(dim, False),
ConvNdFunctionalRelu(dim),
ConvNdInplaceFunctionalRelu(dim)]:
conv_name = "conv{}d".format(dim)
m = self.checkGraphModeFxOp(
orig_m, self.img_data_dict[dim],
quantized_nodes[dim], quant_type=quant_type)
def _test_quantized_binary_op_impl(self, binary_op, ibinary_op, quantized_op):
class Op(torch.nn.Module):
def __init__(self, is_inplace, is_scalar):
super(Op, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.is_scalar = is_scalar
self.op = ibinary_op if is_inplace else binary_op
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
x = self.op(x, y)
return x
# TODO: decide whether we want to quantize or not
# in this case
# class NonQuantizedOp(torch.nn.Module):
# def __init__(self, is_inplace, is_scalar):
# super(NonQuantizedOp, self).__init__()
# self.is_scalar = is_scalar
# self.op = ibinary_op if is_inplace else binary_op
# def forward(self, x, y):
# y = 3 if self.is_scalar else y
# x = self.op(x, y)
# return x
data = (torch.randn(1, 2, 3, 3, dtype=torch.float),
torch.randn(1, 2, 3, 3, dtype=torch.float))
quantized_node = ('call_function', quantized_op)
options = itertools.product([True, False], [True, False], self.static_quant_types)
for is_inplace, is_scalar, quant_type in options:
self.checkGraphModeFxOp(Op(is_inplace, is_scalar), data, quantized_node, quant_type=quant_type)
def _test_quantized_binary_op_relu_impl(self, binary_op, ibinary_op, quantized_op):
class OpRelu(torch.nn.Module):
def __init__(self, is_inplace, is_functional_relu,
is_inplace_relu, is_scalar):
super(OpRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.op = ibinary_op if is_inplace else binary_op
self.is_functional_relu = is_functional_relu
self.is_inplace_relu = is_inplace_relu
self.is_scalar = is_scalar
if self.is_functional_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU(self.is_inplace_relu)
def forward(self, x, y):
x = self.conv1(x)
y = 3 if self.is_scalar else self.conv2(y)
x = self.op(x, y)
x = self.relu(x, self.is_inplace_relu) if \
self.is_functional_relu else self.relu(x)
return x
data = (torch.rand((1, 2, 5, 5), dtype=torch.float),
torch.rand((1, 2, 5, 5), dtype=torch.float))
quantized_node = ('call_function', quantized_op)
options = itertools.product(
[True, False], [True, False], [True, False], [True, False], self.static_quant_types)
for is_inplace_op, is_functional_relu, is_inplace_relu, is_scalar, quant_type in options:
self.checkGraphModeFxOp(
OpRelu(is_inplace_op, is_functional_relu, is_inplace_relu, is_scalar),
data, quantized_node, quant_type=quant_type)
@skipIfNoFBGEMM
def test_quantized_binary_op(self):
self._test_quantized_binary_op_impl(
operator.add, operator.iadd, torch.ops.quantized.add)
@skipIfNoFBGEMM
def test_quantized_binary_op_relu(self):
self._test_quantized_binary_op_relu_impl(
operator.add, operator.iadd, torch.ops.quantized.add_relu)
@skipIfNoFBGEMM
def test_quantized_cat(self):
""" quantization of the output of cat will be depend on the
input of cat. we only quantize the output of cat when its inputs are quantized.
"""
class QuantizedCat(torch.nn.Module):
def __init__(self):
super(QuantizedCat, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return torch.cat([x, y], 1)
# TODO: decide whether to quantize in this case
# class NonQuantizedCat(torch.nn.Module):
# def __init__(self):
# super(NonQuantizedCat, self).__init__()
# def forward(self, x, y):
# return torch.cat([x, y], 1)
data = (torch.randn(1, 2, 5, 5, dtype=torch.float),
torch.randn(1, 2, 5, 5, dtype=torch.float))
quantized_node = ('call_function', torch.ops.quantized.cat)
for quant_type in self.static_quant_types:
self.checkGraphModeFxOp(QuantizedCat(), data, quantized_node, quant_type=quant_type)
@skipIfNoFBGEMM
def test_qbatch_norm(self):
bn_module = {
# TODO: quantized batchnorm 1d module is missing
# 1 : torch.nn.BatchNorm1d,
2 : torch.nn.BatchNorm2d,
3 : torch.nn.BatchNorm3d,
}
class M(torch.nn.Module):
def __init__(self, dim):
super(M, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return self.bn(x)
options = itertools.product(self.static_quant_types, [2, 3])
quantized_nodes = {
# 1: ('call_module', nnq.BatchNorm1d),
2: ('call_module', nnq.BatchNorm2d),
3: ('call_module', nnq.BatchNorm3d),
}
for quant_type, dim in options:
model = self.checkGraphModeFxOp(M(dim), self.img_data_dict[dim], quantized_nodes[dim], quant_type)
@skipIfNoFBGEMM
def test_qbatch_norm_relu(self):
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
class BNRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(BNRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
self.relu = torch.nn.ReLU(inplace=inplace)
def forward(self, x):
return self.relu(self.bn(x))
class BNFuncRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), False)
class BNFuncInplaceRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncInplaceRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), True)
options = itertools.product(self.static_quant_types, [2, 3])
quantized_nodes = {
2: ('call_module', nniq.BNReLU2d),
3: ('call_module', nniq.BNReLU3d),
}
for quant_type, dim in options:
for instance in [BNRelu(dim, True), BNRelu(dim, False),
BNFuncRelu(dim), BNFuncInplaceRelu(dim)]:
self.checkGraphModeFxOp(
instance, self.img_data_dict[dim], quantized_nodes[dim], quant_type)