blob: b58335f7b96149d6422b08d4e3694e94bb8de54f [file] [log] [blame]
# Owner(s): ["oncall: quantization"]
import torch
import torch.fx.experimental.fx_acc.acc_tracer as acc_tracer
from torch.fx.experimental.fx2trt import (
TRTInterpreter,
InputTensorSpec,
TRTModule,
)
from torch.ao.quantization import (
default_qconfig
)
from torch.ao.quantization.quantize_fx import (
prepare_fx,
prepare_qat_fx,
get_tensorrt_backend_config_dict,
)
from torch.ao.quantization._quantize_fx_do_not_use import (
_convert_fx_do_not_use,
)
from torch.ao.quantization.fx.match_utils import (
MatchAllNode,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
)
from torch.ao.quantization.fx.backend_config.observation_type import ObservationType
import torch.nn.functional as F
import torch.nn as nn
import torch.nn.quantized._reference as nnqr
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.common_quantization import NodeSpec as ns
import unittest
import itertools
import copy
import operator
def lower_to_trt(model, inputs, shape_ranges):
""" Lower a quantized model to TensorRT
"""
assert len(inputs) == 1, "lower_to_trt only works for one input currently"
model = acc_tracer.trace(model, inputs) # type: ignore[attr-defined]
# TODO: test multiple inputs setting and enable multiple inputs
input_specs = [
InputTensorSpec(
torch.Size([-1, *inputs[0].shape[1:]]), torch.float,
shape_ranges=shape_ranges, has_batch_dim=True)
]
interp = TRTInterpreter(
model,
input_specs,
explicit_batch_dimension=True, explicit_precision=True)
result = interp.run(fp16_mode=False, int8_mode=True)
trt_mod = TRTModule(result.engine, result.input_names, result.output_names)
return trt_mod
class TestConvertFxDoNotUse(QuantizationTestCase):
def setUp(self):
super().setUp()
self.trt_qconfig = torch.ao.quantization.QConfig(
activation=torch.ao.quantization.observer.HistogramObserver.with_args(
qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
),
weight=torch.ao.quantization.default_weight_observer
)
self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
def _test_quantized_inputs_outputs(
self, prepare_custom_config_dict, prepare_count_check,
convert_count_check):
"""
Test the option to have inputs and outputs of the graph quantized
"""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
# quantized input, quantized output
m = M()
qconfig_dict = {'': torch.ao.quantization.default_qconfig}
m.eval()
mp = torch.ao.quantization.quantize_fx.prepare_fx(
m, qconfig_dict,
prepare_custom_config_dict=prepare_custom_config_dict)
self.checkGraphModuleNodes(mp, expected_node_occurrence=prepare_count_check)
mp(torch.randn(1, 1, 4, 4))
mq = _convert_fx_do_not_use(
mp, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
self.checkGraphModuleNodes(mq, expected_node_occurrence=convert_count_check)
def test_quantized_input_quantized_output(self):
prepare_custom_config_dict = {
'input_quantized_idxs': [0], 'output_quantized_idxs': [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
}
convert_count_check = {
# output of ref conv1 and output of ref conv2
ns.call_function(torch.quantize_per_tensor): 2,
# input of ref conv1 and input of ref conv2
ns.call_method('dequantize'): 2,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_quantized_output(self):
prepare_custom_config_dict = {
'output_quantized_idxs': [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
}
convert_count_check = {
# input, output of conv1 and output of conv2
ns.call_function(torch.quantize_per_tensor): 3,
# input of conv1, conv2
ns.call_method('dequantize'): 2,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_quantized_input_fp32_output(self):
prepare_custom_config_dict = {
'input_quantized_idxs': [0]}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 2,
}
convert_count_check = {
# output of conv1, conv2
ns.call_function(torch.quantize_per_tensor): 2,
# input of ref conv1, input of ref conv2, final output
ns.call_method('dequantize'): 3,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def test_fp32_input_fp32_output(self):
prepare_custom_config_dict = {}
prepare_count_check = {
ns.call_module(torch.ao.quantization.MinMaxObserver): 3,
}
convert_count_check = {
ns.call_function(torch.quantize_per_tensor): 3,
ns.call_method('dequantize'): 3,
}
self._test_quantized_inputs_outputs(
prepare_custom_config_dict, prepare_count_check, convert_count_check)
def _test_standalone_module(
self,
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check,
qconfig=None,
backend_config_dict=None):
""" Test standalone module with different quantized input/quantized output
configurations
"""
class StandaloneModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
return self.conv(x)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(1, 1, 1)
self.standalone = StandaloneModule()
def forward(self, x):
x = self.conv(x)
x = self.standalone(x)
return x
class RefM(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
if backend_config_dict is None:
backend_config_dict = self.trt_backend_config_dict
if qconfig is None:
qconfig = self.trt_qconfig
data = torch.randn(1, 1, 1, 1)
# instantiate M and RefM and align the parameters
original_m = M().eval()
original_ref_m = RefM().eval()
original_ref_m.conv1.weight = torch.nn.Parameter(original_m.conv.weight.detach())
original_ref_m.conv1.bias = torch.nn.Parameter(original_m.conv.bias.detach())
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.standalone.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.standalone.conv.bias.detach())
prepare_config = {
"standalone_module_name": [("standalone", None, interface_config, backend_config_dict)]
}
original_m_copy = copy.deepcopy(original_m)
original_ref_m_copy = copy.deepcopy(original_ref_m)
qconfig_dict = {"": qconfig}
# check prepared model
m = prepare_fx(
original_m_copy, qconfig_dict, prepare_custom_config_dict=prepare_config, backend_config_dict=backend_config_dict)
# calibration
m(data)
self.checkGraphModuleNodes(m, expected_node_occurrence=prepare_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_prepare_count_check)
# check converted/quantized model
m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict)
self.checkGraphModuleNodes(m, expected_node_occurrence=convert_count_check)
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_convert_count_check)
res = m(data)
# quantize the reference model
ref_m = prepare_fx(original_ref_m_copy, qconfig_dict, backend_config_dict=backend_config_dict)
ref_m(data)
ref_m = _convert_fx_do_not_use(ref_m, is_reference=True, backend_config_dict=backend_config_dict)
ref_res = ref_m(data)
self.assertEqual(res, ref_res)
def test_standalone_module_float_interface(self):
float_interface_config = {
"input_quantized_idxs": [], # float input
"output_quantized_idxs": [], # float output
}
interface_config = float_interface_config
# input and output of first conv, observer for standalone module
# will be inserted in the standalone module itself
prepare_count_check = {
ns.call_module(torch.ao.quantization.HistogramObserver): 2
}
# for input and output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.ao.quantization.HistogramObserver): 2
}
convert_count_check = {
# input and output of reference conv
ns.call_function(torch.quantize_per_tensor) : 2,
ns.call_module(nnqr.Conv2d) : 1,
ns.call_method("dequantize") : 2,
}
standalone_convert_count_check = {
# standalone module will take float as input and output
# so we'll see quantize and dequantize in the modoule
ns.call_function(torch.quantize_per_tensor) : 2,
ns.call_module(nnqr.Conv2d): 1,
ns.call_method("dequantize") : 2,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check)
def test_standalone_module_quantized_interface(self):
quantized_interface_config = {
"input_quantized_idxs": [0], # quantized input
"output_quantized_idxs": [0], # quantized output
}
interface_config = quantized_interface_config
# TODO: input_quantized_idxs only supports quint8, we can remove this
# custom_backend_config_dict after
# the `input_quantized_idxs` supports more complicated
# configurations, as a first step we can change it to use a dictionary from
# index to dtype
qconfig = torch.ao.quantization.QConfig(
activation=torch.ao.quantization.observer.HistogramObserver.with_args(
qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
),
weight=torch.ao.quantization.default_weight_observer
)
weighted_op_quint8_dtype_config = {
# optional, input activation dtype
"input_dtype": torch.quint8,
# optional, weight dtype
"weight_dtype": torch.qint8,
# optional, bias dtype
"bias_dtype": torch.float,
# optional, output activation dtype
"output_dtype": torch.quint8
}
conv_module_config = {
"pattern": torch.nn.Conv2d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_quint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
custom_backend_config_dict = {
"configs": [conv_module_config]
}
# observer for input and output of first conv
prepare_count_check = {
ns.call_module(torch.ao.quantization.HistogramObserver): 2
}
# for output of conv in the standalone module
standalone_prepare_count_check = {
ns.call_module(torch.ao.quantization.HistogramObserver): 1
}
convert_count_check = {
# quantizing input/output for reference conv
ns.call_function(torch.quantize_per_tensor) : 2,
ns.call_module(nnqr.Conv2d) : 1,
# dequantize the input of reference conv and
# dequantizing output of standalone module
ns.call_method("dequantize") : 2,
}
standalone_convert_count_check = {
# quantization of input happens in parent module
# quantization of output happens in the standalone module
ns.call_function(torch.quantize_per_tensor) : 1,
ns.call_module(nnqr.Conv2d): 1,
# dequantization of input happens in the standalone module
# dequantization for output happens in parent module
ns.call_method("dequantize") : 1,
}
self._test_standalone_module(
interface_config,
prepare_count_check,
standalone_prepare_count_check,
convert_count_check,
standalone_convert_count_check,
qconfig=qconfig,
backend_config_dict=custom_backend_config_dict)
@unittest.skipIf(not TEST_CUDA, "gpu is not available.")
class TestQuantizeFxTRTOps(QuantizationTestCase):
""" Test TensorRT operator support
"""
def setUp(self):
super().setUp()
self.qconfig = torch.ao.quantization.QConfig(
activation=torch.ao.quantization.observer.HistogramObserver.with_args(
qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
),
weight=torch.ao.quantization.default_weight_observer
)
self.trt_backend_config_dict = get_tensorrt_backend_config_dict()
def _test_module(
self,
m,
inputs,
shape_ranges,
no_prepare=None,
no_convert=None,
is_qat=False):
"""
Args:
m: the float module we want to test
inputs: list of inputs for the module
shape_ranges: a list of shape_range, where every shape_range is a tuple of
three tuples
((min_input_shape), (optimized_input_shape), (max_input_shape)).
Each shape_range is used to populate a TensorRT optimization profile.
e.g. If the input shape varies from (1, 224) to (100, 224) and we want to optimize
for (25, 224) because it's the most common input shape, then we set shape_ranges to
((1, 224), (25, 225), (100, 224))
no_prepare: node occurrence after prepare
no_convert: node occurrence after convert
"""
if is_qat:
m = m.train()
prepare = prepare_qat_fx
else:
m = m.eval()
prepare = prepare_fx
prepared = prepare(m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
self.checkGraphModuleNodes(prepared, expected_node_occurrence=no_prepare)
# calibration
prepared(*inputs)
quantized = _convert_fx_do_not_use(
prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
self.checkGraphModuleNodes(quantized, expected_node_occurrence=no_convert)
# lower to trt
trt_mod = lower_to_trt(quantized, inputs, shape_ranges)
inputs_cuda = [i.cuda() for i in inputs]
# make sure it runs
trt_mod(*inputs_cuda)
def test_conv_relu_module(self):
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
conv1d_input = torch.rand(1, 3, 10)
conv2d_input = torch.rand(1, 3, 10, 10)
conv3d_input = torch.rand(1, 3, 10, 10, 10)
conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input}
class ConvNdModule(torch.nn.Module):
def __init__(self, dim, has_relu=False, f_relu=False):
super().__init__()
self.conv = conv_module[dim](3, 3, 3).float()
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(self.conv(x))
# just testing conv2d since conv1d and conv3d are not supported in fx2trt
for dim, has_relu, f_relu, is_qat in itertools.product([2], [True, False], [True, False], [True, False]):
# when has_relu=False, we have torch.nn.Identity, which would introduce
# extra quant-dequat pair
no_convert = {
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
ns.call_method("dequantize"): 2 + int(not has_relu),
}
self._test_module(
ConvNdModule(dim, has_relu, f_relu),
[conv_input[dim]],
[((1, *conv_input[dim].shape[1:]),
(5, *conv_input[dim].shape[1:]),
(10, *conv_input[dim].shape[1:]))],
no_convert=no_convert,
is_qat=is_qat)
def test_linear_relu_module(self):
class LinearModule(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super().__init__()
self.linear = torch.nn.Linear(5, 10).float()
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(self.linear(x))
linear_input = torch.rand(8, 5)
shape_ranges = [
((1, 5),
(5, 5),
(10, 5))
]
for has_relu, f_relu, is_qat in itertools.product([True, False], [True, False], [True, False]):
# when has_relu=False, we have torch.nn.Identity, which would introduce
# extra quant-dequat pair
no_convert = {
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
ns.call_method("dequantize"): 2 + int(not has_relu),
}
self._test_module(
LinearModule(has_relu, f_relu),
[linear_input],
shape_ranges,
no_convert=no_convert,
is_qat=is_qat)
def test_ops(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.linear = torch.nn.Linear(5, 5)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.linear(x)
x = x + 3
x = self.relu(x)
x = x + 6
return x
m = M().eval()
m = prepare_fx(m, {"": default_qconfig})
m = _convert_fx_do_not_use(
m, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
expected_occurrence = {
ns.call_function(torch.quantize_per_tensor): 5,
ns.call_method("dequantize"): 5,
ns.call_module(torch.nn.quantized._reference.Linear): 1,
ns.call_module(torch.nn.quantized._reference.Conv2d): 1,
}
self.checkGraphModuleNodes(
m,
expected_node_occurrence=expected_occurrence)
def test_unsupported_qconfig(self):
""" Check that we won't quantize the model if the qconfig is not supported
"""
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
linear_module_input = torch.rand(8, 5)
m = LinearModule().eval()
trt_unsupported_qconfig = default_qconfig
prepared = prepare_fx(m, {"": trt_unsupported_qconfig}, backend_config_dict=self.trt_backend_config_dict)
# calibration
prepared(linear_module_input)
quantized = _convert_fx_do_not_use(
prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 0,
ns.call_method("dequantize"): 0,
ns.call_module(torch.nn.Linear): 1,
ns.call_module(torch.nn.quantized._reference.Linear): 0,
}
# check model is not quantized
self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_cat(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.cat([x, x], 1)
m = M().eval()
prepared = prepare_fx(
m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
self.assertTrue(len(dict(prepared.named_children())) == 1)
quantized = _convert_fx_do_not_use(
prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 2,
ns.call_function(torch.cat): 1,
ns.call_method("dequantize"): 2,
}
self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_addmm(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.randn(5, 5)
self.bias = torch.randn(5)
def forward(self, x):
return torch.addmm(self.bias, x, self.weight)
m = M().eval()
prepared = prepare_fx(
m, {"": self.qconfig}, backend_config_dict=self.trt_backend_config_dict)
node_occurrence = {
# weight
ns.call_module(torch.ao.quantization.MinMaxObserver): 1,
# activation
ns.call_module(torch.ao.quantization.HistogramObserver): 2,
}
self.checkGraphModuleNodes(prepared, expected_node_occurrence=node_occurrence)
quantized = _convert_fx_do_not_use(
prepared, is_reference=True, backend_config_dict=self.trt_backend_config_dict)
node_occurrence = {
# input activation, output activation and weight
ns.call_function(torch.quantize_per_tensor): 3,
ns.call_function(torch.addmm): 1,
ns.call_method("dequantize"): 3,
}
self.checkGraphModuleNodes(quantized, expected_node_occurrence=node_occurrence)
def test_conv_add(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x, y):
return self.conv(x) + y
weighted_op_qint8_dtype_config = {
# optional, input activation dtype
"input_dtype": torch.qint8,
# optional, weight dtype
"weight_dtype": torch.qint8,
# optional, bias dtype
"bias_dtype": torch.float,
# optional, output activation dtype
"output_dtype": torch.qint8
}
conv_add_config = {
"pattern": (operator.add, torch.nn.Conv2d, MatchAllNode),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_qint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
"reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
m = M().eval()
modified_backend_config_dict = copy.deepcopy(self.trt_backend_config_dict)
modified_backend_config_dict["configs"].insert(0, conv_add_config)
m = prepare_fx(m, {"": self.qconfig}, backend_config_dict=modified_backend_config_dict)
node_occurrence = {
ns.call_module(torch.ao.quantization.HistogramObserver): 3,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=modified_backend_config_dict)
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 3,
ns.call_method("dequantize"): 3,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
def test_conv_add_standalone_module(self):
class Standalone(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()
def forward(self, x, y):
return self.relu(self.conv(x) + y)
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.standalone = Standalone()
def forward(self, x, y):
y = self.conv(x)
return self.standalone(x, y)
from torch.ao.quantization.fx.backend_config_dict.observation_type import ObservationType
weighted_op_quint8_dtype_config = {
# optional, input activation dtype
# TODO: change back to torch.qint8 after input_quantized_idxs and output_quantized_idxs
# are more flexible
"input_dtype": torch.quint8,
# optional, weight dtype
"weight_dtype": torch.qint8,
# optional, bias dtype
"bias_dtype": torch.float,
# optional, output activation dtype
"output_dtype": torch.quint8
}
conv_add_config = {
"pattern": (torch.nn.ReLU, (operator.add, torch.nn.Conv2d, MatchAllNode)),
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_quint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
# "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
conv_config = {
"pattern": torch.nn.Conv2d,
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
"dtype_configs": [
weighted_op_quint8_dtype_config,
],
"root_module": torch.nn.Conv2d,
# "reference_quantized_module_for_root": torch.nn.quantized._reference.Conv2d,
}
m = M().eval()
backend_config_dict = {
"configs": [
conv_add_config,
conv_config,
]
}
prepare_custom_config_dict = {
"standalone_module_name": [("standalone", None, {"input_quantized_idxs": [0, 1]}, None)]
}
# TODO: use self.qconfig after input_quantized_idxs and output_quantized_idxs
# are more flexible
qconfig = torch.ao.quantization.QConfig(
activation=torch.ao.quantization.observer.HistogramObserver.with_args(
qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
),
weight=torch.ao.quantization.default_weight_observer
)
m = prepare_fx(
m,
{"": qconfig},
prepare_custom_config_dict=prepare_custom_config_dict,
backend_config_dict=backend_config_dict)
node_occurrence = {
# for input and output of conv, where input is used twice, once in conv and
# once in standalone module
ns.call_module(torch.ao.quantization.HistogramObserver): 2,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
standalone_node_occurrence = {
# output of the standalone module
ns.call_module(torch.ao.quantization.HistogramObserver): 1,
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
m = _convert_fx_do_not_use(m, is_reference=True, backend_config_dict=backend_config_dict)
node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 3,
ns.call_module(nn.Conv2d): 1,
ns.call_method("dequantize"): 1,
}
self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence)
standalone_node_occurrence = {
ns.call_function(torch.quantize_per_tensor): 1,
ns.call_module(nn.Conv2d): 1,
ns.call_module(torch.nn.ReLU): 1,
ns.call_method("dequantize"): 3,
}
self.checkGraphModuleNodes(m.standalone, expected_node_occurrence=standalone_node_occurrence)
if __name__ == "__main__":
run_tests()