| # -*- coding: utf-8 -*- |
| # Owner(s): ["oncall: quantization"] |
| |
| # torch |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.jit |
| import torch.jit.quantized |
| |
| # torch.ao.quantization |
| from torch.ao.quantization import ( |
| QConfig, |
| default_dynamic_qconfig, |
| float16_dynamic_qconfig, |
| default_observer, |
| per_channel_dynamic_qconfig, |
| default_per_channel_weight_observer, |
| default_qconfig, |
| get_default_qconfig, |
| quantize, |
| quantize_dynamic, |
| default_weight_observer, |
| default_histogram_observer, |
| fuse_modules, |
| quantize_jit, |
| quantize_dynamic_jit, |
| PlaceholderObserver, |
| ) |
| |
| # torch.ao.quantization.quantize_jit |
| from torch.ao.quantization.quantize_jit import ( |
| convert_jit, |
| convert_dynamic_jit, |
| fuse_conv_bn_jit, |
| prepare_jit, |
| prepare_dynamic_jit, |
| script_qconfig, |
| ) |
| |
| # Testing utils |
| from torch.testing._internal.common_quantized import ( |
| override_qengines, |
| qengine_is_fbgemm, |
| qengine_is_qnnpack, |
| ) |
| |
| from torch.testing._internal.common_quantization import ( |
| QuantizationTestCase, |
| skipIfNoFBGEMM, |
| get_script_module, |
| SingleLayerLinearModel, |
| SkipQuantModel, |
| NestedModel, |
| ConvModel, |
| ConvTransposeModel, |
| default_per_channel_qconfig, |
| test_only_eval_fn, |
| ConvBnModel, |
| ) |
| |
| # Annotated models |
| from torch.testing._internal.common_quantization import ( |
| AnnotatedSingleLayerLinearModel, |
| AnnotatedSkipQuantModel, |
| AnnotatedNestedModel, |
| AnnotatedConvModel, |
| AnnotatedConvTransposeModel, |
| AnnotatedConvBnModel, |
| ) |
| |
| from torch.testing import FileCheck |
| from torch.testing._internal.jit_utils import attrs_with_prefix |
| from torch.testing._internal.jit_utils import get_forward |
| from torch.testing._internal.jit_utils import get_forward_graph |
| from torch.testing._internal.common_utils import skipIfSlowGradcheckEnv |
| |
| from torch.jit._recursive import wrap_cpp_module |
| |
| # Standard library |
| from typing import List, Tuple |
| import io |
| import itertools |
| import unittest |
| |
| |
| class TestQuantizeJitPasses(QuantizationTestCase): |
| """Test graph mode quantization passes used by quantize_jit""" |
| |
| def test_skip_dequant_constant_prop(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3).float() |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| m = torch.jit.script(M()) |
| observer = ( |
| default_per_channel_weight_observer.with_args(ch_axis=1) |
| ) |
| qconfig_dict = {"": QConfig(activation=default_observer, weight=observer)} |
| m = prepare_jit(m, qconfig_dict) |
| data = torch.randn(1, 3, 10, 10, dtype=torch.float) |
| |
| m(data) |
| m = convert_jit(m, debug=True) |
| |
| freezed = torch.jit.freeze(m) |
| freezed(data) |
| |
| # After freezing, weight becomes Constant. |
| # We have this pattern in the original graph: Constant f32_weight -> quant -> dequant |
| # After skipping dequant during Constant Propagation, the resulting graph will be: |
| # Constant int8_weight -> dequant |
| FileCheck().check_count("aten::quantize_per_tensor", 2, exactly=True).run(freezed.graph) |
| FileCheck().check_count("aten::quantize_per_channel", 0, exactly=True).run(freezed.graph) |
| FileCheck().check_count("aten::dequantize", 3, exactly=True).run(freezed.graph) |
| FileCheck().check("aten::quantize_per_tensor").check_next("aten::dequantize").check_not( |
| "aten::quantize_per_channel" |
| ).check("aten::dequantize").check_next("aten::conv2d").check_next( |
| "aten::quantize_per_tensor" |
| ).check_next( |
| "aten::dequantize" |
| ).run( |
| freezed.graph |
| ) |
| |
| def test_foldbn_trivial(self): |
| bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} |
| conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| # Test trivial case |
| class TestModule(torch.nn.Module): |
| def __init__(self, dim): |
| super(TestModule, self).__init__() |
| self.conv = conv_module[dim](1, 20, 5, 1) |
| self.bn = bn_module[dim](num_features=20) |
| self.bn.eps = 0.0023 |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
| |
| options = itertools.product([True, False], [2, 3]) |
| data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)} |
| # Check that the transformation doesn't change numerics |
| for tracing, dim in options: |
| eager = TestModule(dim).eval() |
| x = data[dim] |
| scripted_or_traced = get_script_module(eager, tracing, x).eval() |
| # Check that in the original script module's forward we have two |
| # CallMethod nodes. One of them should be for conv.forward and the other |
| # for bn.forward. |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', 2, exactly=True |
| ).run(str(get_forward(scripted_or_traced._c).graph)) |
| |
| # Run FoldConvBatchnorm pass. |
| scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) |
| |
| # Check that after the pass one of the CallMethods is gone (supposedly, |
| # the bn.forward). |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', 1, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced._c))) |
| |
| # Check that the transformation doesn't change numerics |
| self.assertEqual(eager(x), scripted_or_traced(x)) |
| |
| def test_foldbn_trivial_nobias(self): |
| bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} |
| conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| # Test trivial case |
| class TestModule(torch.nn.Module): |
| def __init__(self, dim): |
| super(TestModule, self).__init__() |
| self.conv = conv_module[dim](1, 20, 5, 1, bias=False) |
| self.bn = bn_module[dim](num_features=20) |
| # to make sure new bias is not zero |
| self.bn.eps = 0.0027 |
| self.bn.bias = torch.nn.Parameter(torch.rand([20])) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
| |
| options = itertools.product([True, False], [2, 3]) |
| data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)} |
| for tracing, dim in options: |
| eager = TestModule(dim).eval() |
| x = data[dim] |
| scripted_or_traced = get_script_module(eager, tracing, x).eval() |
| # Check that in the original script module's forward we have two |
| # CallMethod nodes. One of them should be for conv.forward and the other |
| # for bn.forward. |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', 2, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced._c))) |
| |
| # Run FoldConvBatchnorm pass. |
| scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) |
| |
| # Check that after the pass one of the CallMethods is gone (supposedly, |
| # the bn.forward). |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', 1, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced._c))) |
| |
| # Check that the transformation doesn't change numerics |
| self.assertEqual(eager(x), scripted_or_traced(x)) |
| |
| def test_foldbn_in_submodule(self): |
| bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} |
| conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| # Test that we find Conv-BN patterns in submodules |
| class SubModule(torch.nn.Module): |
| def __init__(self, dim): |
| super(SubModule, self).__init__() |
| self.conv = conv_module[dim](1, 20, 5, 1) |
| self.bn = bn_module[dim](num_features=20) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.bn(x) |
| return x |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self, dim): |
| super(TestModule, self).__init__() |
| self.sub = SubModule(dim) |
| |
| def forward(self, x): |
| x = self.sub(x) |
| return x |
| |
| options = itertools.product([True, False], [2, 3]) |
| data = {2: torch.rand(1, 1, 10, 10), 3: torch.rand(1, 1, 10, 10, 10)} |
| for tracing, dim in options: |
| eager = TestModule(dim).eval() |
| x = data[dim] |
| scripted_or_traced = get_script_module(eager, tracing, x).eval() |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', 2, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced.sub._c))) |
| |
| scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) |
| |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', 1, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced.sub._c))) |
| |
| self.assertEqual(eager(x), scripted_or_traced(x)) |
| |
| def test_foldbn_shared_classtype(self): |
| bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} |
| conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self, dim, bias=False): |
| super(TestModule, self).__init__() |
| self.conv1 = conv_module[dim](5, 5, 3, bias=bias) |
| self.bn1 = bn_module[dim](num_features=5) |
| self.bn1.running_mean.fill_(-0.2) |
| self.bn1.bias = torch.nn.Parameter(torch.rand([5])) |
| # to make sure new bias is not zero |
| self.bn1.eps = 0.0023 |
| self.conv2 = conv_module[dim](5, 5, 3, bias=bias) |
| self.bn2 = bn_module[dim](num_features=5) |
| self.bn2.eps = 0.0029 |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.bn1(x) |
| x = self.relu(x) |
| x = self.conv2(x) |
| x = self.bn2(x) |
| x = self.relu(x) |
| return x |
| |
| options = itertools.product([True, False], [2, 2], [True, False]) |
| data = {2: torch.rand(1, 5, 6, 6), 3: torch.rand(1, 5, 6, 6, 6)} |
| for tracing, dim, bias in options: |
| eager = TestModule(dim, bias).eval() |
| x = data[dim] |
| scripted_or_traced = get_script_module(eager, tracing, x) |
| folded = fuse_conv_bn_jit(scripted_or_traced) |
| self.assertEqual(eager(x), scripted_or_traced(x)) |
| |
| def test_foldbn_no_fusion(self): |
| """Test that we don't fuse the cases when module type does not match""" |
| |
| class CustomConv(torch.nn.Module): |
| def __init__(self): |
| super(CustomConv, self).__init__() |
| |
| def forward(self, x): |
| return x |
| |
| class CustomBn(torch.nn.Module): |
| def __init__(self): |
| super(CustomBn, self).__init__() |
| |
| def forward(self, x): |
| return x |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = CustomConv() |
| self.bn = CustomBn() |
| |
| def forward(self, x): |
| return self.bn(self.conv(x)) |
| |
| m = torch.jit.script(M()) |
| m = fuse_conv_bn_jit(m) |
| FileCheck().check_count("prim::CallMethod", 2, exactly=True).run(m.graph) |
| |
| def test_foldbn_complex_cases(self): |
| # This test case attempt to try combinations of conv2d/conv3d with bias/nobias |
| # as well as BatchNorm with affine/no-affine along with varying the |
| # number of layers. |
| # this only works when default dtype is double |
| torch.set_default_dtype(torch.double) |
| bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} |
| conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| class SubModule(torch.nn.Module): |
| def __init__(self, dim, num_blocks, enable_bias, enable_affine): |
| super(SubModule, self).__init__() |
| layers = [] |
| for i in range(num_blocks): |
| layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias)) |
| bn_obj = bn_module[dim](num_features=20, affine=enable_affine) |
| if enable_affine: |
| bn_obj.weight = torch.nn.Parameter( |
| torch.rand_like(bn_obj.weight) |
| ) |
| bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias)) |
| bn_obj.running_mean = torch.rand_like(bn_obj.running_mean) |
| bn_obj.running_var = torch.rand_like(bn_obj.running_var) |
| layers.append(bn_obj) |
| self.layers = nn.Sequential(*layers) |
| |
| def forward(self, x): |
| return self.layers(x) |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self, dim, num_blocks, enable_bias, enable_affine): |
| super(TestModule, self).__init__() |
| self.sub = SubModule(dim, num_blocks, enable_bias, enable_affine) |
| |
| def forward(self, x): |
| x = self.sub(x) |
| return x |
| |
| options = itertools.product( |
| [True, False], [2, 3], [True, False], [True, False], [1, 2] |
| ) |
| data = {2: torch.rand(1, 20, 10, 10), 3: torch.rand(1, 20, 10, 10, 10)} |
| for tracing, dim, enable_bias, enable_bn_affine, num_layers in options: |
| eager = TestModule(dim, num_layers, enable_bias, enable_bn_affine).eval() |
| x = data[dim] |
| scripted_or_traced = get_script_module(eager, tracing, x).eval() |
| |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', num_layers * 2, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced.sub.layers._c))) |
| |
| scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced) |
| |
| FileCheck().check_count( |
| 'prim::CallMethod[name="forward"]', num_layers, exactly=True |
| ).run(str(get_forward_graph(scripted_or_traced.sub.layers._c))) |
| |
| self.assertEqual(eager(x), scripted_or_traced(x)) |
| |
| torch.set_default_dtype(torch.float) |
| |
| def test_fuse_linear(self): |
| class FunctionalLinear(torch.nn.Module): |
| def __init__(self, weight, bias): |
| super(FunctionalLinear, self).__init__() |
| self.weight = weight |
| self.bias = bias |
| |
| def forward(self, x): |
| res = torch.matmul(x, self.weight.t()) |
| if self.bias is not None: |
| res.add_(self.bias) |
| return res |
| |
| x1 = torch.rand(3) |
| w1 = torch.rand(5, 3) |
| b1 = torch.rand(5) |
| |
| x2 = torch.rand(5, 5) |
| w2 = torch.rand(5, 5) |
| b2 = torch.rand(5) |
| |
| x3 = torch.rand(5, 5, 5) |
| w3 = torch.rand(5, 5) |
| b3 = torch.rand(5) |
| for has_bias, (x, weight, b) in itertools.product( |
| [True, False], [(x1, w1, b1), (x2, w2, b2), (x3, w3, b3)] |
| ): |
| bias = b if has_bias else None |
| model = torch.jit.trace(FunctionalLinear(weight, bias), [x]) |
| for node in model.graph.nodes(): |
| if node.kind() == "aten::matmul": |
| source_range_1 = node.sourceRange() |
| torch._C._jit_pass_fuse_linear(model.graph) |
| for node in model.graph.nodes(): |
| if node.kind() == "aten::linear": |
| source_range_2 = node.sourceRange() |
| FileCheck().check("aten::linear").run(model.graph) |
| check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("] |
| for cn in check_not: |
| FileCheck().check_not(cn).run(model.graph) |
| # make sure it runs |
| self.assertTrue(source_range_1 == source_range_2) |
| model(x) |
| |
| # check matmuls are not fused |
| class Matmul(torch.nn.Module): |
| def __init__(self, weight): |
| super(Matmul, self).__init__() |
| self.weight = weight |
| |
| def forward(self, x): |
| return torch.matmul(x, self.weight) |
| |
| x = torch.rand(5, 6, 5) |
| w = torch.rand(5, 5, 100) |
| model = torch.jit.trace(Matmul(w), [x]) |
| torch._C._jit_pass_fuse_linear(model.graph) |
| # check 3d matmul is not fused |
| FileCheck().check("aten::matmul").run(model.graph) |
| FileCheck().check_not("aten::linear").run(model.graph) |
| # make sure it runs |
| model(x) |
| |
| def test_insert_observers(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # for input and output of conv |
| assert len(attrs_with_prefix(m, "_observer_")) == 2 |
| # for weight |
| assert len(attrs_with_prefix(m.conv, "_observer_")) == 1 |
| |
| def test_insert_observers_interface(self): |
| @torch.jit.interface |
| class SubInterface(torch.nn.Module): |
| def addOne(self, inp) -> torch.Tensor: |
| pass |
| |
| class Sub(torch.nn.Module): |
| def __init__(self): |
| super(Sub, self).__init__() |
| self.fc = torch.nn.Linear(5, 5) |
| |
| def addOne(self, inp): |
| return self.fc(inp) + 1 |
| |
| def forward(self, x): |
| return self.addOne(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| self.sub = Sub() |
| |
| def forward(self, x): |
| return self.sub(self.conv(x)) |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"sub.conv": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| |
| def test_insert_observers_interface_unshare_type(self): |
| @torch.jit.interface |
| class OperatorIf(nn.Module): |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class Operator(nn.Module): |
| def __init__(self, a): |
| super().__init__() |
| self.a = a |
| |
| def forward(self, inp: torch.Tensor) -> torch.Tensor: |
| return self.a * (inp + self.a) |
| |
| class Inner(nn.Module): |
| op: OperatorIf |
| |
| def __init__(self, op): |
| super().__init__() |
| self.op = op |
| |
| def forward(self, inp): |
| return self.op(inp) |
| |
| class Outer(nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.inner_a = Inner(Operator(1)) |
| self.inner_b = Inner(Operator(3.0)) |
| |
| def forward(self, inp): |
| return self.inner_a(inp) + self.inner_b(inp) |
| |
| qconfig_dict = {"inner_a": default_qconfig, "inner_b": default_qconfig} |
| |
| eager_model = Outer() |
| for tracing in [True, False]: |
| x = torch.rand(3) |
| script_model = get_script_module(eager_model, tracing, x) |
| # make sure it runs |
| prepare_jit(script_model, qconfig_dict) |
| |
| def test_insert_observers_child_qconfig(self): |
| class Sub(torch.nn.Module): |
| def __init__(self): |
| super(Sub, self).__init__() |
| self.fc = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| self.sub = Sub() |
| |
| def forward(self, x): |
| return self.sub(self.conv(x)) |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"sub.fc": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # input and output of sub |
| assert len(attrs_with_prefix(m, "_observer_")) == 2 |
| # not quantized |
| assert len(attrs_with_prefix(m.conv, "_observer_")) == 0 |
| # no observers since we observe in the outer most call site |
| assert len(attrs_with_prefix(m.sub, "_observer_")) == 0 |
| # weight of linear |
| assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1 |
| |
| @unittest.skipUnless( |
| "fbgemm" in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| def test_insert_observers_skip_values(self): |
| class ConvFunctionalReLU(torch.nn.Module): |
| def __init__(self): |
| super(ConvFunctionalReLU, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| |
| def forward(self, x): |
| return F.relu(self.conv(x)) |
| |
| class ConvReLUModule(torch.nn.Module): |
| def __init__(self): |
| super(ConvReLUModule, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(self.conv(x)) |
| |
| class AddReLUModule(torch.nn.Module): |
| def __init__(self): |
| super(AddReLUModule, self).__init__() |
| self.relu = torch.nn.ReLU() |
| self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x): |
| out = self.conv(x) |
| out += x |
| return self.relu(out) |
| |
| class AddFunctionalReLU(torch.nn.Module): |
| def __init__(self): |
| super(AddFunctionalReLU, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x): |
| out = self.conv(x) |
| out += x |
| return F.relu(out) |
| |
| def attrs_with_prefix(module, prefix): |
| return [x for x, _ in module._modules._c.items() if x.startswith(prefix)] |
| |
| qconfig_dict = {"": default_qconfig} |
| m = torch.jit.script(ConvFunctionalReLU()) |
| m = prepare_jit(m, qconfig_dict) |
| # observer for weight of conv |
| assert len(attrs_with_prefix(m.conv, "_observer_")) == 1 |
| # observer for input of conv and output of relu |
| assert len(attrs_with_prefix(m, "_observer_")) == 2 |
| |
| m = torch.jit.script(ConvReLUModule()) |
| m = prepare_jit(m, qconfig_dict) |
| # observer for input of conv and output of relu |
| assert len(attrs_with_prefix(m, "_observer_")) == 2 |
| # observer for weight of conv |
| assert len(attrs_with_prefix(m.conv, "_observer_")) == 1 |
| # observer for output of relu |
| assert len(attrs_with_prefix(m.relu, "_observer_")) == 0 |
| |
| m = torch.jit.script(AddReLUModule()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| assert len(attrs_with_prefix(m, "_observer")) == 3 |
| assert len(attrs_with_prefix(m.relu, "_observer")) == 0 |
| FileCheck().check("aten::add_").check_not( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check("ReLU = prim::GetAttr").run(str(get_forward_graph(m._c))) |
| |
| m = torch.jit.script(AddFunctionalReLU()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| assert len(attrs_with_prefix(m, "_observer")) == 3 |
| FileCheck().check("aten::add_").check_not( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check("CallFunction").check('Observer = prim::GetAttr[name="_observer_').run( |
| str(get_forward_graph(m._c)) |
| ) |
| |
| def test_insert_observers_weight_dtype(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| |
| def forward(self, x): |
| return F.relu(self.conv(x)) |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| activation_dtypes = set( |
| obs.getattr("dtype") |
| for x, obs in m._modules._c.items() |
| if x.startswith("_observer_") |
| ) |
| weight_dtypes = set( |
| obs.getattr("dtype") |
| for x, obs in m.conv._modules._c.items() |
| if x.startswith("_observer_") |
| ) |
| assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype" |
| assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype" |
| assert ( |
| list(activation_dtypes)[0] != list(weight_dtypes)[0] |
| ), "Expected activation dtype to " |
| " be different from wegiht dtype" |
| |
| def test_insert_observers_for_reused_weight(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| |
| def forward(self, x, y, weight): |
| x = F.conv2d(x, weight) |
| y = F.conv2d(y, weight) |
| return x + y |
| |
| m = torch.jit.script(M()).eval() |
| m = prepare_jit(m, {"": default_qconfig}) |
| # 3 for x, y, weight, one for output of each F.conv2d and one for output of add |
| assert len(attrs_with_prefix(m, "_observer")) == 6 |
| |
| def test_insert_observers_shared_class_type(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 5, 3).float() |
| self.conv2 = torch.nn.Conv2d(3, 5, 3).float() |
| |
| def forward(self, x): |
| return self.conv2(self.conv1(x)) |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # conv1 and conv2 shares the same type, we need to |
| # make sure we didn't quantize the type twice |
| conv1_observers = attrs_with_prefix(m.conv1, "_observer_") |
| conv2_observers = attrs_with_prefix(m.conv2, "_observer_") |
| assert len(conv1_observers) == 1, "Expected to have 1 observer submodules" |
| assert len(conv2_observers) == 1, "Expected to have 1 observer submodules" |
| assert ( |
| conv1_observers == conv2_observers |
| ), "Expect conv1 and conv2 to have same observers since the class type is shared" |
| |
| def test_insert_observers_for_general_ops(self): |
| """Make sure we skip observers for ops that doesn't require |
| observation, e.g. flatten |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = torch.flatten(x) |
| return x |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # input and output of conv |
| assert len(attrs_with_prefix(m, "_observer_")) == 2 |
| FileCheck().check('Observer = prim::GetAttr[name="_observer_').check( |
| 'prim::GetAttr[name="conv"]' |
| ).check("prim::CallMethod").check( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check( |
| "aten::flatten" |
| ).check_not( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).run( |
| m.graph |
| ) |
| |
| # TODO: this is too long, split this to test_insert_observers.py and remove |
| # insrt_observers prefix |
| def test_insert_observers_propagate_observed(self): |
| """Make sure we propagate observed property through general ops""" |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = torch.flatten(x) |
| # we don't want to insert observer for input of self.conv2 |
| # because output of self.conv1 is already observed |
| x = self.conv2(x) |
| return x |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # input and output of conv |
| assert len(attrs_with_prefix(m, "_observer_")) == 3 |
| FileCheck().check('Observer = prim::GetAttr[name="_observer_').check( |
| 'prim::GetAttr[name="conv1"]' |
| ).check("prim::CallMethod").check( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check( |
| "aten::flatten" |
| ).check_not( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check( |
| 'prim::GetAttr[name="conv2"]' |
| ).check( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).run( |
| m.graph |
| ) |
| |
| def test_insert_observers_propagate_observed_in_submodule(self): |
| """Make sure we propagate observed property through general ops""" |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 3).float() |
| self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = self.avgpool(x) |
| # we don't want to insert observer for input of self.conv2 |
| # because output of self.conv1 is already observed |
| x = self.conv2(x) |
| return x |
| |
| m = torch.jit.script(M()) |
| qconfig_dict = {"": default_qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # input and output of conv |
| assert len(attrs_with_prefix(m, "_observer_")) == 3 |
| FileCheck().check('Observer = prim::GetAttr[name="_observer_').check( |
| 'prim::GetAttr[name="conv1"]' |
| ).check("prim::CallMethod").check( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check( |
| "prim::CallMethod" |
| ).check_not( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check( |
| 'prim::GetAttr[name="conv2"]' |
| ).check( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).run( |
| m.graph |
| ) |
| |
| def test_insert_observers_propagate_observed_for_function(self): |
| def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor: |
| batchsize, num_channels, height, width = x.data.size() |
| channels_per_group = num_channels // groups |
| # reshape |
| x = x.view(batchsize, groups, channels_per_group, height, width) |
| x = torch.transpose(x, 1, 2).contiguous() |
| # flatten |
| x = x.view(batchsize, -1, height, width) |
| return x |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 1).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 1).float() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| x = channel_shuffle(x, 1) |
| x = self.conv2(x) |
| return x |
| |
| data = [ |
| ( |
| torch.rand((1, 3, 10, 10), dtype=torch.float), |
| torch.randint(0, 1, (1,), dtype=torch.long), |
| ) |
| for _ in range(2) |
| ] |
| m = torch.jit.script(M()).eval() |
| m = prepare_jit(m, {"": default_qconfig}) |
| # we want to test that channel_shuffle is going to pass |
| # the observed property from the output of conv1 to input of conv2 |
| # so that we don't insert observers for input of conv2 |
| assert ( |
| len( |
| attrs_with_prefix( |
| m, |
| "_observer_", |
| ) |
| ) |
| == 3 |
| ) |
| |
| def test_insert_observers_for_if(self): |
| class QuantProp(torch.nn.Module): |
| def __init__(self, use_skip): |
| super(QuantProp, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 1).float() |
| self.use_skip = use_skip |
| |
| def forward(self, x): |
| if self.use_skip: |
| x = self.conv(x) |
| return torch.reshape(x, x.shape) |
| else: |
| x = self.conv(x) |
| return torch.reshape(x, x.shape) |
| |
| class Res(torch.nn.Module): |
| def __init__(self, use_skip): |
| super(Res, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 1).float() |
| self.use_skip = use_skip |
| |
| def forward(self, x): |
| if self.use_skip: |
| return self.conv(x) |
| else: |
| return self.conv(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.quant_prop = QuantProp(True) |
| self.res = Res(False) |
| |
| def forward(self, x): |
| x = self.quant_prop(x) |
| x = self.res(x) |
| return x |
| |
| data = [torch.rand(1, 3, 10, 10, dtype=torch.float)] |
| result = {False: [1, 2, 2], True: [2, 1, 0]} |
| for tracing in [True, False]: |
| if tracing: |
| m = torch.jit.trace(M(), data).eval() |
| else: |
| m = torch.jit.script(M()).eval() |
| m = prepare_jit(m, {"": default_qconfig}) |
| assert ( |
| len( |
| attrs_with_prefix( |
| m, |
| "_observer_", |
| ) |
| ) |
| == result[tracing][0] |
| ) |
| assert ( |
| len( |
| attrs_with_prefix( |
| m.quant_prop, |
| "_observer_", |
| ) |
| ) |
| == result[tracing][1] |
| ) |
| assert ( |
| len( |
| attrs_with_prefix( |
| m.res, |
| "_observer_", |
| ) |
| ) |
| == result[tracing][2] |
| ) |
| |
| def test_insert_observers_for_nested_if(self): |
| class Res(torch.nn.Module): |
| def __init__(self, use_skip): |
| super(Res, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 1).float() |
| self.cond = use_skip |
| self.use_skip = use_skip |
| |
| def forward(self, x): |
| if self.use_skip: |
| if self.cond: |
| return self.conv(x) |
| else: |
| return self.conv(x) |
| else: |
| return self.conv(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.res1 = Res(True) |
| self.res2 = Res(False) |
| |
| def forward(self, x): |
| x = self.res1(x) |
| x = self.res2(x) |
| return x |
| |
| data = torch.rand((1, 3, 10, 10), dtype=torch.float) |
| result = {True: 3, False: 1} |
| for tracing in [True, False]: |
| if tracing: |
| m = torch.jit.trace(M(), data).eval() |
| else: |
| m = torch.jit.script(M()).eval() |
| m = prepare_jit(m, {"": default_qconfig}) |
| assert len(attrs_with_prefix(m, "_observer_")) == result[tracing] |
| |
| def test_insert_observers_for_if_consistent_observation(self): |
| """check quantization for if works as long as |
| output of all branches are quantized/observed consistently |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self, cond): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| self.cond = cond |
| |
| def forward(self, x): |
| x = self.conv(x) |
| # x is already observed |
| if self.cond: |
| x = torch.flatten(x) |
| return x |
| |
| class M2(torch.nn.Module): |
| def __init__(self, cond): |
| super(M2, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 3).float() |
| self.cond = cond |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| if self.cond: |
| x = self.conv2(x) |
| # x will be observed in the branch |
| else: |
| x = torch.flatten(x) |
| # since output for both branch are quantized |
| # the if node is quantized consistently |
| return x |
| |
| data = torch.rand((1, 3, 5, 5), dtype=torch.float) |
| options = list(itertools.product([True, False], [True, False])) |
| for cond, tracing in options: |
| if tracing: |
| m = torch.jit.trace(M(cond), data) |
| else: |
| m = torch.jit.script(M(cond)) |
| m = prepare_jit(m, {"": default_qconfig}) |
| assert len(attrs_with_prefix(m, "_observer_")) == 2 |
| |
| for cond, tracing in options: |
| if tracing: |
| m = torch.jit.trace(M2(cond), data) |
| else: |
| m = torch.jit.script(M2(cond)) |
| m = prepare_jit(m, {"": default_qconfig}) |
| num_observers = 2 if tracing and not cond else 3 |
| assert len(attrs_with_prefix(m, "_observer_")) == num_observers |
| |
| def test_insert_quant_dequant(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3).float() |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| for is_per_channel in [True, False]: |
| m = torch.jit.script(M()) |
| observer = ( |
| default_per_channel_weight_observer.with_args(ch_axis=1) |
| if is_per_channel |
| else default_observer |
| ) |
| qconfig_dict = {"": QConfig(activation=observer, weight=observer)} |
| m = prepare_jit(m, qconfig_dict) |
| data = torch.randn(1, 3, 10, 10, dtype=torch.float) |
| |
| m(data) |
| m = convert_jit(m, debug=True) |
| assert ( |
| len(m._modules._c.items()) == 1 |
| ), "Expected to have single submodule of conv" |
| # make sure the quantized model is executable |
| m(data) |
| quant_func = ( |
| "aten::quantize_per_channel" |
| if is_per_channel |
| else "aten::quantize_per_tensor" |
| ) |
| FileCheck().check_count(quant_func, 3, exactly=True).run(m.graph) |
| |
| def test_insert_quant_dequant_shared_class_type(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x): |
| return self.conv2(self.conv1(x)) |
| |
| for is_per_channel in [True, False]: |
| m = torch.jit.script(M()) |
| observer = ( |
| default_per_channel_weight_observer.with_args(ch_axis=1) |
| if is_per_channel |
| else default_observer |
| ) |
| qconfig = QConfig(activation=observer, weight=observer) |
| qconfig_dict = {"": qconfig} |
| m = prepare_jit(m, qconfig_dict) |
| # observers for input, output and value between conv1/conv2 |
| assert ( |
| len(attrs_with_prefix(m, "_observer_")) == 3 |
| ), "Expected to have 3 obervers" |
| # observer for weight |
| assert ( |
| len(attrs_with_prefix(m.conv1, "_observer_")) == 1 |
| ), "Expected to have 1 obervers" |
| # observer for weight |
| assert ( |
| len(attrs_with_prefix(m.conv2, "_observer_")) == 1 |
| ), "Expected to have 1 obervers" |
| |
| data = torch.randn(1, 3, 10, 10, dtype=torch.float) |
| m(data) |
| m = convert_jit(m, debug=True) |
| m(data) |
| assert m.conv1._c._type() == m.conv2._c._type() |
| |
| # check all observers have been removed |
| assert ( |
| len(attrs_with_prefix(m, "_observer_")) == 0 |
| ), "Expected to have 0 obervers" |
| assert ( |
| len(attrs_with_prefix(m.conv1, "_observer_")) == 0 |
| ), "Expected to have 0 obervers" |
| assert ( |
| len(attrs_with_prefix(m.conv2, "_observer_")) == 0 |
| ), "Expected to have 0 obervers" |
| |
| quant_func = ( |
| "aten::quantize_per_channel" |
| if is_per_channel |
| else "aten::quantize_per_tensor" |
| ) |
| for module in ["conv1", "conv2"]: |
| conv = m._c.getattr(module) |
| # quantize weight |
| FileCheck().check(quant_func).check_next("aten::dequantize").check( |
| 'prim::CallMethod[name="_conv_forward"]' |
| ).check("return").run(get_forward_graph(conv)) |
| # no quantize node in _conv_forward |
| FileCheck().check_not(quant_func).check("aten::conv2d").check_not( |
| quant_func |
| ).check("return").run(conv._get_method("_conv_forward").graph) |
| |
| def test_dedup_module_uses(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.relu(x) |
| x -= 0.5 |
| return self.relu(x) |
| |
| data = torch.randn((2, 2)) |
| m = torch.jit.script(M()) |
| ref_res = m(data) |
| assert ( |
| len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 1 |
| ), "Expected to have 1 relu modules after dedup module uses" |
| torch._C._jit_pass_dedup_module_uses(m._c) |
| m = torch.jit._recursive.wrap_cpp_module(m._c) |
| res = m(data) |
| assert ( |
| len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 2 |
| ), "Expected to have 2 relu modules after dedup module uses" |
| self.assertEqual(res, ref_res) |
| |
| def test_replicate_dequantize(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 1).float() |
| |
| def forward(self, x): |
| x = torch.dequantize(x) |
| r = self.conv(x) |
| r += x |
| return r |
| |
| x = torch.randn([1, 3, 10, 10], dtype=torch.float) |
| x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8) |
| m = torch.jit.script(M()) |
| ref_res = m(x) |
| FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) |
| torch._C._jit_pass_replicate_dequantize(m.graph) |
| FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph) |
| res = get_forward(m._c)(x) |
| self.assertEqual(res, ref_res) |
| |
| def test_replicate_dequantize_in_block(self): |
| class M(torch.nn.Module): |
| def __init__(self, cond): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 1).float() |
| |
| self.cond = cond |
| |
| def forward(self, x): |
| x = torch.dequantize(x) |
| if self.cond: |
| x = self.conv(x) |
| else: |
| x = x + 3 |
| return x |
| |
| x = torch.randn([1, 3, 10, 10], dtype=torch.float) |
| x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8) |
| m = torch.jit.script(M(True)) |
| ref_res = m(x) |
| FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) |
| torch._C._jit_pass_replicate_dequantize(m.graph) |
| FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph) |
| # check dequantize is right before CallMethod of conv |
| FileCheck().check("aten::dequantize").check_next("CallMethod").run(m.graph) |
| # check dequantize is right before add |
| FileCheck().check("aten::dequantize").check("aten::dequantize").check_next( |
| "aten::add" |
| ).run(m.graph) |
| res = get_forward(m._c)(x) |
| self.assertEqual(res, ref_res) |
| |
| def test_swap_functional_linear(self): |
| # TODO: This pass replaces any function called "linear" with "aten::linear" |
| # No longer necessary, and also quite surprising |
| def linear(input, weight, bias): |
| return torch.nn.functional.linear(input, weight, bias) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| |
| def forward(self, x, weight, bias): |
| x = torch.dequantize(x) |
| weight = torch.dequantize(weight) |
| x = linear(x, weight, bias) |
| x = torch.quantize_per_tensor( |
| x, scale=1.0, zero_point=0, dtype=torch.quint8 |
| ) |
| return x |
| |
| x = torch.rand((10, 5), dtype=torch.float) |
| x = torch.quantize_per_tensor(x, scale=0.5, zero_point=1, dtype=torch.quint8) |
| weight = torch.rand((5, 5), dtype=torch.float) |
| weight = torch.quantize_per_tensor( |
| weight, scale=0.5, zero_point=1, dtype=torch.qint8 |
| ) |
| bias = torch.rand((5), dtype=torch.float) |
| m = torch.jit.script(M()) |
| ref_res = m(x, weight, bias) |
| FileCheck().check("CallFunction").run(m.graph) |
| torch._C._jit_pass_swap_functional_linear(m.graph) |
| FileCheck().check("aten::linear").check_not("CallFunction").run(m.graph) |
| res = m(x, weight, bias) |
| self.assertEqual(res, ref_res) |
| |
| def test_replicate_quantize_for_if(self): |
| """We want to move quantize nodes for output of prim::If |
| inside the prim::If blocks so that we can match quantization |
| patterns. |
| """ |
| |
| class Res(torch.nn.Module): |
| def __init__(self): |
| super(Res, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 1).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 1).float() |
| self.use_skip = True |
| |
| def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor: |
| # to avoid being frozen |
| self.use_skip = cond |
| if self.use_skip: |
| return self.conv(x) |
| else: |
| return self.conv2(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.res1 = Res() |
| self.res2 = Res() |
| |
| def forward(self, x): |
| x = self.res1(x, True) |
| x = self.res2(x, False) |
| return x |
| |
| data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]] |
| qconfig_dict = {"": default_qconfig} |
| m = torch.jit.script(M()).eval() |
| m = quantize_jit(m, qconfig_dict, test_only_eval_fn, [data]) |
| # make sure patterns in both branches are fused |
| FileCheck().check_count("quantized::conv2d(", 4, exactly=True).run(m.graph) |
| |
| def test_finalize_for_linear(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc = torch.nn.Linear(5, 5).float() |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| data = [[torch.rand((1, 5), dtype=torch.float)]] |
| qconfig_dict = {"": default_qconfig} |
| model = torch.jit.script(M()).eval() |
| model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data]) |
| # make sure there is only one quantize_per_tensor for input |
| # and linear_prepack is folded |
| FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).check_not( |
| "quantized::linear_prepack" |
| ).check("quantized::linear").run(model.graph) |
| |
| def test_inplace_option(self): |
| for tracing in [True, False]: |
| model = get_script_module( |
| torch.nn.Conv2d(3, 3, 3).float(), tracing, self.img_data_2d[0][0] |
| ) |
| qconfig_dict = {"": default_qconfig} |
| quantize_jit( |
| model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True |
| ) |
| FileCheck().check("quantized::conv2d").run(model.graph) |
| |
| FileCheck().check_not("aten::conv2d").run(model.graph) |
| |
| def test_finalize_debug(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| self.avgpool = torch.nn.AvgPool2d(3) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.avgpool(x) |
| return x |
| |
| data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]] |
| qconfig_dict = {"": default_qconfig} |
| model = torch.jit.script(M()).eval() |
| model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data], debug=True) |
| FileCheck().check_not("quantized::conv2d").check("aten::conv2d").check( |
| "aten::avg_pool2d" |
| ).check("aten::q_scale").check_next("aten::q_zero_point").check_next( |
| "prim::dtype" |
| ).check_next( |
| "aten::quantize_per_tensor" |
| ).check( |
| "aten::dequantize" |
| ).run( |
| model.graph |
| ) |
| |
| def test_module_list(self): |
| class SimpleLinearLayer(torch.nn.Module): |
| def __init__(self): |
| super(SimpleLinearLayer, self).__init__() |
| self.fc = torch.nn.Linear(5, 5).float() |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| class ComplexModel(torch.nn.Module): |
| def __init__(self): |
| super(ComplexModel, self).__init__() |
| self.layers = torch.nn.ModuleList( |
| [SimpleLinearLayer() for i in range(2)] |
| ) |
| |
| def forward(self, x: torch.Tensor) -> List[torch.Tensor]: |
| states = [] |
| for layer in self.layers: |
| val = layer(x) |
| states.append(val) |
| return states |
| |
| data = torch.rand((1, 5), dtype=torch.float) |
| qconfig_dict = {"": default_qconfig} |
| model = torch.jit.script(ComplexModel()).eval() |
| model = prepare_jit(model, qconfig_dict) |
| assert len(attrs_with_prefix(model, "_observer")) == 3 |
| model(data) |
| model = convert_jit(model, debug=False) |
| FileCheck().check("quantized::linear").check("quantized::linear").run( |
| model.graph |
| ) |
| |
| def test_conv_trace(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1d = torch.nn.Conv1d(3, 3, 3).float() |
| self.conv2d = torch.nn.Conv2d(3, 3, 3).float() |
| self.conv3d = torch.nn.Conv3d(3, 3, 3).float() |
| |
| def forward(self, x, y, z): |
| a = self.conv1d(x) |
| b = self.conv2d(y) |
| c = self.conv3d(z) |
| return (a, b, c) |
| |
| qconfig_dict = {"": default_qconfig} |
| inputs = ( |
| torch.rand((1, 3, 10), dtype=torch.float), |
| torch.rand((1, 3, 10, 10), dtype=torch.float), |
| torch.rand((1, 3, 10, 10, 10), dtype=torch.float), |
| ) |
| model = torch.jit.trace(M(), inputs).eval() |
| m = prepare_jit(model, qconfig_dict) |
| FileCheck().check("aten::conv1d").check_not("aten::_convolution").run( |
| str(get_forward_graph(m.conv1d._c)) |
| ) |
| FileCheck().check("aten::conv2d").check_not("aten::_convolution").run( |
| str(get_forward_graph(m.conv2d._c)) |
| ) |
| FileCheck().check("aten::conv3d").check_not("aten::_convolution").run( |
| str(get_forward_graph(m.conv3d._c)) |
| ) |
| |
| def test_convtranspose_trace(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.convtranspose1d = torch.nn.ConvTranspose1d(3, 3, 3).float() |
| self.convtranspose2d = torch.nn.ConvTranspose2d(3, 3, 3).float() |
| self.convtranspose3d = torch.nn.ConvTranspose3d(3, 3, 3).float() |
| |
| def forward(self, x, y, z): |
| a = self.convtranspose1d(x) |
| b = self.convtranspose2d(y) |
| c = self.convtranspose3d(z) |
| return (a, b, c) |
| |
| qconfig_dict = {"": default_qconfig} |
| inputs = ( |
| torch.rand((1, 3, 10), dtype=torch.float), |
| torch.rand((1, 3, 10, 10), dtype=torch.float), |
| torch.rand((1, 3, 10, 10, 10), dtype=torch.float), |
| ) |
| model = torch.jit.trace(M(), inputs).eval() |
| m = prepare_jit(model, qconfig_dict) |
| FileCheck().check("aten::conv_transpose1d").check_not("aten::_convolution").run( |
| str(get_forward_graph(m.convtranspose1d._c)) |
| ) |
| FileCheck().check("aten::conv_transpose2d").check_not("aten::_convolution").run( |
| str(get_forward_graph(m.convtranspose2d._c)) |
| ) |
| FileCheck().check("aten::conv_transpose3d").check_not("aten::_convolution").run( |
| str(get_forward_graph(m.convtranspose3d._c)) |
| ) |
| |
| @unittest.skipUnless( |
| "fbgemm" in torch.backends.quantized.supported_engines, |
| " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs" |
| " with instruction set support avx2 or newer.", |
| ) |
| def test_replicate_dequant_same_value(self): |
| class Mul(torch.nn.Module): |
| def __init__(self): |
| super(Mul, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x * x |
| |
| data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]] |
| qconfig_dict = {"": default_qconfig} |
| model = torch.jit.script(Mul()).eval() |
| m = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data]) |
| FileCheck().check("quantized::mul(").check_not("aten::mul").run(m.graph) |
| |
| def test_interface_with_fork(self): |
| class SubModule(torch.nn.Module): |
| def __init__(self): |
| super(SubModule, self).__init__() |
| self.embedding1 = torch.nn.EmbeddingBag( |
| num_embeddings=10, |
| embedding_dim=12, |
| include_last_offset=True, |
| sparse=False, |
| mode="sum", |
| ) |
| |
| def forward(self, x, y): |
| return self.embedding1(x, y) |
| |
| class OrigMod(torch.nn.Module): |
| def __init__(self): |
| super(OrigMod, self).__init__() |
| self.embedding1 = torch.nn.EmbeddingBag( |
| num_embeddings=10, |
| embedding_dim=12, |
| include_last_offset=True, |
| sparse=False, |
| mode="sum", |
| ) |
| |
| def forward(self, x, y): |
| return self.embedding1(x, y) |
| |
| @torch.jit.interface |
| class ModInterface(torch.nn.Module): |
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| pass |
| |
| class TestModule(torch.nn.Module): |
| proxy_mod: ModInterface |
| |
| def __init__(self): |
| super(TestModule, self).__init__() |
| self.proxy_mod = OrigMod() |
| self.sub = SubModule() |
| |
| def forward(self, x, y): |
| a = self.proxy_mod(x, y) |
| b = self.sub(x, y) |
| return b |
| |
| class MainModule(torch.nn.Module): |
| def __init__(self): |
| super(MainModule, self).__init__() |
| self.test = TestModule() |
| |
| def forward(self, x, y): |
| fut = torch.jit._fork(self.test.forward, x, y) |
| z = torch.jit._wait(fut) |
| return z |
| |
| indices = torch.tensor( |
| [ |
| 9, |
| 6, |
| 5, |
| 7, |
| 8, |
| 8, |
| 9, |
| 2, |
| 8, |
| 6, |
| 6, |
| 9, |
| 1, |
| 6, |
| 8, |
| 8, |
| 3, |
| 2, |
| 3, |
| 6, |
| 3, |
| 6, |
| 5, |
| 7, |
| 0, |
| 8, |
| 4, |
| 6, |
| 5, |
| 8, |
| 2, |
| 3, |
| ] |
| ) |
| offsets = torch.tensor([0, 19, 20, 28, 28, 32]) |
| m = torch.jit.trace(MainModule(), (indices, offsets)) |
| m.eval() |
| |
| int8_qconfig = QConfig( |
| activation=PlaceholderObserver.with_args( |
| dtype=torch.float, custom_op_name="embedding_bag_byte" |
| ), |
| weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"), |
| ) |
| |
| m = prepare_jit(m, {"": int8_qconfig}) |
| m = convert_jit(m) |
| FileCheck().check("quantized::embedding_bag_byte_rowwise_offsets").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantize_fork_wait(self): |
| """Tests the case where fork and wait calls are in different subgraphs |
| Calling inline fork-wait only removes the fork call and leaves aten::wait |
| calls in the graph, with Tensor as input (instead of Future[Tensor]) |
| """ |
| |
| class MainModule(nn.Module): |
| def __init__(self): |
| super(MainModule, self).__init__() |
| self.fork_ops = ForkModule() |
| |
| def init_values(self, x): |
| shared_module = self.fork_ops(x) |
| self.fork_dict = shared_module |
| |
| def forward(self, x): |
| val = torch.jit._wait(self.fork_ops(x)) |
| return val |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self): |
| super(TestModule, self).__init__() |
| |
| def forward(self, x): |
| w = torch.ones(5, 5) |
| b = torch.zeros(5) |
| return torch.nn.functional.linear(x, w, b) |
| |
| class ForkModule(nn.Module): |
| def __init__(self): |
| super(ForkModule, self).__init__() |
| self.test = TestModule() |
| |
| def forward(self, x): |
| fut = torch.jit._fork(self.test.forward, x) |
| return fut |
| |
| model = MainModule().eval() |
| traced = torch.jit.trace(model, (torch.randn(5, 5),)) |
| model = prepare_dynamic_jit(traced, {"": default_qconfig}) |
| model = convert_dynamic_jit(model) |
| FileCheck().check("quantized::linear_dynamic").run(model.graph) |
| # Make sure model save works |
| b = io.BytesIO() |
| torch.jit.save(model, b) |
| |
| |
| @skipIfSlowGradcheckEnv |
| class TestQuantizeJitOps(QuantizationTestCase): |
| """Test graph mode post training static quantization works |
| for individual ops end to end. |
| """ |
| |
| @skipIfNoFBGEMM |
| def test_linear(self): |
| class ModuleLinear(torch.nn.Module): |
| def __init__(self, has_relu=False, f_relu=False): |
| super(ModuleLinear, self).__init__() |
| self.linear = torch.nn.Linear(30, 4).float() |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| return self.relu(self.linear(x)) |
| |
| class FuncLinear(torch.nn.Module): |
| def __init__(self, has_relu=False, f_relu=False): |
| super(FuncLinear, self).__init__() |
| self.w = torch.randn(4, 30) |
| self.b = torch.randn(4) |
| if has_relu: |
| if f_relu: |
| self.relu = F.relu |
| else: |
| self.relu = torch.nn.ReLU() |
| else: |
| self.relu = torch.nn.Identity() |
| |
| def forward(self, x): |
| return self.relu(F.linear(x, self.w, self.b)) |
| |
| data = [[torch.rand((1, 30), dtype=torch.float)]] |
| for model, tracing in itertools.product( |
| [ModuleLinear(has_relu=False), FuncLinear(has_relu=False)], [True, False] |
| ): |
| model = self.checkGraphModeOp(model, data, "quantized::linear", tracing) |
| FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( |
| model.graph |
| ) |
| FileCheck().check_not("quantized::linear_prepack").run(model.graph) |
| |
| for f_relu, tracing in itertools.product([True, False], [True, False]): |
| for model in [ |
| ModuleLinear(has_relu=True, f_relu=f_relu), |
| FuncLinear(has_relu=True, f_relu=f_relu), |
| ]: |
| model = self.checkGraphModeOp( |
| model, data, "quantized::linear_relu", tracing |
| ) |
| checker = ( |
| FileCheck() |
| .check_not("aten::linear") |
| .check_not("aten::relu") |
| .check_not("quantized::linear(") |
| .check_not("quantized::relu(") |
| .run(model.graph) |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_conv(self): |
| conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| class Conv(torch.nn.Module): |
| def __init__(self, dim): |
| super(Conv, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| options = itertools.product([1, 2, 3], [True, False]) |
| for dim, tracing in options: |
| model = self.checkGraphModeOp( |
| Conv(dim), |
| self.img_data_dict[dim], |
| "quantized::conv{}d".format(dim), |
| tracing, |
| ) |
| # make sure there is only one quantize_per_tensor for input |
| # and conv2d_prepack is folded |
| FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( |
| model.graph |
| ) |
| |
| FileCheck().check_not("quantized::conv{}d_prepack".format(dim)).run( |
| model.graph |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_conv_relu(self): |
| """tests for conv1d_relu/conv2d_relu/conv3d_relu""" |
| conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d} |
| |
| class ConvNdRelu(torch.nn.Module): |
| def __init__(self, dim, inplace): |
| super(ConvNdRelu, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| return self.relu(self.conv(x)) |
| |
| class ConvNdFunctionalRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(ConvNdFunctionalRelu, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return F.relu(self.conv(x)) |
| |
| class ConvNdInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(ConvNdInplaceFunctionalRelu, self).__init__() |
| self.conv = conv_module[dim](3, 3, 3).float() |
| |
| def forward(self, x): |
| return F.relu(self.conv(x), True) |
| |
| options = itertools.product([1, 2, 3], [True, False]) |
| for dim, tracing in options: |
| for orig_m in [ |
| ConvNdRelu(dim, True), |
| ConvNdRelu(dim, False), |
| ConvNdFunctionalRelu(dim), |
| ConvNdInplaceFunctionalRelu(dim), |
| ]: |
| conv_name = "conv{}d".format(dim) |
| m = self.checkGraphModeOp( |
| orig_m, |
| self.img_data_dict[dim], |
| "quantized::conv{}d_relu(".format(dim), |
| tracing=tracing, |
| ) |
| |
| FileCheck().check_not("aten::conv{}d(".format(dim)).check_not( |
| "aten::relu" |
| ).check_not("quantized::conv{}d(".format(dim)).check_not( |
| "quantized::relu(" |
| ).run( |
| m.graph |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_alpha(self): |
| """Test quant fusion for multiple aten::add using same |
| constant alpha as the third argument |
| """ |
| |
| class QuantizedAdd(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedAdd, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| z = x + y |
| w = y + z |
| return z + w |
| |
| data = [ |
| [ |
| torch.randn(1, 2, 5, 5, dtype=torch.float), |
| torch.randn(1, 2, 5, 5, dtype=torch.float), |
| ] |
| ] |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp(QuantizedAdd(), data, "quantized::add", tracing) |
| FileCheck().check_count("quantized::add", 3, exactly=True).run(m.graph) |
| FileCheck().check_not("aten::add").check_not("aten::add_").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_relu_alpha(self): |
| """Test quant fusion for multiple aten::add using same |
| constant alpha as the third argument in add_relu pattern |
| """ |
| |
| class AddRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(AddRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x + y |
| x = self.relu(x) |
| x = x + y |
| return self.relu(x) |
| |
| class InplaceAddRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(InplaceAddRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| x = self.relu(x) |
| x += y |
| return self.relu(x) |
| |
| class AddFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(AddFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x + y |
| x = F.relu(x) |
| x = x + y |
| return F.relu(x) |
| |
| class InplaceAddFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceAddFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| x = F.relu(x) |
| x += y |
| return F.relu(x) |
| |
| class AddInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(AddInplaceFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x + y |
| x = F.relu(x, True) |
| x = x + y |
| return F.relu(x, True) |
| |
| class InplaceAddInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceAddInplaceFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| x = F.relu(x, True) |
| x += y |
| return F.relu(x, True) |
| |
| data = [ |
| [ |
| torch.rand((1, 2, 5, 5), dtype=torch.float), |
| torch.rand((1, 2, 5, 5), dtype=torch.float), |
| ] |
| ] |
| for m_orig in [ |
| AddRelu(True), |
| AddRelu(False), |
| InplaceAddRelu(True), |
| InplaceAddRelu(False), |
| AddFunctionalRelu(), |
| InplaceAddFunctionalRelu(), |
| AddInplaceFunctionalRelu(), |
| InplaceAddInplaceFunctionalRelu(), |
| ]: |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp( |
| m_orig, data, "quantized::add_relu(", tracing=tracing |
| ) |
| FileCheck().check_count("quantized::add_relu(", 2, exactly=True).run( |
| m.graph |
| ) |
| FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not( |
| "aten::relu(" |
| ).check_not("aten::relu_(").check_not("quantized::add(").check_not( |
| "quantized::relu(" |
| ).run( |
| m.graph |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add(self): |
| class QuantizedAdd(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedAdd, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| return x + y |
| |
| class QuantizedInplaceAdd(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedInplaceAdd, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| return x |
| |
| class NonQuantizedAdd(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedAdd, self).__init__() |
| |
| def forward(self, x, y): |
| return x + y |
| |
| class NonQuantizedInplaceAdd(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedInplaceAdd, self).__init__() |
| |
| def forward(self, x, y): |
| x += y |
| return x |
| |
| data = [ |
| [ |
| torch.randn(1, 2, 3, 3, dtype=torch.float), |
| torch.randn(1, 2, 3, 3, dtype=torch.float), |
| ] |
| ] |
| for m, quantized in [ |
| (QuantizedAdd(), True), |
| (QuantizedInplaceAdd(), True), |
| (NonQuantizedAdd(), False), |
| (NonQuantizedInplaceAdd(), False), |
| ]: |
| for tracing in [True, False]: |
| op = "quantized::add" if quantized else "aten::add" |
| m = self.checkGraphModeOp(m, data, op, tracing) |
| # TODO: remove after refactor of checkGraphModeOp |
| if quantized: |
| FileCheck().check_not("aten::add").check_not("aten::add_").run( |
| m.graph |
| ) |
| else: |
| FileCheck().check_not("quantized::add").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_scalar(self): |
| class QuantizedAddScalar(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedAddScalar, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x + 3 |
| |
| class QuantizedInplaceAddScalar(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedInplaceAddScalar, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x += 3 |
| return x |
| |
| class NonQuantizedAddScalar(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedAddScalar, self).__init__() |
| |
| def forward(self, x): |
| return x + 3 |
| |
| class NonQuantizedInplaceAddScalar(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedInplaceAddScalar, self).__init__() |
| |
| def forward(self, x): |
| x += 3 |
| return x |
| |
| data = [[torch.randn(1, 2, 3, 3, dtype=torch.float)]] |
| for m, quantized in [ |
| (QuantizedAddScalar(), True), |
| (QuantizedInplaceAddScalar(), True), |
| (NonQuantizedAddScalar(), False), |
| (NonQuantizedInplaceAddScalar(), False), |
| ]: |
| for tracing in [True, False]: |
| op = "quantized::add_scalar" if quantized else "aten::add" |
| # we don't check the numerical consistency for add_scalar |
| # since it's not supported |
| m = self.checkGraphModeOp(m, data, op, tracing, check=False) |
| # TODO: remove after refactor of checkGraphModeOp |
| if quantized: |
| FileCheck().check_not("aten::add").check_not("aten::add_").run( |
| m.graph |
| ) |
| else: |
| FileCheck().check_not("quantized::add_scalar").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_relu(self): |
| class AddRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(AddRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x + y |
| return self.relu(x) |
| |
| class InplaceAddRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(InplaceAddRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| return self.relu(x) |
| |
| class AddFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(AddFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x + y |
| return F.relu(x) |
| |
| class InplaceAddFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceAddFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| return F.relu(x) |
| |
| class AddInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(AddInplaceFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x + y |
| return F.relu(x, True) |
| |
| class InplaceAddInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceAddInplaceFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x += y |
| return F.relu(x, True) |
| |
| data = [ |
| [ |
| torch.rand((1, 2, 5, 5), dtype=torch.float), |
| torch.rand((1, 2, 5, 5), dtype=torch.float), |
| ] |
| ] |
| for m in [ |
| AddRelu(True), |
| AddRelu(False), |
| InplaceAddRelu(True), |
| InplaceAddRelu(False), |
| AddFunctionalRelu(), |
| InplaceAddFunctionalRelu(), |
| AddInplaceFunctionalRelu(), |
| InplaceAddInplaceFunctionalRelu(), |
| ]: |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp(m, data, "quantized::add_relu(", tracing) |
| FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not( |
| "aten::relu(" |
| ).check_not("aten::relu_(").check_not("quantized::add(").check_not( |
| "quantized::relu(" |
| ).run( |
| m.graph |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_add_scalar_relu(self): |
| class AddScalarRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(AddScalarRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return self.relu(x + 3) |
| |
| class InplaceAddScalarRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(InplaceAddScalarRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x += 3 |
| return self.relu(x) |
| |
| class AddScalarFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(AddScalarFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return F.relu(x + 3) |
| |
| class InplaceAddScalarFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceAddScalarFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x += 3 |
| return F.relu(x) |
| |
| class AddScalarInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(AddScalarInplaceFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return F.relu(x + 3, True) |
| |
| class InplaceAddScalarInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceAddScalarInplaceFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x += 3 |
| return F.relu(x, True) |
| |
| data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]] |
| for m in [ |
| AddScalarRelu(True), |
| AddScalarRelu(False), |
| InplaceAddScalarRelu(True), |
| InplaceAddScalarRelu(False), |
| AddScalarFunctionalRelu(), |
| InplaceAddScalarFunctionalRelu(), |
| AddScalarInplaceFunctionalRelu(), |
| InplaceAddScalarInplaceFunctionalRelu(), |
| ]: |
| for tracing in [True, False]: |
| # quantized::add_scalar_relu or quantized::add_scalar_relu_out |
| # TODO: split this after refactor of checkGraphModeOp |
| m = self.checkGraphModeOp( |
| m, data, "quantized::add_scalar_relu", tracing, check=False |
| ) |
| FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not( |
| "aten::relu(" |
| ).check_not("aten::relu_(").check_not( |
| "quantized::add_scalar(" |
| ).check_not( |
| "quantized::relu(" |
| ).run( |
| m.graph |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_cat(self): |
| """quantization of the output of cat will be depend on the |
| input of cat. we only quantize the output of cat when its inputs are quantized. |
| """ |
| |
| class QuantizedCat(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedCat, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| return torch.cat([x, y], 1) |
| |
| class NonQuantizedCat(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedCat, self).__init__() |
| |
| def forward(self, x, y): |
| return torch.cat([x, y], 1) |
| |
| data = [ |
| [ |
| torch.randn(1, 2, 5, 5, dtype=torch.float), |
| torch.randn(1, 2, 5, 5, dtype=torch.float), |
| ] |
| ] |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing) |
| FileCheck().check_not("aten::cat").run(m.graph) |
| |
| m = self.checkGraphModeOp(NonQuantizedCat(), data, "aten::cat", tracing) |
| FileCheck().check_not("quantized::cat").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm(self): |
| bn_module = { |
| 1: torch.nn.BatchNorm1d, |
| 2: torch.nn.BatchNorm2d, |
| 3: torch.nn.BatchNorm3d, |
| } |
| |
| class M(torch.nn.Module): |
| def __init__(self, dim): |
| super(M, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return self.bn(x) |
| |
| options = itertools.product([True, False], [1, 2, 3]) |
| for tracing, dim in options: |
| model = self.checkGraphModeOp( |
| M(dim), self.img_data_dict[dim], "quantized::batch_norm", tracing |
| ) |
| |
| FileCheck().check_not("aten::batch_norm").run(model.graph) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm_relu_BNRelu(self): |
| bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d} |
| |
| class BNRelu(torch.nn.Module): |
| def __init__(self, dim, inplace): |
| super(BNRelu, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| self.relu = torch.nn.ReLU(inplace=inplace) |
| |
| def forward(self, x): |
| return self.relu(self.bn(x)) |
| |
| options = itertools.product([True, False], [2, 3]) |
| for tracing, dim in options: |
| for instance in [BNRelu(dim, True), BNRelu(dim, False)]: |
| model = self.checkGraphModeOp(instance, self.img_data_dict[dim], |
| "quantized::batch_norm_relu", tracing) |
| FileCheck().check_not("aten::batch_norm") \ |
| .check_not("aten::relu") \ |
| .check_not("aten::relu_") \ |
| .run(model.graph) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm_relu_BNFuncRelu(self): |
| bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} |
| |
| class BNFuncRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(BNFuncRelu, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return F.relu(self.bn(x), False) |
| |
| options = itertools.product([True, False], [2, 3]) |
| for tracing, dim in options: |
| instance = BNFuncRelu(dim) |
| model = self.checkGraphModeOp(instance, self.img_data_dict[dim], |
| "quantized::batch_norm_relu", tracing) |
| FileCheck().check_not("aten::batch_norm") \ |
| .check_not("aten::relu") \ |
| .check_not("aten::relu_") \ |
| .run(model.graph) |
| |
| @skipIfNoFBGEMM |
| def test_qbatch_norm_relu_BNFuncInplaceRelu(self): |
| bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d} |
| |
| class BNFuncInplaceRelu(torch.nn.Module): |
| def __init__(self, dim): |
| super(BNFuncInplaceRelu, self).__init__() |
| self.bn = bn_module[dim](3).to(torch.float) |
| |
| def forward(self, x): |
| return F.relu(self.bn(x), True) |
| |
| options = itertools.product([True, False], [2, 3]) |
| for tracing, dim in options: |
| instance = BNFuncInplaceRelu(dim) |
| model = self.checkGraphModeOp(instance, self.img_data_dict[dim], |
| "quantized::batch_norm_relu", tracing) |
| FileCheck().check_not("aten::batch_norm") \ |
| .check_not("aten::relu") \ |
| .check_not("aten::relu_") \ |
| .run(model.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul(self): |
| class QuantizedMul(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedMul, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| return x * y |
| |
| class QuantizedInplaceMul(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedInplaceMul, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x *= y |
| return x |
| |
| class NonQuantizedMul(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedMul, self).__init__() |
| |
| def forward(self, x, y): |
| return x * y |
| |
| class NonQuantizedInplaceMul(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedInplaceMul, self).__init__() |
| |
| def forward(self, x, y): |
| x *= y |
| return x |
| |
| data = [ |
| [ |
| torch.randn(1, 2, 10, 10, dtype=torch.float), |
| torch.randn(1, 2, 10, 10, dtype=torch.float), |
| ] |
| ] |
| for m, quantized in [ |
| (QuantizedMul(), True), |
| (QuantizedInplaceMul(), True), |
| (NonQuantizedMul(), False), |
| (NonQuantizedInplaceMul(), False), |
| ]: |
| for tracing in [True, False]: |
| op = "quantized::mul" if quantized else "aten::mul" |
| m = self.checkGraphModeOp(m, data, op, tracing) |
| # TODO: remove after refactor of checkGraphModeOp |
| if quantized: |
| FileCheck().check_not("aten::mul").check_not("aten::mul_").run( |
| m.graph |
| ) |
| else: |
| FileCheck().check_not("quantized::mul").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul_scalar(self): |
| class QuantizedMulScalar(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedMulScalar, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return x * 3 |
| |
| class QuantizedInplaceMulScalar(torch.nn.Module): |
| def __init__(self): |
| super(QuantizedInplaceMulScalar, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x *= 3 |
| return x |
| |
| class NonQuantizedMulScalar(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedMulScalar, self).__init__() |
| |
| def forward(self, x): |
| return x * 3 |
| |
| class NonQuantizedInplaceMulScalar(torch.nn.Module): |
| def __init__(self): |
| super(NonQuantizedInplaceMulScalar, self).__init__() |
| |
| def forward(self, x): |
| x *= 3 |
| return x |
| |
| data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]] |
| for m, quantized in [ |
| (QuantizedMulScalar(), True), |
| (QuantizedInplaceMulScalar(), True), |
| (NonQuantizedMulScalar(), False), |
| (NonQuantizedInplaceMulScalar(), False), |
| ]: |
| for tracing in [True, False]: |
| op = "quantized::mul_scalar" if quantized else "aten::mul" |
| # we don't check the numerical consistency for add_scalar |
| # since it's not supported |
| m = self.checkGraphModeOp(m, data, op, tracing, check=False) |
| # TODO: remove after refactor of checkGraphModeOp |
| if quantized: |
| FileCheck().check_not("aten::mul").check_not("aten::mul_").run( |
| m.graph |
| ) |
| else: |
| FileCheck().check_not("quantized::mul_scalar").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul_relu(self): |
| class MulRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(MulRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x * y |
| return self.relu(x) |
| |
| class InplaceMulRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(InplaceMulRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x *= y |
| return self.relu(x) |
| |
| class MulFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(MulFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x * y |
| return F.relu(x) |
| |
| class InplaceMulFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceMulFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x *= y |
| return F.relu(x) |
| |
| class MulInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(MulInplaceFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x = x * y |
| return F.relu(x, True) |
| |
| class InplaceMulInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceMulInplaceFunctionalRelu, self).__init__() |
| self.conv1 = torch.nn.Conv2d(2, 2, 2).float() |
| self.conv2 = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x, y): |
| x = self.conv1(x) |
| y = self.conv2(y) |
| x *= y |
| return F.relu(x, True) |
| |
| data = [ |
| [ |
| torch.rand((1, 2, 5, 5), dtype=torch.float), |
| torch.rand((1, 2, 5, 5), dtype=torch.float), |
| ] |
| ] |
| for m in [ |
| MulRelu(True), |
| MulRelu(False), |
| InplaceMulRelu(True), |
| InplaceMulRelu(False), |
| MulFunctionalRelu(), |
| InplaceMulFunctionalRelu(), |
| MulInplaceFunctionalRelu(), |
| InplaceMulInplaceFunctionalRelu(), |
| ]: |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp(m, data, "quantized::mul_relu(", tracing) |
| FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not( |
| "aten::relu(" |
| ).check_not("aten::relu_(").check_not("quantized::mul(").check_not( |
| "quantized::relu(" |
| ).run( |
| m.graph |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_quantized_mul_scalar_relu(self): |
| class MulScalarRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(MulScalarRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return self.relu(x * 3) |
| |
| class InplaceMulScalarRelu(torch.nn.Module): |
| def __init__(self, inplace): |
| super(InplaceMulScalarRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu = torch.nn.ReLU(inplace) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x *= 3 |
| return self.relu(x) |
| |
| class MulScalarFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(MulScalarFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return F.relu(x * 3) |
| |
| class InplaceMulScalarFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceMulScalarFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x *= 3 |
| return F.relu(x) |
| |
| class MulScalarInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(MulScalarInplaceFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| return F.relu(x * 3, True) |
| |
| class InplaceMulScalarInplaceFunctionalRelu(torch.nn.Module): |
| def __init__(self): |
| super(InplaceMulScalarInplaceFunctionalRelu, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x *= 3 |
| return F.relu(x, True) |
| |
| data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]] |
| for m in [ |
| MulScalarRelu(True), |
| MulScalarRelu(False), |
| InplaceMulScalarRelu(True), |
| InplaceMulScalarRelu(False), |
| MulScalarFunctionalRelu(), |
| InplaceMulScalarFunctionalRelu(), |
| MulScalarInplaceFunctionalRelu(), |
| InplaceMulScalarInplaceFunctionalRelu(), |
| ]: |
| for tracing in [True, False]: |
| # quantized::mul_scalar_relu or quantized::mul_scalar_relu_out |
| m = self.checkGraphModeOp( |
| m, data, "quantized::mul_scalar_relu", tracing, check=False |
| ) |
| FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not( |
| "aten::relu(" |
| ).check_not("aten::relu_(").check_not( |
| "quantized::mul_scalar(" |
| ).check_not( |
| "quantized::relu(" |
| ).run( |
| m.graph |
| ) |
| |
| def test_hardswish(self): |
| class FunctionalHardswish(torch.nn.Module): |
| def __init__(self, inplace): |
| super(FunctionalHardswish, self).__init__() |
| self.inplace = inplace |
| |
| def forward(self, input): |
| return torch.nn.functional.hardswish(input, inplace=self.inplace) |
| |
| modules = [ |
| torch.nn.Hardswish(), |
| FunctionalHardswish(True), |
| FunctionalHardswish(False), |
| ] |
| |
| for test_case in itertools.product([True, False], modules): |
| tracing, m = test_case |
| m = self.checkGraphModeOp( |
| m, self.img_data_2d, "quantized::hardswish", tracing |
| ) |
| FileCheck().check_not("aten::hardswish").check_not("aten::hardswish_").run( |
| m.graph |
| ) |
| |
| def test_elu(self): |
| class FunctionalELU(torch.nn.Module): |
| def __init__(self, inplace=False): |
| super(FunctionalELU, self).__init__() |
| self.inplace = inplace |
| |
| def forward(self, input): |
| return torch.nn.functional.elu(input, inplace=self.inplace) |
| |
| modules = [torch.nn.ELU, FunctionalELU] |
| for test_case in itertools.product([True, False], [True, False], modules): |
| tracing, inplace, mod_class = test_case |
| m = mod_class(inplace=inplace) |
| m = self.checkGraphModeOp(m, self.img_data_2d, "quantized::elu", tracing) |
| FileCheck().check_not("aten::elu").check_not("aten::elu_").run(m.graph) |
| |
| def test_layer_norm(self): |
| data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)] for _ in range(2)] |
| layer_norm = torch.nn.LayerNorm([2, 5, 5]) |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp( |
| layer_norm, data, "quantized::layer_norm", tracing |
| ) |
| FileCheck().check_not("aten::layer_norm").run(m.graph) |
| |
| def test_group_norm(self): |
| data = [[torch.rand((1, 4, 5, 5), dtype=torch.float)] for _ in range(2)] |
| group_norm = torch.nn.GroupNorm(2, 4) |
| for tracing in [True, False]: |
| m = self.checkGraphModeOp( |
| group_norm, data, "quantized::group_norm", tracing |
| ) |
| FileCheck().check_not("aten::group_norm").run(m.graph) |
| |
| def test_instance_norm(self): |
| data_1d = [[torch.rand((1, 4, 5), dtype=torch.float)] for _ in range(2)] |
| data_2d = [[torch.rand((1, 4, 5, 1), dtype=torch.float)] for _ in range(2)] |
| data_3d = [[torch.rand((1, 4, 5, 1, 1), dtype=torch.float)] for _ in range(2)] |
| data = {1: data_1d, 2: data_2d, 3: data_3d} |
| instance_norm_modules = { |
| 1: torch.nn.InstanceNorm1d, |
| 2: torch.nn.InstanceNorm2d, |
| 3: torch.nn.InstanceNorm3d, |
| } |
| |
| options = itertools.product([1, 2, 3], [True, False]) |
| for dim, tracing in options: |
| instance_norm = instance_norm_modules[dim](4) |
| m = self.checkGraphModeOp( |
| instance_norm, data[dim], "quantized::instance_norm", tracing |
| ) |
| FileCheck().check_not("aten::instance_norm").run(m.graph) |
| |
| @skipIfNoFBGEMM |
| def test_dequantize_tuple(self): |
| """Make sure dequantize can support Tuple of tensor""" |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3).float() |
| self.conv2 = torch.nn.Conv2d(3, 3, 3).float() |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| x1 = self.conv1(x) |
| x2 = self.conv2(x) |
| return x1, x2 |
| |
| for tracing in [True, False]: |
| self.checkGraphModeOp(M(), self.img_data_2d, "quantized::conv2d", tracing) |
| |
| @skipIfNoFBGEMM |
| def test_clamp(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(2, 2, 2).float() |
| self.relu6 = torch.nn.ReLU6() |
| self.relu6_ = torch.nn.ReLU6(True) |
| self.hardtanh = torch.nn.Hardtanh() |
| self.hardtanh_ = torch.nn.Hardtanh(inplace=True) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.relu6(x) |
| self.relu6_(x) |
| x = F.relu6(x) |
| x = torch.clamp(x, -3, 3) |
| x = x.clamp(-2.5, 2.5) |
| # x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready |
| x = self.hardtanh(x) |
| self.hardtanh_(x) |
| x = F.hardtanh(x) |
| F.hardtanh_(x) |
| return x |
| |
| data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]] |
| options = itertools.product( |
| ["aten::clamp", "aten::hardtanh", "aten::hardtanh_"], [True, False] |
| ) |
| for op, tracing in options: |
| m = self.checkGraphModeOp(M(), data, op, tracing) |
| FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( |
| m.graph |
| ) |
| |
| FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) |
| |
| def test_general_shape_ops(self): |
| """A test that checks dequantize will be swapped for |
| all supported general shape ops like aten::flatten |
| without actually checking for execution of these ops |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3) |
| self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3) |
| self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3) |
| self.dropout = torch.nn.Dropout() |
| self.conv1 = torch.nn.Conv2d(3, 3, 3) |
| self.conv2 = torch.nn.Conv2d(3, 3, 3) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| x = self.conv1(x) |
| # add_scalar |
| x = x + 3 |
| # mul_scalar |
| x = x * 3 |
| # add_scalar_out |
| x += 3 |
| # mul_scalar_out |
| x *= 3 |
| # add_scalar_relu |
| x = x + 3 |
| x = F.relu(x) |
| # add_scalar_relu_out |
| x += 3 |
| x = F.relu(x) |
| # mul_scalar_relu |
| x = x * 3 |
| x = F.relu(x) |
| # mul_scalar_relu_out |
| x *= 3 |
| x = F.relu(x) |
| x = self.maxpool1d(x) |
| x = self.maxpool2d(x) |
| x = self.maxpool3d(x) |
| x = torch.flatten(x) |
| x = torch.max(x) |
| x = torch.min(x) |
| x = x.reshape([-1]) |
| x = x.resize_(1, 1, x.numel()) |
| x = x.view(-1) |
| # prim::ListConstruct |
| xs = [x, x] |
| # prim::ListUnpack |
| x, y = xs |
| # prim::TupleConstruct |
| xs = (x, x) |
| # prim::TupleUnpack |
| x, y = xs |
| x = x.transpose(1, 2) |
| x = x.contiguous() |
| x, y = torch.chunk(x, 2) |
| x = F.dropout(x) |
| x = self.dropout(x) |
| x, _ = torch.sort(x) |
| x = x.permute(0, 2, 3, 1) |
| x = torch.repeat_interleave(x, 3, 1) |
| x = self.relu(x) |
| x = F.relu(x) |
| x.relu_() |
| x = x.squeeze(0) |
| x.squeeze_(0) |
| x = torch.squeeze(x, 0) |
| x = x.unsqueeze(0) |
| x.unsqueeze_(0) |
| x = torch.unsqueeze(x, 0) |
| x = x.detach() |
| x.detach_() |
| x = x.repeat(4, 2) |
| y = [] |
| y.append(x) |
| z = torch.stack(y, 0) |
| z = [z, z] |
| x, _ = z |
| x = self.conv2(x) |
| return x |
| |
| data = torch.rand(1, 3, 10, 10) |
| # This model is not executable since we just put all ops |
| # in the same forward, therefore we only test scripting |
| m = torch.jit.script(M()) |
| qconfig = script_qconfig(default_qconfig) |
| # dummy data to suppress warning |
| get_forward(qconfig.activation)(data) |
| get_forward(qconfig.weight)(data) |
| |
| m = wrap_cpp_module( |
| torch._C._jit_pass_insert_observers( |
| m._c, "forward", {"": qconfig}, inplace=False |
| ) |
| ) |
| m = convert_jit(m) |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers and also successfully fused two quantized::conv2d |
| # patterns |
| # one quantize_per_tensor for input |
| FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run( |
| m.graph |
| ) |
| |
| FileCheck().check_count("quantized::conv2d(", 2, exactly=True).run(m.graph) |
| |
| FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph) |
| |
| FileCheck().check("quantized::add_scalar").check("quantized::mul_scalar").run( |
| m.graph |
| ) |
| |
| def test_general_value_ops(self): |
| """ A test that checks correct patterns are produced for |
| all supported general value ops like aten::avg_pool2d \ |
| without actually checking for execution of these ops |
| """ |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 3, 3) |
| self.avg_pool1d = torch.nn.AvgPool1d(3) |
| self.avg_pool2d = torch.nn.AvgPool2d(3) |
| self.avg_pool3d = torch.nn.AvgPool3d(3) |
| self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1)) |
| self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1)) |
| self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1)) |
| self.leaky_relu = torch.nn.LeakyReLU() |
| self.hardsigmoid = torch.nn.Hardsigmoid() |
| self.sigmoid = torch.nn.Sigmoid() |
| self.tanh = torch.nn.Tanh() |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.avg_pool1d(x) |
| x = self.avg_pool2d(x) |
| x = self.avg_pool3d(x) |
| x = self.adaptive_avg_pool1d(x) |
| x = self.adaptive_avg_pool2d(x) |
| x = self.adaptive_avg_pool3d(x) |
| x = F.avg_pool1d(x, 3) |
| x = F.avg_pool2d(x, 3) |
| x = F.avg_pool3d(x, 3) |
| x = F.adaptive_avg_pool1d(x, (1)) |
| x = F.adaptive_avg_pool2d(x, (1, 1)) |
| x = F.adaptive_avg_pool3d(x, (1, 1, 1)) |
| x = torch.mean(x) |
| x = torch.mean(x, [2, 3], False) |
| x = x.mean() |
| x = x.mean([2, 3], True) |
| # interpolate node will introduce 3 quantize_per_tensor ops |
| x = F.interpolate(x, 4, mode="nearest") # interpolate node |
| x = F.upsample(x, (32, 32)) # interpolate node |
| x = F.upsample_nearest(x, (32, 32)) # interpolate node |
| x = F.interpolate(x, 4, mode="linear") # common node |
| x = F.upsample_bilinear(x, (32, 32)) # common node |
| x = self.leaky_relu(x) |
| x = F.leaky_relu(x) |
| x.leaky_relu_() |
| x = self.hardsigmoid(x) |
| x = F.hardsigmoid(x) |
| x.hardsigmoid_() |
| x = self.sigmoid(x) |
| x = torch.sigmoid(x) |
| # F.sigmoid is deprecated |
| x = x.sigmoid() |
| x.sigmoid_() |
| x = self.tanh(x) |
| # F.tanh is deprecated |
| x = torch.tanh(x) |
| x = x.tanh() |
| x.tanh_() |
| x = self.conv(x) |
| return x |
| |
| # This model is not executable since we just put all ops |
| # in the same forward, therefore we only test scripting |
| m = torch.jit.script(M()) |
| qconfig = script_qconfig(default_qconfig) |
| # dummy data to suppress warning |
| data = torch.rand(1, 3, 10, 10) |
| get_forward(qconfig.activation)(data) |
| get_forward(qconfig.weight)(data) |
| |
| m = wrap_cpp_module( |
| torch._C._jit_pass_insert_observers( |
| m._c, "forward", {"": qconfig}, inplace=False |
| ) |
| ) |
| # Checking the model before fianlize contain unfused patterns |
| # that numerically matches the model after quantize by checking |
| # number of aten::quantize_per_tensor functions |
| # conv has 3 quantize_per_tensor for activations and 1 for weight |
| # and for N general value op between conv we should have |
| |
| # N + 1 quantize_per_tensor between these ops |
| m1 = convert_jit(m, debug=True) |
| # NB: This Needs to be updated when we add more ops to test |
| # mapping from number of quant for the op to the number of these ops |
| # for example, for `3` in the key means for this type of op |
| # we'll have 3 quantize_per_tensor |
| num_op_by_num_quant = {1: 32, 2: 2, 3: 3} |
| num_quantize_per_tensor = 1 # for output |
| for num_quant, num_op in num_op_by_num_quant.items(): |
| num_quantize_per_tensor += num_op * num_quant |
| num_quantize_per_tensor -= 4 # constant propagation removes some prepacks |
| FileCheck().check_count( |
| "aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True |
| ).run(m1.graph) |
| |
| # This checks that the dequantize from the output of first conv |
| # is being propagated to the end, so that we don't insert extra |
| # observers and also successfully fused two quantized::conv2d |
| # patterns |
| # one quantize_per_tensor for input |
| m2 = convert_jit(m, debug=False) |
| FileCheck().check_count("aten::quantize_per_tensor(", 1, exactly=True).run( |
| m2.graph |
| ) |
| FileCheck().check_count("quantized::conv2d(", 2, exactly=True).check( |
| "aten::dequantize(" |
| ).run(m2.graph) |
| |
| @override_qengines |
| def test_conv_with_benchmark_flag(self): |
| r"""Verifies that convolutions get quantized when |
| torch.backends.cudnn.benchmark is enabled |
| """ |
| if not qengine_is_qnnpack(): |
| return |
| with torch.backends.cudnn.flags(enabled=True): |
| m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1)) |
| m.eval() |
| m = torch.jit.trace(m, torch.rand(4, 1, 4, 4)) |
| qconfig = torch.ao.quantization.get_default_qconfig("qnnpack") |
| prepared_model = torch.ao.quantization.prepare_jit(m, {"": qconfig}) |
| prepared_model(torch.rand(4, 1, 4, 4)) |
| converted_model = torch.ao.quantization.convert_jit(prepared_model) |
| FileCheck().check("quantized::conv2d").run(converted_model.graph) |
| |
| @skipIfNoFBGEMM |
| def test_cat_linear(self): |
| class LinearModel(torch.nn.Module): |
| def __init__(self): |
| super(LinearModel, self).__init__() |
| self.weight = torch.randn(5, 5) |
| |
| def forward(self, x, y): |
| a = torch.cat([x, y]) |
| b = F.linear(a, self.weight) |
| c = F.linear(b, self.weight) |
| return b, c |
| |
| model = LinearModel().eval() |
| qconfig = {"": default_qconfig} |
| float_model = torch.jit.script(model) |
| prepared_model = prepare_jit(float_model, qconfig) |
| prepared_model(torch.rand(5, 5), torch.rand(5, 5)) |
| converted_model = convert_jit(prepared_model) |
| FileCheck().check("quantized::linear").check("quantized::linear").run( |
| converted_model.graph |
| ) |
| |
| |
| class TestQuantizeDynamicJitPasses(QuantizationTestCase): |
| def test_prepare_dynamic(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| model = torch.jit.script(M()) |
| for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]: |
| m = prepare_dynamic_jit(model, {"": qconfig}) |
| |
| # observer for weight |
| assert len(attrs_with_prefix(m.fc, "_observer_")) == 1 |
| |
| if qconfig == float16_dynamic_qconfig: |
| observer_name = 'PlaceholderObserver = prim::GetAttr[name="_observer_' |
| FileCheck().check(observer_name).run(m.fc.graph) |
| else: |
| # for input of FC for dynamic quant |
| assert len(attrs_with_prefix(m, "_observer_")) == 1 |
| observer_name = 'Observer = prim::GetAttr[name="_observer_' |
| FileCheck().check(observer_name).check( |
| 'prim::GetAttr[name="fc"]' |
| ).check("prim::CallMethod").check_not(observer_name).run(m.graph) |
| |
| def test_prepare_dynamic_child_qconfig(self): |
| class Sub(torch.nn.Module): |
| def __init__(self): |
| super(Sub, self).__init__() |
| self.fc = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.conv = torch.nn.Conv2d(3, 5, 3) |
| self.sub = Sub() |
| |
| def forward(self, x): |
| return self.sub(self.conv(x)) |
| |
| m = torch.jit.script(M()) |
| # only quantize child module. |
| m = prepare_dynamic_jit(m, {"sub.fc": default_dynamic_qconfig}) |
| |
| # input of sub for dynamic quant |
| assert len(attrs_with_prefix(m, "_observer_")) == 1 |
| # not quantized |
| assert len(attrs_with_prefix(m.conv, "_observer_")) == 0 |
| # no observers since we observe in the outer most call site |
| assert len(attrs_with_prefix(m.sub, "_observer_")) == 0 |
| # weight of linear |
| assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1 |
| FileCheck().check('prim::GetAttr[name="sub').check("prim::CallMethod").check( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).check("prim::CallMethod").check_not( |
| 'Observer = prim::GetAttr[name="_observer_' |
| ).run( |
| m.graph |
| ) |
| |
| def test_insert_quant_dequant_linear_dynamic(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc1 = torch.nn.Linear(5, 5).float() |
| self.fc2 = torch.nn.Linear(5, 5).float() |
| |
| def forward(self, x): |
| x = self.fc1(x) |
| return self.fc2(x) |
| |
| for is_per_channel in [True, False]: |
| m = torch.jit.script(M()) |
| qconfig = ( |
| per_channel_dynamic_qconfig |
| if is_per_channel is True |
| else default_dynamic_qconfig |
| ) |
| m = quantize_dynamic_jit(m, {"": qconfig}, debug=True) |
| assert ( |
| len(m._modules._c.items()) == 2 |
| ), "Expected to have two submodule of linear" |
| |
| wt_quant_func = ( |
| "aten::quantize_per_channel" |
| if is_per_channel |
| else "aten::quantize_per_tensor" |
| ) |
| act_quant_func = "aten::quantize_per_tensor" |
| # quantizing activations |
| FileCheck().check("aten::_choose_qparams_per_tensor").check_next( |
| act_quant_func |
| ).check_next("aten::dequantize").check( |
| "aten::_choose_qparams_per_tensor" |
| ).check_next( |
| act_quant_func |
| ).check_next( |
| "aten::dequantize" |
| ).check( |
| wt_quant_func |
| ).check_next( |
| "aten::dequantize" |
| ).check_not( |
| wt_quant_func |
| ).check( |
| "return" |
| ).run( |
| m.graph |
| ) |
| |
| @override_qengines |
| def test_dynamic_multi_op(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float) |
| |
| def forward(self, x): |
| x = x + 5 |
| return self.fc1(x) |
| |
| x = torch.randn(5, 5) |
| for tracing in [True, False]: |
| model = self.checkGraphModeOp( |
| M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True |
| ) |
| # add op is not dynamically quantized. |
| FileCheck().check("aten::add").run(model.graph) |
| |
| @override_qengines |
| def test_dynamic_quant_multi_uses(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc = torch.nn.Linear(5, 5).float() |
| |
| def forward(self, x): |
| size1 = x.size() |
| size2 = x.size() |
| return self.fc(x), size1, size2 |
| |
| x = torch.randn(5, 5) |
| for tracing in [True, False]: |
| model = self.checkGraphModeOp( |
| M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True |
| ) |
| FileCheck().check_not("aten::_choose_qparams_per_tensor").run(model.graph) |
| |
| @override_qengines |
| def test_dynamic_shared_weights(self): |
| class myMod(torch.nn.Module): |
| def __init__(self, weight): |
| super().__init__() |
| self.linear = nn.Linear(5, 5) |
| self.linear.weight = weight |
| |
| def forward(self, x): |
| return self.linear(x) |
| |
| class DynamicModel(torch.nn.Module): |
| def __init__(self): |
| super(DynamicModel, self).__init__() |
| self.weight = torch.nn.Parameter(torch.ones(5, 5)) |
| self.mod1 = myMod(self.weight) |
| |
| def forward(self, x): |
| y = self.mod1(x) |
| z = torch.nn.functional.linear(y, self.weight) |
| return z |
| |
| model = torch.jit.script(DynamicModel()).eval() |
| data = torch.randn(5, 5, dtype=torch.float) |
| quant_ops = ["mod1", ""] |
| counts = [1, 2] |
| for op, count in zip(quant_ops, counts): |
| qconfig_dict = {op: default_dynamic_qconfig} |
| m1 = quantize_dynamic_jit(model, qconfig_dict) |
| out_graph = m1(data) |
| |
| FileCheck().check_count( |
| "quantized::linear_dynamic(", count, exactly=True |
| ).check_not("aten::_choose_qparams_per_tensor").run(m1.graph) |
| |
| # Explicitly call forward on model before convert |
| m2 = prepare_dynamic_jit(model, qconfig_dict) |
| m2(data) |
| m2 = convert_dynamic_jit(m2, debug=False) |
| out_ref = m2(data) |
| self.assertEqual(out_graph, out_ref) |
| |
| @override_qengines |
| def test_dynamic_with_if(self): |
| class Res(torch.nn.Module): |
| def __init__(self): |
| super(Res, self).__init__() |
| self.weight = torch.nn.Parameter(torch.ones(5, 5)) |
| |
| def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor: |
| if cond: |
| return torch.nn.functional.linear(x, self.weight) |
| else: |
| return torch.nn.functional.linear(x, self.weight) |
| |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.res1 = Res() |
| self.res2 = Res() |
| |
| def forward(self, x): |
| x = self.res1(x, True) |
| x = self.res2(x, False) |
| return x |
| |
| model = torch.jit.script(M()).eval() |
| data = torch.randn(5, 5, dtype=torch.float) |
| qconfig_dict = {"": default_dynamic_qconfig} |
| for tracing in [True, False]: |
| m1 = self.checkGraphModeOp( |
| M(), data, "quantized::linear_dynamic", tracing=tracing, dynamic=True |
| ) |
| FileCheck().check_count( |
| "quantized::linear_dynamic(", 2, exactly=True |
| ).check_not("aten::_choose_qparams_per_tensor").run(m1.graph) |
| |
| # Check to make sure weight observers run correctly |
| ref_qparams = [] |
| qconfig = script_qconfig(default_dynamic_qconfig) |
| wt_module = wrap_cpp_module(qconfig.weight) |
| for wt in [model.res1.weight, model.res2.weight]: |
| wt_module(wt) |
| qparams = wt_module.calculate_qparams() |
| ref_qparams.append((qparams[0].item(), qparams[1].item())) |
| |
| m2 = quantize_dynamic_jit(model, qconfig_dict, debug=True) |
| graph_params = [] |
| for x, obs in m2._modules._c.items(): |
| if x == "res1": |
| graph_params.append( |
| (obs.getattr("weight.2_scale_0"), obs.getattr("weight.2_zero_point_0")) |
| ) |
| elif x == "res2": |
| graph_params.append( |
| (obs.getattr("weight.4_scale_0"), obs.getattr("weight.4_zero_point_0")) |
| ) |
| self.assertEqual(ref_qparams, graph_params) |
| |
| def test_dynamic_weight_observer(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc = torch.nn.Linear(5, 5).float() |
| self.fc2 = torch.nn.Linear(5, 5).float() |
| |
| def forward(self, x): |
| x = self.fc(x) |
| return self.fc2(x) |
| |
| qconfig_dict = {"": default_dynamic_qconfig} |
| eager_model = M().eval() |
| for tracing in [True, False]: |
| x = torch.rand(5, 5) |
| model = get_script_module(eager_model, tracing, x) |
| ref_qparams = [] |
| for wt in [model.fc.weight, model.fc2.weight]: |
| wt_module = default_dynamic_qconfig.weight() |
| wt_module(wt) |
| qparams = wt_module.calculate_qparams() |
| ref_qparams.append((qparams[0].item(), qparams[1].item())) |
| model = quantize_dynamic_jit(model, qconfig_dict, debug=True) |
| graph_qparams = [] |
| for x, obs in model._modules._c.items(): |
| n = 2 if x == 'fc' and tracing else 1 |
| graph_qparams.append( |
| (obs.getattr(f"weight.{n}_scale_0"), |
| obs.getattr(f"weight.{n}_zero_point_0")) |
| ) |
| self.assertEqual(ref_qparams, graph_qparams) |
| |
| def test_convert_dynamic_fp16(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| m = torch.jit.script(M()) |
| m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig}, debug=True) |
| FileCheck().check("aten::_saturate_weight_to_fp16").check( |
| "aten::linear" |
| ).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph) |
| |
| def test_quantize_dynamic_fp16(self): |
| class M(torch.nn.Module): |
| def __init__(self): |
| super(M, self).__init__() |
| self.fc = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.fc(x) |
| |
| m = torch.jit.script(M()) |
| m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig}) |
| |
| FileCheck().check("quantized::linear_dynamic_fp16").check_not( |
| "aten::linear" |
| ).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph) |
| |
| |
| class TestQuantizeDynamicJitOps(QuantizationTestCase): |
| """Test graph mode post training dynamic quantization works |
| for individual ops end to end. |
| """ |
| |
| @override_qengines |
| def test_linear(self): |
| class FunctionalLinear(torch.nn.Module): |
| def __init__(self, weight, bias): |
| super(FunctionalLinear, self).__init__() |
| self.weight = weight |
| self.bias = bias |
| |
| def forward(self, x): |
| return F.linear(x, self.weight, self.bias) |
| |
| x = torch.rand(5, 5) |
| for tracing in [True, False]: |
| model = self.checkGraphModeOp( |
| torch.nn.Linear(5, 5), |
| x, |
| "quantized::linear_dynamic", |
| tracing=tracing, |
| dynamic=True, |
| ) |
| |
| weight = torch.rand(5, 5) |
| b = torch.rand(5) |
| for tracing, has_bias in itertools.product([True, False], [True, False]): |
| bias = b if has_bias else None |
| model = self.checkGraphModeOp( |
| FunctionalLinear(weight, bias), |
| x, |
| "quantized::linear_dynamic", |
| tracing=tracing, |
| dynamic=True, |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_embedding_bag(self): |
| class M(torch.nn.Module): |
| def __init__(self, weights): |
| super(M, self).__init__() |
| self.embedding1 = torch.nn.EmbeddingBag( |
| num_embeddings=10, |
| embedding_dim=12, |
| include_last_offset=True, |
| sparse=True, |
| _weight=weights, |
| mode="sum", |
| ) |
| |
| self.embedding2 = torch.nn.EmbeddingBag( |
| num_embeddings=10, |
| embedding_dim=12, |
| include_last_offset=True, |
| sparse=True, |
| _weight=weights, |
| mode="sum", |
| ) |
| |
| def forward(self, indices1, offsets1, indices2, offsets2): |
| e1 = self.embedding1(indices1, offsets1) |
| e2 = self.embedding2(indices2, offsets2) |
| return e1, e2 |
| |
| weights = torch.randn(10, 12, dtype=torch.float32) |
| module = M(weights) |
| |
| indices = torch.tensor( |
| [ |
| 9, |
| 6, |
| 5, |
| 7, |
| 8, |
| 8, |
| 9, |
| 2, |
| 8, |
| 6, |
| 6, |
| 9, |
| 1, |
| 6, |
| 8, |
| 8, |
| 3, |
| 2, |
| 3, |
| 6, |
| 3, |
| 6, |
| 5, |
| 7, |
| 0, |
| 8, |
| 4, |
| 6, |
| 5, |
| 8, |
| 2, |
| 3, |
| ] |
| ) |
| offsets = torch.tensor([0, 19, 20, 28, 28, 32]) |
| dummy_inputs = (indices, offsets, indices, offsets) |
| for trace in [True, False]: |
| if trace: |
| m = torch.jit.trace(module, dummy_inputs) |
| else: |
| m = torch.jit.script(module) |
| int4_qconfig = QConfig( |
| activation=PlaceholderObserver.with_args( |
| dtype=torch.float, custom_op_name="embedding_bag_4bit" |
| ), |
| weight=PlaceholderObserver.with_args( |
| custom_op_name="embedding_bag_4bit" |
| ), |
| ) |
| int8_qconfig = QConfig( |
| activation=PlaceholderObserver.with_args( |
| dtype=torch.float, custom_op_name="embedding_bag_byte" |
| ), |
| weight=PlaceholderObserver.with_args( |
| custom_op_name="embedding_bag_byte" |
| ), |
| ) |
| m = prepare_jit(m, {"embedding1": int4_qconfig, "embedding2": int8_qconfig}) |
| m = convert_jit(m) |
| FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets").check( |
| "quantized::embedding_bag_byte_rowwise_offsets" |
| ).run(m.graph) |
| m(*dummy_inputs) |
| |
| # Ensure that attempting to quantize an EmbeddingBag throws an error if |
| # padding_idx is not None |
| @skipIfNoFBGEMM |
| def test_embedding_bag_padding_idx_error(self): |
| class M(torch.nn.Module): |
| def __init__(self, weights): |
| super(M, self).__init__() |
| self.embedding = torch.nn.EmbeddingBag( |
| num_embeddings=10, |
| embedding_dim=12, |
| include_last_offset=True, |
| sparse=True, |
| _weight=weights, |
| mode="sum", |
| padding_idx=0, |
| ) |
| |
| def forward(self, indices, offsets): |
| e = self.embedding(indices, offsets) |
| return e |
| |
| weights = torch.randn(10, 12, dtype=torch.float32) |
| module = M(weights) |
| |
| indices = torch.tensor([0, 1, 2, 3, 4]) |
| offsets = torch.tensor([0, 2, 5]) |
| dummy_inputs = (indices, offsets) |
| |
| int4_qconfig = QConfig( |
| activation=PlaceholderObserver.with_args( |
| dtype=torch.float, custom_op_name="embedding_bag_4bit" |
| ), |
| weight=PlaceholderObserver.with_args( |
| custom_op_name="embedding_bag_4bit" |
| ), |
| ) |
| int8_qconfig = QConfig( |
| activation=PlaceholderObserver.with_args( |
| dtype=torch.float, custom_op_name="embedding_bag_byte" |
| ), |
| weight=PlaceholderObserver.with_args( |
| custom_op_name="embedding_bag_byte" |
| ), |
| ) |
| |
| error_msg = r'Expected aten::embedding_bag padding_idx input to be None' |
| for trace, qconfig in itertools.product([True, False], [int4_qconfig, int8_qconfig]): |
| if trace: |
| m = torch.jit.trace(module, dummy_inputs) |
| else: |
| m = torch.jit.script(module) |
| m = prepare_jit(m, {"embedding": qconfig}) |
| with self.assertRaisesRegex(RuntimeError, error_msg): |
| m = convert_jit(m) |
| |
| |
| class TestQuantizeJit(QuantizationTestCase): |
| @override_qengines |
| def test_single_linear(self): |
| r"""Compare the result of quantizing single linear layer in |
| eager mode and graph mode |
| """ |
| # eager mode |
| annotated_linear_model = AnnotatedSingleLayerLinearModel( |
| torch.backends.quantized.engine |
| ).eval() |
| linear_model = SingleLayerLinearModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| linear_model.fc1.weight = torch.nn.Parameter( |
| annotated_linear_model.fc1.module.weight.detach() |
| ) |
| linear_model.fc1.bias = torch.nn.Parameter( |
| annotated_linear_model.fc1.module.bias.detach() |
| ) |
| model_eager = quantize( |
| annotated_linear_model, test_only_eval_fn, [self.calib_data] |
| ) |
| |
| qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} |
| model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(linear_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_jit( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False, |
| ) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| @skipIfNoFBGEMM |
| def test_observer_with_ignored_function(self): |
| r"""Test observers with ignored function and make sure it works in |
| graph mode |
| """ |
| # eager mode |
| annotated_linear_model = AnnotatedSingleLayerLinearModel("fbgemm").eval() |
| for qconfig in [ |
| QConfig(activation=default_observer, weight=default_weight_observer), |
| QConfig( |
| activation=default_histogram_observer, weight=default_weight_observer |
| ), |
| QConfig( |
| activation=default_observer, weight=default_per_channel_weight_observer |
| ), |
| ]: |
| annotated_linear_model.qconfig = qconfig |
| linear_model = SingleLayerLinearModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| linear_model.fc1.weight = torch.nn.Parameter( |
| annotated_linear_model.fc1.module.weight.detach() |
| ) |
| linear_model.fc1.bias = torch.nn.Parameter( |
| annotated_linear_model.fc1.module.bias.detach() |
| ) |
| model_eager = quantize( |
| annotated_linear_model, test_only_eval_fn, [self.calib_data] |
| ) |
| |
| qconfig_dict = {"": qconfig} |
| model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(linear_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_jit( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False, |
| ) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| @override_qengines |
| def test_conv(self): |
| r"""Compare the result of quantizing conv layer in |
| eager mode and graph mode |
| """ |
| # eager mode |
| annotated_conv_model = AnnotatedConvModel( |
| torch.backends.quantized.engine |
| ).eval() |
| conv_model = ConvModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| conv_model.conv.weight = torch.nn.Parameter( |
| annotated_conv_model.conv.weight.detach() |
| ) |
| model_eager = quantize( |
| annotated_conv_model, test_only_eval_fn, [self.img_data_2d] |
| ) |
| qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} |
| model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0]) |
| model_script = torch.jit.script(conv_model) |
| result_eager = model_eager(self.img_data_2d[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_jit( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.img_data_2d], |
| inplace=False, |
| ) |
| self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager) |
| |
| @override_qengines |
| def test_conv_transpose(self): |
| r"""Compare the result of quantizing conv_transpose layer in |
| eager mode and graph mode |
| """ |
| if not qengine_is_qnnpack(): |
| return # Currently only qnnpack is supported |
| # eager mode |
| annotated_conv_model = AnnotatedConvTransposeModel( |
| torch.backends.quantized.engine |
| ).eval() |
| conv_model = ConvTransposeModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| conv_model.conv.weight = torch.nn.Parameter( |
| annotated_conv_model.conv.weight.detach() |
| ) |
| model_eager = quantize( |
| annotated_conv_model, test_only_eval_fn, [self.img_data_2d] |
| ) |
| qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)} |
| model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0]) |
| model_script = torch.jit.script(conv_model) |
| result_eager = model_eager(self.img_data_2d[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_jit( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.img_data_2d], |
| inplace=False, |
| ) |
| self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager) |
| |
| @override_qengines |
| def test_conv_bn(self): |
| r"""Compare the result of quantizing conv + bn layer in |
| eager mode and graph mode |
| """ |
| # eager mode |
| conv_model = AnnotatedConvBnModel().eval() |
| conv_model_to_script = ConvBnModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| conv_model_to_script.conv.weight = torch.nn.Parameter( |
| conv_model.conv.weight.detach() |
| ) |
| fuse_modules(conv_model, ["conv", "bn"], inplace=True) |
| model_eager = quantize(conv_model, test_only_eval_fn, [self.img_data_2d]) |
| qconfig_dict = {"": default_qconfig} |
| model_script = quantize_jit( |
| torch.jit.script(conv_model_to_script), |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.img_data_2d], |
| inplace=False, |
| ) |
| result_eager = model_eager(self.img_data_2d[0][0]) |
| result_script = model_script(self.img_data_2d[0][0]) |
| self.assertEqual(result_eager, result_script) |
| |
| @override_qengines |
| def test_nested(self): |
| # Eager mode |
| eager_model = AnnotatedNestedModel(torch.backends.quantized.engine).eval() |
| |
| # Graph mode |
| script_model = NestedModel().eval() |
| # Copy weights for eager_model |
| script_model.sub1.fc.weight = torch.nn.Parameter( |
| eager_model.sub1.fc.weight.detach() |
| ) |
| script_model.sub1.fc.bias = torch.nn.Parameter( |
| eager_model.sub1.fc.bias.detach() |
| ) |
| script_model.sub2.fc1.weight = torch.nn.Parameter( |
| eager_model.sub2.fc1.module.weight.detach() |
| ) |
| script_model.sub2.fc1.bias = torch.nn.Parameter( |
| eager_model.sub2.fc1.module.bias.detach() |
| ) |
| script_model.sub2.fc2.weight = torch.nn.Parameter( |
| eager_model.sub2.fc2.weight.detach() |
| ) |
| script_model.sub2.fc2.bias = torch.nn.Parameter( |
| eager_model.sub2.fc2.bias.detach() |
| ) |
| script_model.fc3.weight = torch.nn.Parameter( |
| eager_model.fc3.module.weight.detach() |
| ) |
| script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach()) |
| |
| model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data]) |
| qconfig_dict = { |
| "sub2.fc1": default_per_channel_qconfig |
| if qengine_is_fbgemm() |
| else default_qconfig, |
| "fc3": default_qconfig, |
| } |
| model_traced = torch.jit.trace(script_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(script_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_jit( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False, |
| ) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| @override_qengines |
| def test_skip_quant(self): |
| """Test None qconfig""" |
| # Eager mode |
| eager_model = AnnotatedSkipQuantModel(torch.backends.quantized.engine).eval() |
| |
| # Graph mode |
| script_model = SkipQuantModel().eval() |
| # Copy weights for eager_model |
| script_model.sub.fc1.weight = torch.nn.Parameter( |
| eager_model.sub.module.fc1.weight.detach() |
| ) |
| script_model.sub.fc1.bias = torch.nn.Parameter( |
| eager_model.sub.module.fc1.bias.detach() |
| ) |
| script_model.sub.fc2.weight = torch.nn.Parameter( |
| eager_model.sub.module.fc2.weight.detach() |
| ) |
| script_model.sub.fc2.bias = torch.nn.Parameter( |
| eager_model.sub.module.fc2.bias.detach() |
| ) |
| script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach()) |
| script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach()) |
| |
| eager_model.fuse_modules() |
| |
| model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data]) |
| qconfig_dict = { |
| "": get_default_qconfig(torch.backends.quantized.engine), |
| "fc": None, |
| } |
| model_traced = torch.jit.trace(script_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(script_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_jit( |
| model_under_test, |
| qconfig_dict, |
| test_only_eval_fn, |
| [self.calib_data], |
| inplace=False, |
| ) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| @override_qengines |
| def test_single_linear_dynamic(self): |
| r"""Compare the result of dynamic quantization of single linear layer in |
| eager mode and graph mode. |
| """ |
| if qengine_is_qnnpack(): |
| # eager mode |
| annotated_linear_model = AnnotatedSingleLayerLinearModel("qnnpack").eval() |
| linear_model = SingleLayerLinearModel().eval() |
| # copy the weight from eager mode so that we can |
| # compare the result of the two quantized models later |
| linear_model.fc1.weight = torch.nn.Parameter( |
| annotated_linear_model.fc1.module.weight.detach() |
| ) |
| linear_model.fc1.bias = torch.nn.Parameter( |
| annotated_linear_model.fc1.module.bias.detach() |
| ) |
| qconfig_dict = {"": default_dynamic_qconfig} |
| model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict) |
| |
| model_traced = torch.jit.trace(linear_model, self.calib_data[0][0]) |
| model_script = torch.jit.script(linear_model) |
| result_eager = model_eager(self.calib_data[0][0]) |
| |
| for model_under_test in [model_traced, model_script]: |
| model_quantized = quantize_dynamic_jit(model_under_test, qconfig_dict) |
| self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager) |
| |
| # Check to make sure choose_qparams->quant->dequant->linear is numerically |
| # equivalent to the final quantized model. |
| model_fake_quantized = quantize_dynamic_jit( |
| model_under_test, qconfig_dict, debug=True |
| ) |
| self.assertEqual( |
| model_fake_quantized(self.calib_data[0][0]), result_eager |
| ) |
| |
| @skipIfNoFBGEMM |
| def test_linear_dynamic_fp16(self): |
| linear_model = SingleLayerLinearModel().eval() |
| # Create weight tensor values that are beyond fp16 max |
| x = torch.ones(5, 5) * 65532 |
| linear_model.fc1.weight = torch.nn.Parameter(x) |
| import warnings |
| |
| model_eager = quantize_dynamic(linear_model, dtype=torch.float16) |
| result_eager = model_eager(self.calib_data[0][0]) |
| for trace in [True]: |
| with warnings.catch_warnings(record=True) as w: |
| quantized_model = self.checkGraphModeOp( |
| linear_model, |
| self.calib_data[0][0], |
| "quantized::linear_dynamic_fp16", |
| tracing=trace, |
| dynamic=True, |
| qconfig=float16_dynamic_qconfig, |
| ) |
| # compare result with eager mode |
| self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager) |