| import unittest |
| |
| from torchgen.executorch.api.types import ExecutorchCppSignature |
| from torchgen.local import parametrize |
| from torchgen.model import Location, NativeFunction |
| |
| DEFAULT_NATIVE_FUNCTION, _ = NativeFunction.from_yaml( |
| {"func": "foo.out(Tensor input, *, Tensor(a!) out) -> Tensor(a!)"}, |
| loc=Location(__file__, 1), |
| valid_tags=set(), |
| ) |
| |
| |
| class ExecutorchCppSignatureTest(unittest.TestCase): |
| def setUp(self) -> None: |
| self.sig = ExecutorchCppSignature.from_native_function(DEFAULT_NATIVE_FUNCTION) |
| |
| def test_runtime_signature_contains_runtime_context(self) -> None: |
| # test if `KernelRuntimeContext` argument exists in `RuntimeSignature` |
| with parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ): |
| args = self.sig.arguments(include_context=True) |
| self.assertEqual(len(args), 3) |
| self.assertTrue(any(a.name == "context" for a in args)) |
| |
| def test_runtime_signature_does_not_contain_runtime_context(self) -> None: |
| # test if `KernelRuntimeContext` argument is missing in `RuntimeSignature` |
| with parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ): |
| args = self.sig.arguments(include_context=False) |
| self.assertEqual(len(args), 2) |
| self.assertFalse(any(a.name == "context" for a in args)) |
| |
| def test_runtime_signature_declaration_correct(self) -> None: |
| with parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ): |
| decl = self.sig.decl(include_context=True) |
| self.assertEqual( |
| decl, |
| ( |
| "torch::executor::Tensor & foo_outf(" |
| "torch::executor::KernelRuntimeContext & context, " |
| "const torch::executor::Tensor & input, " |
| "torch::executor::Tensor & out)" |
| ), |
| ) |
| no_context_decl = self.sig.decl(include_context=False) |
| self.assertEqual( |
| no_context_decl, |
| ( |
| "torch::executor::Tensor & foo_outf(" |
| "const torch::executor::Tensor & input, " |
| "torch::executor::Tensor & out)" |
| ), |
| ) |