|  | #!/usr/bin/env python3 | 
|  | # Owner(s): ["oncall: mobile"] | 
|  |  | 
|  | import os | 
|  | import ctypes | 
|  | import torch | 
|  | from typing import Tuple | 
|  | from torch.backends._nnapi.prepare import convert_model_to_nnapi | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests | 
|  |  | 
|  |  | 
|  | def qpt(t, scale, zero_point, dtype=torch.quint8): | 
|  | t = torch.tensor(t) | 
|  | return torch.quantize_per_tensor(t, scale, zero_point, dtype) | 
|  |  | 
|  |  | 
|  | def nhwc(t): | 
|  | t = t.clone().contiguous(memory_format=torch.channels_last) | 
|  | t.nnapi_nhwc = True | 
|  | return t | 
|  |  | 
|  |  | 
|  | class TestNNAPI(TestCase): | 
|  |  | 
|  | def setUp(self): | 
|  | # Avoid saturation in fbgemm | 
|  | torch.backends.quantized.engine = 'qnnpack' | 
|  |  | 
|  | libneuralnetworks_path = os.environ.get("LIBNEURALNETWORKS_PATH") | 
|  | if libneuralnetworks_path: | 
|  | ctypes.cdll.LoadLibrary(libneuralnetworks_path) | 
|  | print("Will attempt to run NNAPI models.") | 
|  | self.can_run_nnapi = True | 
|  | else: | 
|  | self.can_run_nnapi = False | 
|  |  | 
|  | # Created for easy override by subclasses (eg TestNnapiBackend) | 
|  | def call_lowering_to_nnapi(self, traced_module, args): | 
|  | return convert_model_to_nnapi(traced_module, args) | 
|  |  | 
|  | # Created for subclasses to set can_run_nnapi (eg TestNnapiBackend) | 
|  | def set_can_run_nnapi(self, can_run): | 
|  | self.can_run_nnapi = can_run | 
|  |  | 
|  | def check( | 
|  | self, | 
|  | module, | 
|  | arg_or_args, | 
|  | *, | 
|  | trace_args=None, | 
|  | convert_args=None, | 
|  | atol_rtol=None, | 
|  | limit=None, | 
|  | expected_memory_format=None | 
|  | ): | 
|  | with torch.no_grad(): | 
|  | if isinstance(arg_or_args, torch.Tensor): | 
|  | args = [arg_or_args] | 
|  | else: | 
|  | args = arg_or_args | 
|  | module.eval() | 
|  | traced = torch.jit.trace(module, trace_args or args) | 
|  | nnapi_module = self.call_lowering_to_nnapi(traced, convert_args or args) | 
|  | if not self.can_run_nnapi: | 
|  | # Only test that the model was converted successfully. | 
|  | return | 
|  | eager_output = module(*args) | 
|  | nnapi_output = nnapi_module(*args) | 
|  | kwargs = {} | 
|  | if atol_rtol is not None: | 
|  | kwargs["atol"] = atol_rtol[0] | 
|  | kwargs["rtol"] = atol_rtol[1] | 
|  | self.assertEqual(eager_output, nnapi_output, **kwargs) | 
|  | if limit is not None: | 
|  | mismatches = \ | 
|  | eager_output.int_repr().to(torch.int32) - \ | 
|  | nnapi_output.int_repr().to(torch.int32) | 
|  | if mismatches.count_nonzero() > limit: | 
|  | # Too many mismatches.  Re-run the check with no tolerance | 
|  | # to get a nice message. | 
|  | self.assertEqual(eager_output, nnapi_output, atol=0, rtol=0) | 
|  | if expected_memory_format: | 
|  | self.assertTrue(nnapi_output.is_contiguous(memory_format=expected_memory_format)) | 
|  |  | 
|  | def float_and_quant_and_nhwc(self, inp_float, scale, zero_point): | 
|  | torch.manual_seed(29) | 
|  | inp_quant = qpt(inp_float, 0.03, 128) | 
|  | return [ | 
|  | ("float", inp_float), | 
|  | ("float-nhwc", nhwc(inp_float)), | 
|  | ("quant", inp_quant), | 
|  | ("quant-nhwc", nhwc(inp_quant)), | 
|  | ] | 
|  |  | 
|  | def test_prelu(self): | 
|  | arg = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) | 
|  | single_a = torch.nn.PReLU() | 
|  | self.check(single_a, arg) | 
|  | multi_a = torch.nn.PReLU(4) | 
|  | with torch.no_grad(): | 
|  | multi_a.weight.copy_(torch.tensor([.1, .2, .3, .4])) | 
|  | self.check(multi_a, nhwc(arg)) | 
|  |  | 
|  | # Test flexible size | 
|  | self.check( | 
|  | multi_a, | 
|  | arg, | 
|  | trace_args=[torch.zeros(1, 4, 3, 3)], | 
|  | convert_args=[nhwc(torch.zeros(1, 4, 0, 0))], | 
|  | ) | 
|  |  | 
|  | def test_quantize(self): | 
|  | self.check( | 
|  | torch.nn.quantized.Quantize(0.25, 2, torch.quint8), | 
|  | nhwc(torch.tensor([[[[1.0]], [[2.0]]]]))) | 
|  |  | 
|  | def test_dequantize(self): | 
|  | self.check( | 
|  | torch.nn.quantized.DeQuantize(), | 
|  | nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2))) | 
|  |  | 
|  | def test_unsqueeze(self): | 
|  | class UnsqueezeModule(torch.nn.Module): | 
|  | def __init__(self, dim): | 
|  | super().__init__() | 
|  | self.dim = dim | 
|  |  | 
|  | def forward(self, arg): | 
|  | return arg.unsqueeze(self.dim) | 
|  |  | 
|  | self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2)) | 
|  | self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2)) | 
|  | self.check(UnsqueezeModule(0), torch.randn(4, 2, 2)) | 
|  | self.check(UnsqueezeModule(1), torch.randn(4, 2, 2)) | 
|  | self.check(UnsqueezeModule(2), torch.randn(4, 2, 2)) | 
|  |  | 
|  | def test_reshape(self): | 
|  | class ReshapeModule(torch.nn.Module): | 
|  | def __init__(self, shape): | 
|  | super().__init__() | 
|  | self.shape = shape | 
|  |  | 
|  | def forward(self, arg): | 
|  | return arg.reshape(self.shape) | 
|  |  | 
|  | self.check( | 
|  | ReshapeModule((2, 4)), | 
|  | torch.randn(4, 2, 1, 1)) | 
|  |  | 
|  | self.check( | 
|  | ReshapeModule((8, -1)), | 
|  | nhwc(torch.randn(4, 2, 1, 1))) | 
|  |  | 
|  | with self.assertRaisesRegex(Exception, "target size"): | 
|  | self.check( | 
|  | ReshapeModule((2, 4)), | 
|  | nhwc(torch.randn(4, 2, 1, 1))) | 
|  |  | 
|  | def test_flatten(self): | 
|  | for mod in [ | 
|  | torch.nn.Flatten(), | 
|  | torch.nn.Flatten(start_dim=2, end_dim=3), | 
|  | torch.nn.Flatten(start_dim=2, end_dim=4), | 
|  | torch.nn.Flatten(start_dim=0, end_dim=-2), | 
|  | torch.nn.Flatten(start_dim=0, end_dim=4) | 
|  |  | 
|  | ]: | 
|  | self.check(mod, torch.randn(4, 2, 1, 3, 7)) | 
|  |  | 
|  | # flex inputs | 
|  | self.check( | 
|  | torch.nn.Flatten(), | 
|  | torch.randn(4, 2, 1, 3, 7), | 
|  | convert_args=[torch.zeros(0, 2, 1, 3, 7)] | 
|  | ) | 
|  |  | 
|  | # channels last | 
|  | self.check( | 
|  | torch.nn.Flatten(), | 
|  | nhwc(torch.randn(2, 1, 4, 7)) | 
|  | ) | 
|  | self.check( | 
|  | torch.nn.Flatten(), | 
|  | nhwc(torch.randn(2, 3, 1, 1)) | 
|  | ) | 
|  |  | 
|  | # Exceptions | 
|  | with self.assertRaisesRegex(Exception, "not supported on NHWC"): | 
|  | self.check( | 
|  | torch.nn.Flatten(), | 
|  | nhwc(torch.randn(1, 3, 4, 4)) | 
|  | ) | 
|  | with self.assertRaisesRegex(Exception, "Flattening flexible dims is not supported yet"): | 
|  | self.check(torch.nn.Flatten(), torch.randn(4, 2, 0, 0, 7)) | 
|  | with self.assertRaisesRegex(Exception, "Only 1 dim"): | 
|  | self.check( | 
|  | torch.nn.Flatten(start_dim=1, end_dim=-2), | 
|  | torch.randn(0, 2, 1, 3, 0)) | 
|  |  | 
|  | def test_slice(self): | 
|  | class SliceModule(torch.nn.Module): | 
|  | def __init__(self, start, stop, step): | 
|  | super().__init__() | 
|  | self.start = start | 
|  | self.stop = stop | 
|  | self.step = step | 
|  |  | 
|  | def forward(self, t): | 
|  | return t[1:, self.start:self.stop:self.step, :] | 
|  |  | 
|  | class SliceModule2(torch.nn.Module): | 
|  | def forward(self, t): | 
|  | return t[3:] | 
|  |  | 
|  | self.check( | 
|  | SliceModule(1, 5, 2), | 
|  | torch.randn(4, 6, 2) | 
|  | ) | 
|  | self.check( | 
|  | SliceModule2(), | 
|  | torch.randn(5) | 
|  | ) | 
|  |  | 
|  | # flex inputs | 
|  | self.check( | 
|  | SliceModule(1, 5, 2), | 
|  | torch.randn(4, 6, 2), | 
|  | convert_args=[torch.zeros(4, 6, 0)] | 
|  | ) | 
|  | with self.assertRaisesRegex(Exception, "slice with flexible shape"): | 
|  | self.check( | 
|  | SliceModule(1, 5, 2), | 
|  | torch.randn(4, 6, 2), | 
|  | convert_args=[torch.zeros(0, 0, 0)] | 
|  | ) | 
|  |  | 
|  | def test_cat(self): | 
|  | class CatModule(torch.nn.Module): | 
|  | def __init__(self, dim): | 
|  | super().__init__() | 
|  | self.dim = dim | 
|  |  | 
|  | def forward(self, t1, t2): | 
|  | return torch.cat([t1, t2], self.dim) | 
|  |  | 
|  | self.check( | 
|  | CatModule(0), | 
|  | [ | 
|  | torch.randn(1, 2, 3, 3), | 
|  | torch.randn(2, 2, 3, 3), | 
|  | ]) | 
|  |  | 
|  | self.check( | 
|  | CatModule(1), | 
|  | [ | 
|  | torch.randn(1, 2, 3, 3), | 
|  | torch.randn(1, 4, 3, 3), | 
|  | ]) | 
|  |  | 
|  | self.check( | 
|  | CatModule(1), | 
|  | [ | 
|  | nhwc(torch.randn(1, 2, 3, 3)), | 
|  | nhwc(torch.randn(1, 4, 3, 3)), | 
|  | ]) | 
|  |  | 
|  | self.check( | 
|  | CatModule(1), | 
|  | [ | 
|  | torch.randn(1, 2, 3, 3), | 
|  | torch.randn(1, 4, 3, 3), | 
|  | ], | 
|  | convert_args=[ | 
|  | torch.zeros(0, 0, 0, 0), | 
|  | torch.zeros(0, 0, 0, 0) | 
|  | ]) | 
|  |  | 
|  | def test_pointwise_unary(self): | 
|  | for op in ["relu", "sigmoid"]: | 
|  | with self.subTest(op): | 
|  | class UnaryModule(torch.nn.Module): | 
|  | def forward(self, arg): | 
|  | if op == "relu": | 
|  | return torch.nn.functional.relu(arg) | 
|  | if op == "sigmoid": | 
|  | return torch.sigmoid(arg) | 
|  | raise Exception("Bad op") | 
|  | self.check(UnaryModule(), torch.tensor([-1.0, 1.0])) | 
|  | self.check( | 
|  | UnaryModule(), | 
|  | qpt(torch.tensor([-1.0, 1.0]), 1. / 256, 0), | 
|  | ) | 
|  |  | 
|  | def test_pointwise_binary(self): | 
|  | for op in ["add", "sub", "mul", "div"]: | 
|  | with self.subTest(op): | 
|  | class BinaryModule(torch.nn.Module): | 
|  | def forward(self, lhs, rhs): | 
|  | if op == "add": | 
|  | return lhs + rhs | 
|  | if op == "sub": | 
|  | return lhs - rhs | 
|  | if op == "mul": | 
|  | return lhs * rhs | 
|  | if op == "div": | 
|  | return lhs / rhs | 
|  | raise Exception("Bad op") | 
|  |  | 
|  | self.check( | 
|  | BinaryModule(), | 
|  | [ | 
|  | torch.tensor([1.0, 2.0]), | 
|  | torch.tensor([3.0, 4.0]), | 
|  | ]) | 
|  |  | 
|  | self.check( | 
|  | BinaryModule(), | 
|  | [ | 
|  | torch.tensor([[1.0, 2.0]]), | 
|  | torch.tensor([[3.0, 4.0], [5.0, 6.0]]), | 
|  | ]) | 
|  |  | 
|  | with self.assertRaisesRegex(Exception, "Non-equal-rank broadcast"): | 
|  | self.check( | 
|  | BinaryModule(), | 
|  | [ | 
|  | torch.tensor([1.0, 2.0]), | 
|  | torch.tensor([[3.0, 4.0], [5.0, 6.0]]), | 
|  | ]) | 
|  |  | 
|  | def test_pointwise_binary_const(self): | 
|  | const = torch.randn(1, 4, 6, 6) | 
|  |  | 
|  | class ArgPlusConst(torch.nn.Module): | 
|  | def forward(self, arg): | 
|  | return arg + const | 
|  |  | 
|  | class ConstPlusArg(torch.nn.Module): | 
|  | def forward(self, arg): | 
|  | return const + arg | 
|  |  | 
|  | arg_contig = torch.randn(2, 4, 6, 6) | 
|  | arg_nhwc = nhwc(torch.randn(2, 4, 6, 6)) | 
|  |  | 
|  | for mod_class in [ArgPlusConst, ConstPlusArg]: | 
|  | for use_nhwc in [False, True]: | 
|  | with self.subTest(mod_class=mod_class.__name__, use_nhwc=use_nhwc): | 
|  | arg = arg_nhwc if use_nhwc else arg_contig | 
|  | memory_format = torch.channels_last if use_nhwc else torch.contiguous_format | 
|  | self.check(mod_class(), arg, | 
|  | expected_memory_format=memory_format) | 
|  |  | 
|  | def test_hardtanh(self): | 
|  | inp = torch.tensor([-2.0, -0.5, 0.5, 2.0, 7.0]) | 
|  | self.check(torch.nn.Hardtanh(), inp) | 
|  | self.check(torch.nn.Hardtanh(0.0, 6.0), inp) | 
|  | with self.assertRaisesRegex(Exception, "hardtanh with args"): | 
|  | self.check(torch.nn.Hardtanh(0.0, 5.0), inp) | 
|  |  | 
|  | def test_softmax(self): | 
|  | inp = torch.tensor([[-2.0, -0.5], [0.5, 2.0]]) | 
|  | self.check(torch.nn.Softmax(), inp) | 
|  | self.check(torch.nn.Softmax(dim=0), inp) | 
|  | # Test flexible size | 
|  | self.check( | 
|  | torch.nn.Softmax(), | 
|  | inp, | 
|  | convert_args=[torch.zeros(0, 0)], | 
|  | ) | 
|  |  | 
|  | def test_to(self): | 
|  | class ToCPU(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.prelu = torch.nn.PReLU() | 
|  |  | 
|  | def forward(self, x): | 
|  | y = x.to("cpu") | 
|  | # add prelu since input operand can't be output | 
|  | return self.prelu(y) | 
|  |  | 
|  | arg = torch.randn(1, 2, 3, 3) | 
|  | self.check(ToCPU(), arg) | 
|  | # Test flexible size | 
|  | self.check( | 
|  | ToCPU(), | 
|  | arg, | 
|  | convert_args=[torch.zeros(1, 2, 0, 0)], | 
|  | ) | 
|  |  | 
|  | def test_detach(self): | 
|  | class DetachModule(torch.nn.Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  |  | 
|  | def forward(self, x): | 
|  | y = x.detach() | 
|  | return torch.nn.functional.relu(y) | 
|  |  | 
|  | self.check(DetachModule(), torch.randn(1, 2, 3, 3)) | 
|  | self.check( | 
|  | DetachModule(), torch.randn(1, 2, 3, 3), | 
|  | convert_args=[torch.zeros(1, 2, 0, 0)]) | 
|  |  | 
|  | def test_log_softmax(self): | 
|  | inp = torch.randn(3, 10) | 
|  | self.check(torch.nn.LogSoftmax(), inp) | 
|  | self.check(torch.nn.LogSoftmax(0), inp) | 
|  |  | 
|  | def test_mean(self): | 
|  | class MeanModule(torch.nn.Module): | 
|  | def __init__(self, dim, keep=False): | 
|  | super().__init__() | 
|  | self.dim = dim | 
|  | self.keep = keep | 
|  |  | 
|  | def forward(self, t): | 
|  | return torch.mean(t, dim=self.dim, keepdim=self.keep) | 
|  |  | 
|  | self.check(MeanModule(0), torch.randn(2, 3)) | 
|  | self.check(MeanModule(1), torch.randn(2, 3)) | 
|  | self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6)) | 
|  | self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6))) | 
|  | self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6))) | 
|  | self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6))) | 
|  |  | 
|  | def test_max_pool2d(self): | 
|  | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): | 
|  | with self.subTest(name): | 
|  | self.check(torch.nn.MaxPool2d(2), inp) | 
|  | self.check(torch.nn.MaxPool2d((3, 4)), inp) | 
|  | self.check(torch.nn.MaxPool2d((3, 4), (1, 2)), inp) | 
|  |  | 
|  | def test_avg_pool2d(self): | 
|  | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): | 
|  | with self.subTest(name): | 
|  | atol_rtol = None | 
|  | limit = None | 
|  | convert_dims = (2, 3, 0, 0) | 
|  | convert_arg = torch.zeros(*convert_dims) | 
|  |  | 
|  | for model in ( | 
|  | torch.nn.AvgPool2d(2), | 
|  | torch.nn.AvgPool2d((3, 4)), | 
|  | torch.nn.AvgPool2d((3, 4), (1, 2))): | 
|  | if "quant" in name: | 
|  | atol_rtol = (1, 0) | 
|  | limit = model(inp).numel() | 
|  | convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) | 
|  | if "nhwc" in name: | 
|  | convert_arg = nhwc(convert_arg) | 
|  |  | 
|  | self.check(model, inp, atol_rtol=atol_rtol, limit=limit) | 
|  | self.check( | 
|  | model, | 
|  | inp, | 
|  | convert_args=[convert_arg], | 
|  | atol_rtol=atol_rtol, | 
|  | limit=limit | 
|  | ) | 
|  |  | 
|  | def test_adaptive_avg_pool2d(self): | 
|  | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): | 
|  | with self.subTest(name): | 
|  | self.check(torch.nn.AdaptiveAvgPool2d((1, 1)), inp) | 
|  | with self.assertRaisesRegex(Exception, "with output size"): | 
|  | self.check(torch.nn.AdaptiveAvgPool2d((2, 2)), inp) | 
|  |  | 
|  | def test_upsample_nearest2d(self): | 
|  | convert_args = dict(self.float_and_quant_and_nhwc(torch.randn(2, 3, 0, 0), 0.3, 128)) | 
|  | for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128): | 
|  | with self.subTest(name): | 
|  | self.check(torch.nn.UpsamplingNearest2d(size=(16, 20)), inp) | 
|  | self.check(torch.nn.UpsamplingNearest2d(size=(24, 32)), inp) | 
|  | self.check(torch.nn.UpsamplingNearest2d(size=(36, 48)), inp) | 
|  | self.check(torch.nn.UpsamplingNearest2d(scale_factor=(1.5, 1.5)), inp) | 
|  | self.check(torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp) | 
|  | self.check(torch.nn.UpsamplingNearest2d(scale_factor=(3.0, 3.0)), inp) | 
|  |  | 
|  | self.check( | 
|  | torch.nn.UpsamplingNearest2d(size=(24, 32)), inp, | 
|  | convert_args=[convert_args[name]] | 
|  | ) | 
|  | self.check( | 
|  | torch.nn.UpsamplingNearest2d(scale_factor=(2.0, 2.0)), inp, | 
|  | convert_args=[convert_args[name]] | 
|  | ) | 
|  |  | 
|  | def test_linear(self): | 
|  | torch.manual_seed(29) | 
|  | self.check(torch.nn.Linear(16, 32), torch.randn(2, 16)) | 
|  | self.check( | 
|  | torch.nn.Linear(16, 32), torch.randn(2, 16), | 
|  | convert_args=[torch.zeros(0, 16)]) | 
|  |  | 
|  | def test_conv2d(self): | 
|  | cases = [ | 
|  | # in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim,      name | 
|  | ( 4,     8,      (3, 3), 1,      0,       1,      1,    (2, 4, 16, 16), "3x3"),        # noqa: E201,E241 | 
|  | ( 4,     8,      (3, 3), 1,      0,       1,      0,    (2, 4, 16, 16), "3x3nobias"),  # noqa: E201,E241 | 
|  | ( 4,     16,     (3, 3), 1,      1,       1,      1,    (2, 4, 16, 16), "3x3p1"),      # noqa: E201,E241 | 
|  | ( 8,     8,      (3, 3), 2,      0,       1,      1,    (2, 8, 16, 16), "3x3s2"),      # noqa: E201,E241 | 
|  | ( 4,     8,      (5, 5), 1,      0,       1,      1,    (2, 4, 16, 16), "5x5"),        # noqa: E201,E241 | 
|  | ( 4,     4,      (3, 3), 1,      0,       4,      1,    (2, 4, 16, 16), "3x3dw"),      # noqa: E201,E241 | 
|  | ( 8,     4,      (1, 1), 1,      0,       1,      1,    (2, 8, 16, 16), "1x1"),        # noqa: E201,E241 | 
|  | ] | 
|  |  | 
|  | for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: | 
|  | for case in cases: | 
|  | in_ch, out_ch, kernel, stride, padding, groups, bias, input_dim, name = case | 
|  | with self.subTest("{}-{}".format(kind, name)): | 
|  | inp = torch.randn(input_dim) | 
|  | model = torch.nn.Conv2d(in_ch, out_ch, kernel, stride, padding, groups=groups, bias=bool(bias)) | 
|  | output_size = model(inp).numel() | 
|  | atol_rtol = None | 
|  | limit = None | 
|  | convert_dims = (0, in_ch, 0, 0) | 
|  | convert_arg = torch.zeros(*convert_dims) | 
|  |  | 
|  | if "quant" in kind: | 
|  | model = torch.nn.Sequential(model) | 
|  | model.eval() | 
|  | model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') | 
|  | model = torch.ao.quantization.prepare(model) | 
|  | model(inp) | 
|  | model = torch.ao.quantization.convert(model) | 
|  | inp = qpt(inp, 1.0 / 16, 128) | 
|  | # I've seen numerical differences between QNNPACK and NNAPI, | 
|  | # but never more than 1 quantum, and never more than ~1% of | 
|  | # the output in this test. | 
|  | atol_rtol = (1, 0) | 
|  | limit = output_size * 0.03 | 
|  | convert_arg = qpt(torch.zeros(*convert_dims), 1.0 / 16, 128) | 
|  |  | 
|  | if "nhwc" in kind: | 
|  | inp = nhwc(inp) | 
|  | convert_arg = nhwc(convert_arg) | 
|  |  | 
|  | self.check(model, inp, atol_rtol=atol_rtol, limit=limit) | 
|  | self.check( | 
|  | model, | 
|  | inp, | 
|  | convert_args=[convert_arg], | 
|  | atol_rtol=atol_rtol, | 
|  | limit=limit | 
|  | ) | 
|  |  | 
|  | def test_conv2d_transpose(self): | 
|  | torch.manual_seed(29) | 
|  | in_ch, out_ch, kernel = (5, 7, (2, 2)) | 
|  | input_dim = (4, 5, 3, 3) | 
|  | convert_dims = input_dim[:2] + (0, 0) | 
|  |  | 
|  | for kind in ["float", "float-nhwc", "quant", "quant-nhwc"]: | 
|  | with self.subTest(kind): | 
|  | inp = torch.randn(input_dim) | 
|  | model = torch.nn.ConvTranspose2d(in_ch, out_ch, kernel) | 
|  | output_size = model(inp).numel() | 
|  | atol_rtol = (0.0002, 0) | 
|  | limit = None | 
|  | convert_arg = torch.zeros(*convert_dims) | 
|  |  | 
|  | if "quant" in kind: | 
|  | model = torch.nn.quantized.ConvTranspose2d(in_ch, out_ch, kernel) | 
|  | model.qconfig = torch.ao.quantization.get_default_qconfig('qnnpack') | 
|  | inp = qpt(inp, 1.0 / 16, 128) | 
|  | # I've seen numerical differences between QNNPACK and NNAPI, | 
|  | # but never more than 1 quantum, and never more than ~10% of | 
|  | # the output in this test. | 
|  | atol_rtol = (1, 0) | 
|  | limit = output_size * 0.1 | 
|  | convert_arg = qpt(convert_arg, 1.0 / 16, 128) | 
|  |  | 
|  | if "nhwc" in kind: | 
|  | inp = nhwc(inp) | 
|  | convert_arg = nhwc(convert_arg) | 
|  |  | 
|  | self.check(model, inp, atol_rtol=atol_rtol, limit=limit) | 
|  | self.check( | 
|  | model, | 
|  | inp, | 
|  | convert_args=[convert_arg], | 
|  | atol_rtol=atol_rtol, | 
|  | limit=limit | 
|  | ) | 
|  |  | 
|  |  | 
|  | def test_qadd(self): | 
|  | func = torch.nn.quantized.QFunctional() | 
|  | func.scale = 0.5 | 
|  | func.zero_point = 120 | 
|  |  | 
|  | class AddMod(torch.nn.Module): | 
|  | def forward(self, lhs, rhs): | 
|  | return func.add(lhs, rhs) | 
|  |  | 
|  | class AddReluMod(torch.nn.Module): | 
|  | def forward(self, lhs, rhs): | 
|  | return func.add_relu(lhs, rhs) | 
|  |  | 
|  | class MulMod(torch.nn.Module): | 
|  | def forward(self, lhs, rhs): | 
|  | return func.mul(lhs, rhs) | 
|  |  | 
|  | for (name, mod) in [("add", AddMod), ("add_relu", AddReluMod), ("mul", MulMod)]: | 
|  | with self.subTest(name): | 
|  | self.check( | 
|  | mod(), | 
|  | [ | 
|  | qpt([1.0, 2.0], 0.25, 128), | 
|  | qpt([3.0, 4.0], 0.25, 128), | 
|  | ]) | 
|  | self.check( | 
|  | mod(), | 
|  | [ | 
|  | qpt([[1.0, 2.0]], 0.25, 128), | 
|  | qpt([[3.0, 4.0]], 0.25, 128), | 
|  | ], | 
|  | convert_args=[ | 
|  | qpt([[1.0, 2.0]], 0.25, 128), | 
|  | qpt(torch.zeros((1, 2)), 0.25, 128), | 
|  | ] | 
|  | ) | 
|  | self.check( | 
|  | mod(), | 
|  | [ | 
|  | qpt([[1.0, 2.0]], 0.25, 128), | 
|  | qpt([[3.0, 4.0]], 0.25, 128), | 
|  | ], | 
|  | convert_args=[ | 
|  | qpt(torch.zeros((1, 2)), 0.25, 128), | 
|  | qpt([[3.0, 4.0]], 0.25, 128), | 
|  | ] | 
|  | ) | 
|  | self.check( | 
|  | mod(), | 
|  | [ | 
|  | qpt([[1.0, 2.0]], 0.25, 128), | 
|  | qpt([[3.0, 4.0]], 0.25, 128), | 
|  | ], | 
|  | convert_args=[ | 
|  | qpt(torch.zeros((1, 2)), 0.25, 128), | 
|  | qpt(torch.zeros((1, 2)), 0.25, 128), | 
|  | ] | 
|  | ) | 
|  | # NOTE: NNAPI qadd supports broadcast, but PT does not. | 
|  |  | 
|  | def test_qlinear(self): | 
|  | torch.manual_seed(29) | 
|  | weight = qpt(torch.randn(16, 32), 0.125, 0, torch.qint8) | 
|  | bias = torch.randn(16) | 
|  | mod = torch.nn.quantized.Linear(32, 16) | 
|  | mod.set_weight_bias(weight, bias) | 
|  | inp = qpt(torch.randn(2, 32), 0.05, 130, torch.quint8) | 
|  | self.check(mod, inp) | 
|  |  | 
|  | def test_seblock_mul(self): | 
|  | class MulModel(torch.nn.Module): | 
|  | def forward(self, lhs, rhs): | 
|  | return lhs * rhs | 
|  |  | 
|  | self.check( | 
|  | MulModel(), | 
|  | [ | 
|  | nhwc(torch.randn(2, 3, 4, 4)), | 
|  | torch.randn(1, 3, 1, 1), | 
|  | ]) | 
|  |  | 
|  | def test_multi_output(self): | 
|  | class MultiModel(torch.nn.Module): | 
|  | def forward(self, lhs, rhs) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | the_sum = lhs + rhs | 
|  | the_diff = lhs - rhs | 
|  | return the_sum, the_diff | 
|  |  | 
|  | self.check(MultiModel(), [torch.tensor([1.0, 2.0]), torch.tensor([1.0, 3.0])]) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |