blob: 899a119ed448c570ce6d0f53e81d524a79f2e449 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from executorch.backends.xnnpack.test.tester import Tester
class TestHardswish(unittest.TestCase):
class Hardswish(torch.nn.Module):
def __init__(self):
super().__init__()
self.hardswish = torch.nn.Hardswish()
def forward(self, x):
return self.hardswish(x)
class HardswishFunctional(torch.nn.Module):
def forward(self, x):
return torch.nn.functional.hardswish(x)
def _test_hardswish(self, inputs):
(
Tester(self.Hardswish(), inputs)
.export()
.check_count({"torch.ops.aten.hardswish.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_hardswish_default",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
def test_fp16_hardswish(self):
inputs = (torch.randn(1, 3, 3).to(torch.float16),)
self._test_hardswish(inputs)
def test_fp32_hardswish(self):
inputs = (torch.randn(1, 3, 3),)
self._test_hardswish(inputs)
def test_fp32_hardswish_functional(self):
inputs = (torch.randn(1, 3, 3),)
(
Tester(self.HardswishFunctional(), inputs)
.export()
.check_count({"torch.ops.aten.hardswish.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(
[
"executorch_exir_dialects_edge__ops_aten_hardswish_default",
]
)
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)