| # -*- coding: utf-8 -*- |
| # Owner(s): ["oncall: quantization"] |
| |
| import sys |
| import os |
| import unittest |
| from typing import Set |
| |
| # torch |
| import torch |
| import torch.nn as nn |
| import torch.ao.nn.quantized as nnq |
| import torch.ao.nn.quantized.dynamic as nnqd |
| import torch.ao.nn.intrinsic.quantized as nniq |
| from torch.fx import GraphModule |
| |
| # Testing utils |
| from torch.testing._internal.common_utils import TestCase, IS_AVX512_VNNI_SUPPORTED |
| from torch.testing._internal.common_quantized import override_qengines, qengine_is_fbgemm |
| from torch.testing._internal.common_quantization import skipIfNoFBGEMM |
| from torch.testing._internal.quantization_torch_package_models import LinearReluFunctional |
| |
| from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver |
| import torch.ao.quantization.quantize_fx as quantize_fx |
| |
| def remove_prefix(text, prefix): |
| if text.startswith(prefix): |
| return text[len(prefix):] |
| return text |
| |
| def get_filenames(self, subname): |
| # NB: we take __file__ from the module that defined the test |
| # class, so we place the expect directory where the test script |
| # lives, NOT where test/common_utils.py lives. |
| module_id = self.__class__.__module__ |
| munged_id = remove_prefix(self.id(), module_id + ".") |
| test_file = os.path.realpath(sys.modules[module_id].__file__) |
| base_name = os.path.join(os.path.dirname(test_file), |
| "../serialized", |
| munged_id) |
| |
| subname_output = "" |
| if subname: |
| base_name += "_" + subname |
| subname_output = " ({})".format(subname) |
| |
| input_file = base_name + ".input.pt" |
| state_dict_file = base_name + ".state_dict.pt" |
| scripted_module_file = base_name + ".scripted.pt" |
| traced_module_file = base_name + ".traced.pt" |
| expected_file = base_name + ".expected.pt" |
| package_file = base_name + ".package.pt" |
| get_attr_targets_file = base_name + ".get_attr_targets.pt" |
| |
| return input_file, state_dict_file, scripted_module_file, \ |
| traced_module_file, expected_file, package_file, get_attr_targets_file |
| |
| |
| class TestSerialization(TestCase): |
| """ Test backward compatiblity for serialization and numerics |
| """ |
| # Copy and modified from TestCase.assertExpected |
| def _test_op(self, qmodule, subname=None, input_size=None, input_quantized=True, |
| generate=False, prec=None, new_zipfile_serialization=False): |
| r""" Test quantized modules serialized previously can be loaded |
| with current code, make sure we don't break backward compatibility for the |
| serialization of quantized modules |
| """ |
| input_file, state_dict_file, scripted_module_file, traced_module_file, \ |
| expected_file, _package_file, _get_attr_targets_file = \ |
| get_filenames(self, subname) |
| |
| # only generate once. |
| if generate and qengine_is_fbgemm(): |
| input_tensor = torch.rand(*input_size).float() |
| if input_quantized: |
| input_tensor = torch.quantize_per_tensor(input_tensor, 0.5, 2, torch.quint8) |
| torch.save(input_tensor, input_file) |
| # Temporary fix to use _use_new_zipfile_serialization until #38379 lands. |
| torch.save(qmodule.state_dict(), state_dict_file, _use_new_zipfile_serialization=new_zipfile_serialization) |
| torch.jit.save(torch.jit.script(qmodule), scripted_module_file) |
| torch.jit.save(torch.jit.trace(qmodule, input_tensor), traced_module_file) |
| torch.save(qmodule(input_tensor), expected_file) |
| |
| input_tensor = torch.load(input_file) |
| qmodule.load_state_dict(torch.load(state_dict_file)) |
| qmodule_scripted = torch.jit.load(scripted_module_file) |
| qmodule_traced = torch.jit.load(traced_module_file) |
| expected = torch.load(expected_file) |
| self.assertEqual(qmodule(input_tensor), expected, atol=prec) |
| self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) |
| self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) |
| |
| def _test_op_graph(self, qmodule, subname=None, input_size=None, input_quantized=True, |
| generate=False, prec=None, new_zipfile_serialization=False): |
| r""" |
| Input: a floating point module |
| |
| If generate == True, traces and scripts the module and quantizes the results with |
| PTQ, and saves the results. |
| |
| If generate == False, traces and scripts the module and quantizes the results with |
| PTQ, and compares to saved results. |
| """ |
| input_file, state_dict_file, scripted_module_file, traced_module_file, \ |
| expected_file, _package_file, _get_attr_targets_file = \ |
| get_filenames(self, subname) |
| |
| # only generate once. |
| if generate and qengine_is_fbgemm(): |
| input_tensor = torch.rand(*input_size).float() |
| torch.save(input_tensor, input_file) |
| |
| # convert to TorchScript |
| scripted = torch.jit.script(qmodule) |
| traced = torch.jit.trace(qmodule, input_tensor) |
| |
| # quantize |
| |
| def _eval_fn(model, data): |
| model(data) |
| |
| qconfig_dict = {'': torch.ao.quantization.default_qconfig} |
| scripted_q = torch.ao.quantization.quantize_jit( |
| scripted, qconfig_dict, _eval_fn, [input_tensor]) |
| traced_q = torch.ao.quantization.quantize_jit( |
| traced, qconfig_dict, _eval_fn, [input_tensor]) |
| |
| torch.jit.save(scripted_q, scripted_module_file) |
| torch.jit.save(traced_q, traced_module_file) |
| torch.save(scripted_q(input_tensor), expected_file) |
| |
| input_tensor = torch.load(input_file) |
| qmodule_scripted = torch.jit.load(scripted_module_file) |
| qmodule_traced = torch.jit.load(traced_module_file) |
| expected = torch.load(expected_file) |
| self.assertEqual(qmodule_scripted(input_tensor), expected, atol=prec) |
| self.assertEqual(qmodule_traced(input_tensor), expected, atol=prec) |
| |
| def _test_obs(self, obs, input_size, subname=None, generate=False, check_numerics=True): |
| """ |
| Test observer code can be loaded from state_dict. |
| """ |
| input_file, state_dict_file, _, traced_module_file, expected_file, \ |
| _package_file, _get_attr_targets_file = get_filenames(self, None) |
| if generate: |
| input_tensor = torch.rand(*input_size).float() |
| torch.save(input_tensor, input_file) |
| torch.save(obs(input_tensor), expected_file) |
| torch.save(obs.state_dict(), state_dict_file) |
| |
| input_tensor = torch.load(input_file) |
| obs.load_state_dict(torch.load(state_dict_file)) |
| expected = torch.load(expected_file) |
| if check_numerics: |
| self.assertEqual(obs(input_tensor), expected) |
| |
| def _test_package(self, fp32_module, input_size, generate=False): |
| """ |
| Verifies that files created in the past with torch.package |
| work on today's FX graph mode quantization transforms. |
| """ |
| input_file, state_dict_file, _scripted_module_file, _traced_module_file, \ |
| expected_file, package_file, get_attr_targets_file = \ |
| get_filenames(self, None) |
| |
| package_name = 'test' |
| resource_name_model = 'test.pkl' |
| |
| def _do_quant_transforms( |
| m: torch.nn.Module, |
| input_tensor: torch.Tensor, |
| ) -> torch.nn.Module: |
| example_inputs = (input_tensor,) |
| # do the quantizaton transforms and save result |
| qconfig = torch.ao.quantization.get_default_qconfig('fbgemm') |
| mp = quantize_fx.prepare_fx(m, {'': qconfig}, example_inputs=example_inputs) |
| mp(input_tensor) |
| mq = quantize_fx.convert_fx(mp) |
| return mq |
| |
| def _get_get_attr_target_strings(m: GraphModule) -> Set[str]: |
| results = set() |
| for node in m.graph.nodes: |
| if node.op == 'get_attr': |
| results.add(node.target) |
| return results |
| |
| if generate and qengine_is_fbgemm(): |
| input_tensor = torch.randn(*input_size) |
| torch.save(input_tensor, input_file) |
| |
| # save the model with torch.package |
| with torch.package.PackageExporter(package_file) as exp: |
| exp.intern('torch.testing._internal.quantization_torch_package_models') |
| exp.save_pickle(package_name, resource_name_model, fp32_module) |
| |
| # do the quantization transforms and save the result |
| mq = _do_quant_transforms(fp32_module, input_tensor) |
| get_attrs = _get_get_attr_target_strings(mq) |
| torch.save(get_attrs, get_attr_targets_file) |
| q_result = mq(input_tensor) |
| torch.save(q_result, expected_file) |
| |
| # load input tensor |
| input_tensor = torch.load(input_file) |
| expected_output_tensor = torch.load(expected_file) |
| expected_get_attrs = torch.load(get_attr_targets_file) |
| |
| # load model from package and verify output and get_attr targets match |
| imp = torch.package.PackageImporter(package_file) |
| m = imp.load_pickle(package_name, resource_name_model) |
| mq = _do_quant_transforms(m, input_tensor) |
| |
| get_attrs = _get_get_attr_target_strings(mq) |
| self.assertTrue( |
| get_attrs == expected_get_attrs, |
| f'get_attrs: expected {expected_get_attrs}, got {get_attrs}') |
| output_tensor = mq(input_tensor) |
| self.assertTrue(torch.allclose(output_tensor, expected_output_tensor)) |
| |
| @override_qengines |
| def test_linear(self): |
| module = nnq.Linear(3, 1, bias_=True, dtype=torch.qint8) |
| self._test_op(module, input_size=[1, 3], generate=False) |
| |
| @override_qengines |
| def test_linear_relu(self): |
| module = nniq.LinearReLU(3, 1, bias=True, dtype=torch.qint8) |
| self._test_op(module, input_size=[1, 3], generate=False) |
| |
| @override_qengines |
| def test_linear_dynamic(self): |
| module_qint8 = nnqd.Linear(3, 1, bias_=True, dtype=torch.qint8) |
| self._test_op(module_qint8, "qint8", input_size=[1, 3], input_quantized=False, generate=False) |
| if qengine_is_fbgemm(): |
| module_float16 = nnqd.Linear(3, 1, bias_=True, dtype=torch.float16) |
| self._test_op(module_float16, "float16", input_size=[1, 3], input_quantized=False, generate=False) |
| |
| @override_qengines |
| def test_conv2d(self): |
| module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros") |
| self._test_op(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_nobias(self): |
| module = nnq.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=False, padding_mode="zeros") |
| self._test_op(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_graph(self): |
| module = nn.Sequential( |
| torch.ao.quantization.QuantStub(), |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros"), |
| ) |
| self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_nobias_graph(self): |
| module = nn.Sequential( |
| torch.ao.quantization.QuantStub(), |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=False, padding_mode="zeros"), |
| ) |
| self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_graph_v2(self): |
| # tests the same thing as test_conv2d_graph, but for version 2 of |
| # ConvPackedParams{n}d |
| module = nn.Sequential( |
| torch.ao.quantization.QuantStub(), |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros"), |
| ) |
| self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_nobias_graph_v2(self): |
| # tests the same thing as test_conv2d_nobias_graph, but for version 2 of |
| # ConvPackedParams{n}d |
| module = nn.Sequential( |
| torch.ao.quantization.QuantStub(), |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=False, padding_mode="zeros"), |
| ) |
| self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_graph_v3(self): |
| # tests the same thing as test_conv2d_graph, but for version 3 of |
| # ConvPackedParams{n}d |
| module = nn.Sequential( |
| torch.ao.quantization.QuantStub(), |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros"), |
| ) |
| self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_nobias_graph_v3(self): |
| # tests the same thing as test_conv2d_nobias_graph, but for version 3 of |
| # ConvPackedParams{n}d |
| module = nn.Sequential( |
| torch.ao.quantization.QuantStub(), |
| nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=False, padding_mode="zeros"), |
| ) |
| self._test_op_graph(module, input_size=[1, 3, 6, 6], generate=False) |
| |
| @override_qengines |
| def test_conv2d_relu(self): |
| module = nniq.ConvReLU2d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros") |
| self._test_op(module, input_size=[1, 3, 6, 6], generate=False) |
| # TODO: graph mode quantized conv2d module |
| |
| @override_qengines |
| def test_conv3d(self): |
| if qengine_is_fbgemm(): |
| module = nnq.Conv3d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros") |
| self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False) |
| # TODO: graph mode quantized conv3d module |
| |
| @override_qengines |
| def test_conv3d_relu(self): |
| if qengine_is_fbgemm(): |
| module = nniq.ConvReLU3d(3, 3, kernel_size=3, stride=1, padding=0, dilation=1, |
| groups=1, bias=True, padding_mode="zeros") |
| self._test_op(module, input_size=[1, 3, 6, 6, 6], generate=False) |
| # TODO: graph mode quantized conv3d module |
| |
| @override_qengines |
| @unittest.skipIf(IS_AVX512_VNNI_SUPPORTED, "This test fails on machines with AVX512_VNNI support. Ref: GH Issue 59098") |
| def test_lstm(self): |
| class LSTMModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.lstm = nnqd.LSTM(input_size=3, hidden_size=7, num_layers=1).to(dtype=torch.float) |
| |
| def forward(self, x): |
| x = self.lstm(x) |
| return x |
| if qengine_is_fbgemm(): |
| mod = LSTMModule() |
| self._test_op(mod, input_size=[4, 4, 3], input_quantized=False, generate=False, new_zipfile_serialization=True) |
| |
| def test_per_channel_observer(self): |
| obs = PerChannelMinMaxObserver() |
| self._test_obs(obs, input_size=[5, 5], generate=False) |
| |
| def test_per_tensor_observer(self): |
| obs = MinMaxObserver() |
| self._test_obs(obs, input_size=[5, 5], generate=False) |
| |
| def test_default_qat_qconfig(self): |
| class Model(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = nn.Linear(5, 5) |
| self.relu = nn.ReLU() |
| |
| def forward(self, x): |
| x = self.linear(x) |
| x = self.relu(x) |
| return x |
| |
| model = Model() |
| model.linear.weight = torch.nn.Parameter(torch.randn(5, 5)) |
| model.qconfig = torch.ao.quantization.get_default_qat_qconfig("fbgemm") |
| ref_model = torch.ao.quantization.QuantWrapper(model) |
| ref_model = torch.ao.quantization.prepare_qat(ref_model) |
| self._test_obs(ref_model, input_size=[5, 5], generate=False, check_numerics=False) |
| |
| @skipIfNoFBGEMM |
| def test_linear_relu_package_quantization_transforms(self): |
| m = LinearReluFunctional(4).eval() |
| self._test_package(m, input_size=(1, 1, 4, 4), generate=False) |