blob: d88f88724bd9042473957add40a5186bd0968041 [file]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import itertools
import unittest
from typing import Optional
import torch
from executorch.backends.xnnpack.test.test_xnnpack_utils import randomize_bn
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
)
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import QuantizationConfig
class Conv2d(torch.nn.Module):
def __init__(
self,
in_channels=2,
out_channels=1,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
dilation=(1, 1),
groups=1,
bias=True,
padding_mode="zeros",
batches=1,
width=8,
height=8,
dtype=torch.float,
):
super().__init__()
self.batches = batches
self.width = width
self.height = height
self.in_channels = in_channels
self.dtype = dtype
self.conv = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
).to(dtype)
def forward(self, x):
return self.conv(x)
def get_inputs(self):
return (
torch.randn(self.batches, self.in_channels, self.height, self.width).to(
self.dtype
),
)
class Conv2dSeq(torch.nn.Module):
def __init__(self):
super().__init__()
self.first = torch.nn.Conv2d(
in_channels=1,
out_channels=3,
kernel_size=(3, 3),
padding=1,
bias=False,
)
self.second = torch.nn.Conv2d(
in_channels=3,
out_channels=2,
kernel_size=(3, 3),
padding=1,
bias=False,
)
def forward(self, x):
y = self.first(x)
return self.second(y)
def get_inputs(self):
return (torch.randn(1, 1, 3, 3),)
class Conv2dBatchNorm(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(
2,
2,
(2, 2),
bias=False,
padding=[1, 1],
stride=[4, 4],
)
self.bn = randomize_bn(2)
self.hardtanh = torch.nn.Hardtanh()
self.conv2 = torch.nn.Conv2d(
2,
2,
(2, 2),
bias=False,
padding=[1, 1],
stride=[4, 4],
)
def forward(self, x):
y = self.conv1(x)
y = self.bn(y)
y = self.hardtanh(y)
y = self.conv2(y)
y = self.bn(y)
y = self.hardtanh(y)
return y
def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)
class Conv2dPermute(torch.nn.Module):
def __init__(self, permute_order):
super().__init__()
self.conv = torch.nn.Conv2d(
2,
2,
(2, 2),
bias=False,
padding=[2, 2],
stride=[2, 2],
)
self.permute_order = permute_order
def forward(self, x):
result = self.conv(x)
channels_last = torch.permute(result, self.permute_order)
return channels_last
def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)
class TestConv2d(unittest.TestCase):
def _test(
self,
m: torch.nn.Module,
quant_config: Optional[QuantizationConfig] = None,
conv_count=1,
dtype: torch.dtype = torch.float,
):
# pyre-fixme[29]: `Union[torch._tensor.Tensor,
# torch.nn.modules.module.Module]` is not a function.
tester = Tester(m.eval(), m.get_inputs())
if quant_config is not None:
tester = tester.quantize(Quantize(quantization_config=quant_config))
tester.check(["torch.ops.quantized_decomposed"])
(
tester.export()
.check_count({"torch.ops.aten.conv2d": conv_count})
.to_edge_transform_and_lower()
.check_not(["executorch_exir_dialects_edge__ops_aten_convolution_default"])
.check_not(
[
"executorch_exir_dialects_edge__ops__native_batch_norm_legit_no_training_default"
]
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.to_executorch()
.serialize()
.run_method_and_compare_outputs(qtol=1)
)
def test_fp16_conv2d(self) -> None:
for has_bias in (True, False):
self._test(Conv2d(bias=has_bias, dtype=torch.float16))
def test_fp32_conv2d(self) -> None:
for has_bias in (True, False):
self._test(Conv2d(bias=has_bias))
def test_fp32_conv2d_permute(self) -> None:
for perm_order in list(itertools.permutations([0, 1, 2, 3])):
self._test(Conv2dPermute(perm_order))
def test_qs8_conv2d_test(self) -> None:
for has_bias in (True, False):
self._test(
Conv2d(bias=has_bias), quant_config=get_symmetric_quantization_config()
)
def test_qs8_conv2d_per_channel(self) -> None:
self._test(
Conv2d(),
quant_config=get_symmetric_quantization_config(is_per_channel=True),
)
def test_fp32_conv2d_seq(self) -> None:
self._test(Conv2dSeq(), conv_count=2)
def test_qs8_conv2d_seq(self) -> None:
self._test(
Conv2dSeq(), conv_count=2, quant_config=get_symmetric_quantization_config()
)
def test_fp32_conv2d_single_int_params(self):
self._test(
Conv2d(
kernel_size=3,
stride=2,
padding="valid",
dilation=1,
)
)
def test_fp32_conv2d_depthwise(self):
# Depthwise Convolution Requirements:
# - Groups must equal In Channels
# - Out Channels must be a positive multiple of In Channels
self._test(Conv2d(groups=2, in_channels=2, out_channels=6))
def test_qs8_conv2d_depthwise(self):
self._test(
Conv2d(groups=2, in_channels=2, out_channels=6),
quant_config=get_symmetric_quantization_config(),
)
def test_fp32_conv2d_bn(self):
class Conv2dBatchNorm(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, kernel_size):
super().__init__()
self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
self.bn = randomize_bn(out_features)
self.in_features = in_features
self.kernel_size = kernel_size
def forward(self, x):
y = self.conv2d(x)
y = self.bn(y)
return y
def get_inputs(self):
return (
torch.randn(
2,
self.in_features,
self.kernel_size[0] * 2,
self.kernel_size[1] * 2,
),
)
self._test(Conv2dBatchNorm(in_features=2, out_features=2, kernel_size=(2, 2)))
def test_fp32_conv2d_bn_hardtanh_mean_sequence(self):
"""
This test makes sure that we can fuse batchnorm and hardtanh
even with inserting copy nodes at some spots in the graph to change
memory format
"""
class Conv2dBatchNormHardTanh(torch.nn.Module):
def __init__(self, in_channels: int, out_channels: int, kernel_size):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
padding=[1, 1],
stride=[2, 2],
)
self.in_channels = in_channels
self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
def forward(self, x):
x = self.conv(x)
x = self.native_batchnorm(x)
x = self.hardtanh(x)
x = torch.mean(x, (-1, -2), keepdim=True)
return x
def get_inputs(self):
return (torch.randn(2, self.in_channels, 8, 8),)
self._test(
Conv2dBatchNormHardTanh(in_channels=2, out_channels=1, kernel_size=(2, 2))
)
def test_qs8_conv2d_bn(self):
self._test(
Conv2dBatchNorm(),
quant_config=get_symmetric_quantization_config(),
conv_count=2,
)
def test_qs8_conv2d_relu(self):
class ConvReLU(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(
2,
2,
(2, 2),
bias=False,
padding=[1, 1],
stride=[4, 4],
)
self.relu = torch.nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu(y)
return y
def get_inputs(self):
return (torch.randn(2, 2, 4, 4),)
self._test(
ConvReLU(),
quant_config=get_symmetric_quantization_config(),
)
def test_qs8_conv2d_dw_relu(self):
# Depthwise Convolution Requirements:
# - Groups must equal In Channels
# - Out Channels must be a positive multiple of In Channels
groups = 2
stride = [2, 2]
padding = [1, 1]
dilation = [1, 1]
in_channels = groups
out_channels = 3 * in_channels
width = 8
height = 8
batches = 1
class ModelConvReLU(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 3),
stride=stride,
padding=padding,
groups=groups,
dilation=dilation,
bias=True,
)
self.relu = torch.nn.ReLU()
def forward(self, x):
y = self.conv1(x)
y = self.relu(y)
return y
def get_inputs(self):
return (torch.randn(batches, in_channels, height, width) * 11,)
for per_channel_quant in (False, True):
model = ModelConvReLU()
self._test(
model,
quant_config=get_symmetric_quantization_config(
is_per_channel=per_channel_quant
),
)
def test_qs8_conv2d_relu_seq(self):
class ConvReLUSeq(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.nn.Sequential(
torch.nn.Conv2d(1, 1, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(1, 64, 1),
torch.nn.ReLU(),
)
def forward(self, x):
return self.model(x)
def get_inputs(self):
return (torch.randn(1, 1, 1, 1),)
self._test(
ConvReLUSeq(),
quant_config=get_symmetric_quantization_config(),
conv_count=2,
)
def test_qs8_conv2d_relu_multi_users(self):
class Conv2dReluMultiUsers(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 1, 1)
self.conv2 = torch.nn.Conv2d(1, 64, 1)
self.relu = torch.nn.ReLU()
def forward(self, x):
conv_default = self.conv1(x)
y = self.relu(conv_default)
conv_default_2 = self.conv2(y)
return conv_default + conv_default_2
def get_inputs(self):
return (torch.randn(1, 1, 1, 1),)
self._test(
Conv2dReluMultiUsers(),
quant_config=get_symmetric_quantization_config(),
conv_count=2,
)