blob: 1031852176f4416087ab6c1a89645bc6e48a5eaf [file]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from executorch.backends.xnnpack.test.tester import Tester
class TestMaxPool2d(unittest.TestCase):
class MaxPool2d(torch.nn.Module):
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
super().__init__()
self.max_pool2d_module = torch.nn.MaxPool2d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
)
def forward(self, x):
return self.max_pool2d_module(x)
class MaxPool2dUnsupported(torch.nn.Module):
def __init__(self, kernel_size=3, stride=1, padding=0, dilation=1):
super().__init__()
self.max_pool2d_module = torch.nn.MaxPool2d(
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
return_indices=True,
)
def forward(self, x):
return self.max_pool2d_module(x)[1]
class MaxPool2dUnsupportedCeilMode(torch.nn.Module):
def __init__(self):
super().__init__()
self.max_pool2d_module = torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
def forward(self, x):
return self.max_pool2d_module(x)
def _test_maxpool2d(self, inputs):
"""
Note that the export process generates aten.max_pool2d_with_indices. The remove_getitem_op
pass transforms it into aten.max_pool2d (if supported).
"""
(
Tester(self.MaxPool2d(3, 1, 0, 1), inputs)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default"
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
def test_fp16_maxpool2d(self):
inputs = (torch.randn(4, 3, 24, 24).to(torch.float16),)
self._test_maxpool2d(inputs)
def test_fp32_maxpool2d(self):
inputs = (torch.randn(4, 3, 24, 24),)
self._test_maxpool2d(inputs)
def test_fp32_maxpool2d_unsupported(self):
"""
MaxPool2d with return_indices is not generally supported (see maxpool2d_with_indices constraint).
"""
inputs = (torch.randn(4, 3, 24, 24),)
(
Tester(self.MaxPool2dUnsupported(), inputs)
.export()
.check_count({"torch.ops.aten.max_pool2d_with_indices.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
.check_count(
{
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
}
)
)
def test_fp32_maxpool2d_unsupported_ceilmode(self):
"""
MaxPool2d with ceil mode is not generally supported (see maxpool2d constraint).
"""
inputs = (torch.randn(1, 32, 23, 23),)
(
Tester(self.MaxPool2dUnsupportedCeilMode(), inputs)
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.to_edge_transform_and_lower()
# We expect it not be be delegated.
.check_count({"torch.ops.higher_order.executorch_call_delegate": 0})
.check_count(
{
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default": 1
}
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
def test_qs8_maxpool2d(self):
class MaxPool(torch.nn.Module):
def __init__(self, maxpool_params):
super().__init__()
self.max = torch.nn.MaxPool2d(*maxpool_params)
def forward(self, x):
z = x + x
return self.max(z)
# Parameter order is kernel_size, stride, padding.
for maxpool_params in [(4,), (4, 2), (4, 2, 2)]:
inputs = (torch.randn(1, 2, 8, 8),)
(
Tester(MaxPool(maxpool_params), inputs)
.quantize()
.export()
.check_count({"torch.ops.aten.max_pool2d.default": 1})
.check(["torch.ops.quantized_decomposed"])
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_max_pool2d_with_indices_default",
"torch.ops.quantized_decomposed",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)