blob: 30dfa5503a921591b6fb47192b2a353c03312a27 [file]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from executorch.backends.xnnpack.test.tester import Tester
class TestMaximum(unittest.TestCase):
class Maximum(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return torch.maximum(x, y)
def _test_maximum(self, inputs):
(
Tester(self.Maximum(), inputs)
.export()
.check_count({"torch.ops.aten.maximum.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)
def test_fp16_maximum(self):
inputs = (
torch.randn(2, 3, 4).to(torch.float16),
torch.randn(2, 3, 4).to(torch.float16),
)
self._test_maximum(inputs)
def test_fp32_maximum(self):
inputs = (
torch.randn(2, 3, 4),
torch.randn(2, 3, 4),
)
self._test_maximum(inputs)
def test_fp32_maximum_broadcast(self):
inputs = (
torch.randn(2, 3, 4),
torch.randn(2, 1, 4),
)
(
Tester(self.Maximum(), inputs)
.export()
.check_count({"torch.ops.aten.maximum.default": 1})
.to_edge_transform_and_lower()
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
.check_not(["executorch_exir_dialects_edge__ops_aten_maximum_default"])
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)