| import unittest |
| from types import ModuleType |
| |
| from torchgen import local |
| from torchgen.api import cpp as aten_cpp, types as aten_types |
| from torchgen.api.types import ( |
| ArgName, |
| BaseCType, |
| ConstRefCType, |
| MutRefCType, |
| NamedCType, |
| ) |
| from torchgen.executorch.api import et_cpp as et_cpp, types as et_types |
| from torchgen.executorch.api.unboxing import Unboxing |
| from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type |
| |
| |
| def aten_argumenttype_type_wrapper( |
| t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False |
| ) -> NamedCType: |
| return aten_cpp.argumenttype_type( |
| t, |
| mutable=mutable, |
| binds=binds, |
| remove_non_owning_ref_types=remove_non_owning_ref_types, |
| ) |
| |
| |
| ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper) |
| ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type) |
| |
| |
| class TestUnboxing(unittest.TestCase): |
| """ |
| Could use torch.testing._internal.common_utils to reduce boilerplate. |
| GH CI job doesn't build torch before running tools unit tests, hence |
| manually adding these parametrized tests. |
| """ |
| |
| @local.parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ) |
| def test_symint_argument_translate_ctype_aten(self) -> None: |
| # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. |
| # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. |
| |
| # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt` |
| # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. |
| symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) |
| |
| out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert( |
| t=symint_list_type, arg_name="size", mutable=False |
| ) |
| |
| self.assertEqual(out_name, "size_list_out") |
| self.assertIsInstance(ctype, BaseCType) |
| # pyre-fixme[16]: |
| self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT)) |
| |
| @local.parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ) |
| def test_symint_argument_translate_ctype_executorch(self) -> None: |
| # test if `SymInt[]` JIT argument can be translated into C++ argument correctly. |
| # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig. |
| |
| # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt` |
| # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. |
| symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None) |
| |
| out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert( |
| t=symint_list_type, arg_name="size", mutable=False |
| ) |
| |
| self.assertEqual(out_name, "size_list_out") |
| self.assertIsInstance(ctype, et_types.ArrayRefCType) |
| # pyre-fixme[16]: |
| self.assertEqual( |
| ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT)) |
| ) |
| |
| @local.parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ) |
| def _test_const_tensor_argument_translate_ctype( |
| self, unboxing: Unboxing, types: ModuleType |
| ) -> None: |
| # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` |
| # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. |
| tensor_type = BaseType(BaseTy.Tensor) |
| |
| out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( |
| t=tensor_type, arg_name="self", mutable=False |
| ) |
| |
| self.assertEqual(out_name, "self_base") |
| # pyre-fixme[16]: |
| self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT))) |
| |
| def test_const_tensor_argument_translate_ctype_aten(self) -> None: |
| self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) |
| |
| def test_const_tensor_argument_translate_ctype_executorch(self) -> None: |
| self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types) |
| |
| @local.parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ) |
| def _test_mutable_tensor_argument_translate_ctype( |
| self, unboxing: Unboxing, types: ModuleType |
| ) -> None: |
| # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` |
| # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. |
| tensor_type = BaseType(BaseTy.Tensor) |
| |
| out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( |
| t=tensor_type, arg_name="out", mutable=True |
| ) |
| |
| self.assertEqual(out_name, "out_base") |
| # pyre-fixme[16]: |
| self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT))) |
| |
| def test_mutable_tensor_argument_translate_ctype_aten(self) -> None: |
| self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types) |
| |
| def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None: |
| self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types) |
| |
| @local.parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ) |
| def _test_tensor_list_argument_translate_ctype( |
| self, unboxing: Unboxing, types: ModuleType |
| ) -> None: |
| # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` |
| # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. |
| tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None) |
| |
| out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( |
| t=tensor_list_type, arg_name="out", mutable=True |
| ) |
| |
| self.assertEqual(out_name, "out_list_out") |
| # pyre-fixme[16]: |
| self.assertEqual(ctype, BaseCType(types.tensorListT)) |
| |
| def test_tensor_list_argument_translate_ctype_aten(self) -> None: |
| self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types) |
| |
| def test_tensor_list_argument_translate_ctype_executorch(self) -> None: |
| self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types) |
| |
| @local.parametrize( |
| use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False |
| ) |
| def _test_optional_int_argument_translate_ctype( |
| self, unboxing: Unboxing, types: ModuleType |
| ) -> None: |
| # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor` |
| # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided. |
| optional_int_type = OptionalType(elem=BaseType(BaseTy.int)) |
| |
| out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert( |
| t=optional_int_type, arg_name="something", mutable=True |
| ) |
| |
| self.assertEqual(out_name, "something_opt_out") |
| # pyre-fixme[16]: |
| self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT))) |
| |
| def test_optional_int_argument_translate_ctype_aten(self) -> None: |
| self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types) |
| |
| def test_optional_int_argument_translate_ctype_executorch(self) -> None: |
| self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types) |