blob: ca18e6c265db124daedd58d86737fcf909e37183 [file]
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from executorch.backends.xnnpack.test.tester import Quantize, Tester
from transformers import MobileBertConfig, MobileBertModel # @manual
class TestMobilebert(unittest.TestCase):
# pyre-ignore
mobilebert = MobileBertModel(MobileBertConfig()).eval()
example_inputs = (torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]]),)
supported_ops = {
"executorch_exir_dialects_edge__ops_aten_addmm_default",
"executorch_exir_dialects_edge__ops_aten_add_Tensor",
"executorch_exir_dialects_edge__ops_aten_mul_Tensor",
"executorch_exir_dialects_edge__ops_aten_sub_Tensor",
"executorch_exir_dialects_edge__ops_aten_div_Tensor",
"executorch_exir_dialects_edge__ops_aten_cat_default",
"executorch_exir_dialects_edge__ops_aten_relu_default",
"executorch_exir_dialects_edge__ops_aten_permute_copy_default",
"executorch_exir_dialects_edge__ops_aten__softmax_default",
"executorch_exir_dialects_edge__ops_aten_constant_pad_nd_default",
}
def test_fp32_mobilebert(self):
(
Tester(self.mobilebert, self.example_inputs)
.export()
.to_edge_transform_and_lower()
.check_not(list(self.supported_ops))
.to_executorch()
.serialize()
.run_method_and_compare_outputs(inputs=self.example_inputs)
)
def test_qs8_mobilebert(self):
(
Tester(self.mobilebert, self.example_inputs)
.quantize(Quantize(calibrate=False))
.export()
.to_edge_transform_and_lower()
.check_not(list(self.supported_ops))
.to_executorch()
.serialize()
.run_method_and_compare_outputs(inputs=self.example_inputs)
)