| # Copyright (c) Meta Platforms, Inc. and affiliates. |
| # All rights reserved. |
| # |
| # This source code is licensed under the BSD-style license found in the |
| # LICENSE file in the root directory of this source tree. |
| |
| import itertools |
| import unittest |
| from typing import Optional |
| |
| import torch |
| from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn |
| from executorch.backends.xnnpack.test.tester import Quantize, Tester |
| from torch.ao.quantization.quantizer.xnnpack_quantizer import ( |
| get_symmetric_quantization_config, |
| ) |
| from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig |
| |
| |
| class Conv2d(torch.nn.Module): |
| def __init__( |
| self, |
| in_channels=2, |
| out_channels=1, |
| kernel_size=(3, 3), |
| stride=(2, 2), |
| padding=(1, 1), |
| dilation=(1, 1), |
| groups=1, |
| bias=True, |
| padding_mode="zeros", |
| batches=1, |
| width=8, |
| height=8, |
| dtype=torch.float, |
| ): |
| super().__init__() |
| self.batches = batches |
| self.width = width |
| self.height = height |
| self.in_channels = in_channels |
| self.dtype = dtype |
| |
| self.conv = torch.nn.Conv2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| padding_mode=padding_mode, |
| ).to(dtype) |
| |
| def forward(self, x): |
| return self.conv(x) |
| |
| def get_inputs(self): |
| return ( |
| torch.randn(self.batches, self.in_channels, self.height, self.width).to( |
| self.dtype |
| ), |
| ) |
| |
| |
| class Conv2dSeq(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.first = torch.nn.Conv2d( |
| in_channels=1, |
| out_channels=3, |
| kernel_size=(3, 3), |
| padding=1, |
| bias=False, |
| ) |
| self.second = torch.nn.Conv2d( |
| in_channels=3, |
| out_channels=2, |
| kernel_size=(3, 3), |
| padding=1, |
| bias=False, |
| ) |
| |
| def forward(self, x): |
| y = self.first(x) |
| return self.second(y) |
| |
| def get_inputs(self): |
| return (torch.randn(1, 1, 3, 3),) |
| |
| |
| class Conv2dBatchNorm(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d( |
| 2, |
| 2, |
| (2, 2), |
| bias=False, |
| padding=[1, 1], |
| stride=[4, 4], |
| ) |
| self.bn = randomize_bn(2) |
| self.hardtanh = torch.nn.Hardtanh() |
| self.conv2 = torch.nn.Conv2d( |
| 2, |
| 2, |
| (2, 2), |
| bias=False, |
| padding=[1, 1], |
| stride=[4, 4], |
| ) |
| |
| def forward(self, x): |
| y = self.conv1(x) |
| y = self.bn(y) |
| y = self.hardtanh(y) |
| y = self.conv2(y) |
| y = self.bn(y) |
| y = self.hardtanh(y) |
| return y |
| |
| def get_inputs(self): |
| return (torch.randn(2, 2, 4, 4),) |
| |
| |
| class Conv2dPermute(torch.nn.Module): |
| def __init__(self, permute_order): |
| super().__init__() |
| self.conv = torch.nn.Conv2d( |
| 2, |
| 2, |
| (2, 2), |
| bias=False, |
| padding=[2, 2], |
| stride=[2, 2], |
| ) |
| self.permute_order = permute_order |
| |
| def forward(self, x): |
| result = self.conv(x) |
| channels_last = torch.permute(result, self.permute_order) |
| return channels_last |
| |
| def get_inputs(self): |
| return (torch.randn(2, 2, 4, 4),) |
| |
| |
| class TestConv2d(unittest.TestCase): |
| def _test( |
| self, |
| m: torch.nn.Module, |
| quant_config: Optional[QuantizationConfig] = None, |
| conv_count=1, |
| dtype: torch.dtype = torch.float, |
| ): |
| # pyre-fixme[29]: `Union[torch._tensor.Tensor, |
| # torch.nn.modules.module.Module]` is not a function. |
| tester = Tester(m.eval(), m.get_inputs()) |
| |
| if quant_config is not None: |
| tester = tester.quantize(Quantize(quantization_config=quant_config)) |
| tester.check(["torch.ops.quantized_decomposed"]) |
| |
| ( |
| tester.export() |
| .check_count({"torch.ops.aten.conv2d": conv_count}) |
| .to_edge_transform_and_lower() |
| .check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"]) |
| .check_not( |
| [ |
| "executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default" |
| ] |
| ) |
| .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) |
| .to_executorch() |
| .serialize() |
| .run_method_and_compare_outputs(qtol=1) |
| ) |
| |
| def test_fp16_conv2d(self) -> None: |
| for has_bias in (True, False): |
| self._test(Conv2d(bias=has_bias, dtype=torch.float16)) |
| |
| def test_fp32_conv2d(self) -> None: |
| for has_bias in (True, False): |
| self._test(Conv2d(bias=has_bias)) |
| |
| def test_fp32_conv2d_permute(self) -> None: |
| for perm_order in list(itertools.permutations([0, 1, 2, 3])): |
| self._test(Conv2dPermute(perm_order)) |
| |
| def test_qs8_conv2d_test(self) -> None: |
| for has_bias in (True, False): |
| self._test( |
| Conv2d(bias=has_bias), quant_config=get_symmetric_quantization_config() |
| ) |
| |
| def test_qs8_conv2d_per_channel(self) -> None: |
| self._test( |
| Conv2d(), |
| quant_config=get_symmetric_quantization_config(is_per_channel=True), |
| ) |
| |
| def test_fp32_conv2d_seq(self) -> None: |
| self._test(Conv2dSeq(), conv_count=2) |
| |
| def test_qs8_conv2d_seq(self) -> None: |
| self._test( |
| Conv2dSeq(), conv_count=2, quant_config=get_symmetric_quantization_config() |
| ) |
| |
| def test_fp32_conv2d_single_int_params(self): |
| self._test( |
| Conv2d( |
| kernel_size=3, |
| stride=2, |
| padding="valid", |
| dilation=1, |
| ) |
| ) |
| |
| def test_fp32_conv2d_depthwise(self): |
| # Depthwise Convolution Requirements: |
| # - Groups must equal In Channels |
| # - Out Channels must be a positive multiple of In Channels |
| self._test(Conv2d(groups=2, in_channels=2, out_channels=6)) |
| |
| def test_qs8_conv2d_depthwise(self): |
| self._test( |
| Conv2d(groups=2, in_channels=2, out_channels=6), |
| quant_config=get_symmetric_quantization_config(), |
| ) |
| |
| def test_fp32_conv2d_bn(self): |
| class Conv2dBatchNorm(torch.nn.Module): |
| def __init__(self, in_features: int, out_features: int, kernel_size): |
| super().__init__() |
| self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) |
| self.bn = randomize_bn(out_features) |
| self.in_features = in_features |
| self.kernel_size = kernel_size |
| |
| def forward(self, x): |
| y = self.conv2d(x) |
| y = self.bn(y) |
| return y |
| |
| def get_inputs(self): |
| return ( |
| torch.randn( |
| 2, |
| self.in_features, |
| self.kernel_size[0] * 2, |
| self.kernel_size[1] * 2, |
| ), |
| ) |
| |
| self._test(Conv2dBatchNorm(in_features=2, out_features=2, kernel_size=(2, 2))) |
| |
| def test_fp32_conv2d_bn_hardtanh_mean_sequence(self): |
| """ |
| This test makes sure that we can fuse batchnorm and hardtanh |
| even with inserting copy nodes at some spots in the graph to change |
| memory format |
| """ |
| |
| class Conv2dBatchNormHardTanh(torch.nn.Module): |
| def __init__(self, in_channels: int, out_channels: int, kernel_size): |
| super().__init__() |
| self.conv = torch.nn.Conv2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| padding=[1, 1], |
| stride=[2, 2], |
| ) |
| self.in_channels = in_channels |
| self.native_batchnorm = torch.nn.BatchNorm2d(out_channels) |
| self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6) |
| |
| def forward(self, x): |
| x = self.conv(x) |
| x = self.native_batchnorm(x) |
| x = self.hardtanh(x) |
| x = torch.mean(x, (-1, -2), keepdim=True) |
| return x |
| |
| def get_inputs(self): |
| return (torch.randn(2, self.in_channels, 8, 8),) |
| |
| self._test( |
| Conv2dBatchNormHardTanh(in_channels=2, out_channels=1, kernel_size=(2, 2)) |
| ) |
| |
| def test_qs8_conv2d_bn(self): |
| self._test( |
| Conv2dBatchNorm(), |
| quant_config=get_symmetric_quantization_config(), |
| conv_count=2, |
| ) |
| |
| def test_qs8_conv2d_relu(self): |
| class ConvReLU(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d( |
| 2, |
| 2, |
| (2, 2), |
| bias=False, |
| padding=[1, 1], |
| stride=[4, 4], |
| ) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| y = self.conv1(x) |
| y = self.relu(y) |
| return y |
| |
| def get_inputs(self): |
| return (torch.randn(2, 2, 4, 4),) |
| |
| self._test( |
| ConvReLU(), |
| quant_config=get_symmetric_quantization_config(), |
| ) |
| |
| def test_qs8_conv2d_dw_relu(self): |
| # Depthwise Convolution Requirements: |
| # - Groups must equal In Channels |
| # - Out Channels must be a positive multiple of In Channels |
| groups = 2 |
| stride = [2, 2] |
| padding = [1, 1] |
| dilation = [1, 1] |
| in_channels = groups |
| out_channels = 3 * in_channels |
| width = 8 |
| height = 8 |
| batches = 1 |
| |
| class ModelConvReLU(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=(3, 3), |
| stride=stride, |
| padding=padding, |
| groups=groups, |
| dilation=dilation, |
| bias=True, |
| ) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| y = self.conv1(x) |
| y = self.relu(y) |
| return y |
| |
| def get_inputs(self): |
| return (torch.randn(batches, in_channels, height, width) * 11,) |
| |
| for per_channel_quant in (False, True): |
| model = ModelConvReLU() |
| self._test( |
| model, |
| quant_config=get_symmetric_quantization_config( |
| is_per_channel=per_channel_quant |
| ), |
| ) |
| |
| def test_qs8_conv2d_relu_seq(self): |
| class ConvReLUSeq(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.model = torch.nn.Sequential( |
| torch.nn.Conv2d(1, 1, 1), |
| torch.nn.ReLU(), |
| torch.nn.Conv2d(1, 64, 1), |
| torch.nn.ReLU(), |
| ) |
| |
| def forward(self, x): |
| return self.model(x) |
| |
| def get_inputs(self): |
| return (torch.randn(1, 1, 1, 1),) |
| |
| self._test( |
| ConvReLUSeq(), |
| quant_config=get_symmetric_quantization_config(), |
| conv_count=2, |
| ) |
| |
| def test_qs8_conv2d_relu_multi_users(self): |
| class Conv2dReluMultiUsers(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.conv1 = torch.nn.Conv2d(1, 1, 1) |
| self.conv2 = torch.nn.Conv2d(1, 64, 1) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| conv_default = self.conv1(x) |
| y = self.relu(conv_default) |
| conv_default_2 = self.conv2(y) |
| return conv_default + conv_default_2 |
| |
| def get_inputs(self): |
| return (torch.randn(1, 1, 1, 1),) |
| |
| self._test( |
| Conv2dReluMultiUsers(), |
| quant_config=get_symmetric_quantization_config(), |
| conv_count=2, |
| ) |