[torchgen] Move Executorch unboxing logic into torchgen (#90098)
This PR adds `unboxing.py` which converts a `EValue` (similar to `IValue`) to its corresponding C++ type, based on the `ExecutorchCppSignature`.
Added unit tests to it in `test_executorch_unboxing.py`. Notice that this unboxing logic should work for both ATen types and Executorch types, hence the unit tests are parametrized.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90098
Approved by: https://github.com/ezyang
diff --git a/tools/test/test_executorch_unboxing.py b/tools/test/test_executorch_unboxing.py
new file mode 100644
index 0000000..eff0145
--- /dev/null
+++ b/tools/test/test_executorch_unboxing.py
@@ -0,0 +1,176 @@
+import unittest
+from types import ModuleType
+
+from torchgen import local
+from torchgen.api import cpp as aten_cpp, types as aten_types
+from torchgen.api.types import (
+ ArgName,
+ BaseCType,
+ ConstRefCType,
+ MutRefCType,
+ NamedCType,
+)
+from torchgen.executorch.api import et_cpp as et_cpp, types as et_types
+from torchgen.executorch.api.unboxing import Unboxing
+from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type
+
+
+def aten_argumenttype_type_wrapper(
+ t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
+) -> NamedCType:
+ return aten_cpp.argumenttype_type(
+ t,
+ mutable=mutable,
+ binds=binds,
+ remove_non_owning_ref_types=remove_non_owning_ref_types,
+ )
+
+
+ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper)
+ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type)
+
+
+class TestUnboxing(unittest.TestCase):
+ """
+ Could use torch.testing._internal.common_utils to reduce boilerplate.
+ GH CI job doesn't build torch before running tools unit tests, hence
+ manually adding these parametrized tests.
+ """
+
+ @local.parametrize(
+ use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+ )
+ def test_symint_argument_translate_ctype_aten(self) -> None:
+ # test if `SymInt[]` JIT argument can be translated into C++ argument correctly.
+ # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig.
+
+ # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt`
+ # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+ symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None)
+
+ out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert(
+ t=symint_list_type, arg_name="size", mutable=False
+ )
+
+ self.assertEqual(out_name, "size_list_out")
+ self.assertIsInstance(ctype, BaseCType)
+ # pyre-fixme[16]:
+ self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT))
+
+ @local.parametrize(
+ use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+ )
+ def test_symint_argument_translate_ctype_executorch(self) -> None:
+ # test if `SymInt[]` JIT argument can be translated into C++ argument correctly.
+ # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig.
+
+ # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt`
+ # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+ symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None)
+
+ out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert(
+ t=symint_list_type, arg_name="size", mutable=False
+ )
+
+ self.assertEqual(out_name, "size_list_out")
+ self.assertIsInstance(ctype, et_types.ArrayRefCType)
+ # pyre-fixme[16]:
+ self.assertEqual(
+ ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT))
+ )
+
+ @local.parametrize(
+ use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+ )
+ def _test_const_tensor_argument_translate_ctype(
+ self, unboxing: Unboxing, types: ModuleType
+ ) -> None:
+ # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+ # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+ tensor_type = BaseType(BaseTy.Tensor)
+
+ out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+ t=tensor_type, arg_name="self", mutable=False
+ )
+
+ self.assertEqual(out_name, "self_base")
+ # pyre-fixme[16]:
+ self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT)))
+
+ def test_const_tensor_argument_translate_ctype_aten(self) -> None:
+ self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+ def test_const_tensor_argument_translate_ctype_executorch(self) -> None:
+ self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types)
+
+ @local.parametrize(
+ use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+ )
+ def _test_mutable_tensor_argument_translate_ctype(
+ self, unboxing: Unboxing, types: ModuleType
+ ) -> None:
+ # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+ # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+ tensor_type = BaseType(BaseTy.Tensor)
+
+ out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+ t=tensor_type, arg_name="out", mutable=True
+ )
+
+ self.assertEqual(out_name, "out_base")
+ # pyre-fixme[16]:
+ self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT)))
+
+ def test_mutable_tensor_argument_translate_ctype_aten(self) -> None:
+ self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+ def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None:
+ self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types)
+
+ @local.parametrize(
+ use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+ )
+ def _test_tensor_list_argument_translate_ctype(
+ self, unboxing: Unboxing, types: ModuleType
+ ) -> None:
+ # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+ # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+ tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None)
+
+ out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+ t=tensor_list_type, arg_name="out", mutable=True
+ )
+
+ self.assertEqual(out_name, "out_list_out")
+ # pyre-fixme[16]:
+ self.assertEqual(ctype, BaseCType(types.tensorListT))
+
+ def test_tensor_list_argument_translate_ctype_aten(self) -> None:
+ self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+ def test_tensor_list_argument_translate_ctype_executorch(self) -> None:
+ self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types)
+
+ @local.parametrize(
+ use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+ )
+ def _test_optional_int_argument_translate_ctype(
+ self, unboxing: Unboxing, types: ModuleType
+ ) -> None:
+ # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+ # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+ optional_int_type = OptionalType(elem=BaseType(BaseTy.int))
+
+ out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+ t=optional_int_type, arg_name="something", mutable=True
+ )
+
+ self.assertEqual(out_name, "something_opt_out")
+ # pyre-fixme[16]:
+ self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT)))
+
+ def test_optional_int_argument_translate_ctype_aten(self) -> None:
+ self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+ def test_optional_int_argument_translate_ctype_executorch(self) -> None:
+ self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types)
diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py
new file mode 100644
index 0000000..9a8f717
--- /dev/null
+++ b/torchgen/executorch/api/unboxing.py
@@ -0,0 +1,213 @@
+from dataclasses import dataclass
+from typing import Callable, List, Sequence, Tuple
+
+from torchgen.api.types import Binding, CType, NamedCType
+from torchgen.model import (
+ Argument,
+ BaseTy,
+ BaseType,
+ ListType,
+ NativeFunction,
+ OptionalType,
+ Type,
+)
+
+connector = "\n\t"
+
+
+# Return unboxing function name for a NativeFunction
+def name(f: NativeFunction) -> str:
+ return f.func.name.unambiguous_name()
+
+
+@dataclass(frozen=True)
+class Unboxing:
+ """
+ Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
+ A sample generated code:
+ // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+ void mul_out(EValue** stack) {
+ EValue& self = *stack[0];
+ EValue& other = *stack[1];
+ EValue& out = *stack[2];
+ const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
+ const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
+ torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
+
+ EXECUTORCH_SCOPE_PROF("native_call_mul.out");
+ torch::executor::mul_outf(self_base, other_base, out_base);
+
+
+ }
+ """
+
+ # this is a callable that converts a JIT argument, into its C++ type.
+ # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
+ argument_type_gen: Callable[
+ ...,
+ NamedCType,
+ ]
+
+ # Convert all the arguments in a NativeFunction to C++ code
+ def convert_arguments(
+ self, args: Sequence[Binding]
+ ) -> Tuple[List[Binding], List[str]]:
+ code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
+ binding_list = []
+ for arg in args:
+ # expecting only Argument
+ if not isinstance(arg.argument, Argument):
+ raise Exception(
+ f"Unexpected argument type, expecting `Argument` but got {arg}"
+ )
+ argument: Argument = arg.argument
+ unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
+ argument.type, argument.name, mutable=argument.is_write
+ )
+ code_list.extend(decl)
+ code_list.extend(code)
+ binding_list.append(arg.with_name(unboxed_name))
+ return binding_list, code_list
+
+ def argumenttype_evalue_convert(
+ self, t: Type, arg_name: str, *, mutable: bool = False
+ ) -> Tuple[str, CType, List[str], List[str]]:
+ """
+ Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
+ (1) the C++ code necessary to unbox the argument
+ (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
+ :param t: a `Type` of an argument
+ :param arg_name: argument name
+ :param mutable: boolean for whether this argument type is mutable
+ :return: unboxed result
+ """
+ ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
+
+ if isinstance(t, BaseType):
+ out_name = f"{arg_name}_base"
+ code, decl = self._gen_code_base_type(
+ arg_name=arg_name, out_name=out_name, ctype=ctype
+ )
+ elif isinstance(t, OptionalType):
+ out_name = f"{arg_name}_opt_out"
+ code, decl = self._gen_code_optional_type(
+ arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
+ )
+ elif isinstance(t, ListType):
+ out_name = f"{arg_name}_list_out"
+ code, decl = self._gen_code_list_type(
+ arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
+ )
+ else:
+ raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
+ return out_name, ctype, code, decl
+
+ def _gen_code_base_type(
+ self, arg_name: str, out_name: str, ctype: CType
+ ) -> Tuple[List[str], List[str]]:
+ return [
+ f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
+ ], []
+
+ def _gen_code_optional_type(
+ self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
+ ) -> Tuple[List[str], List[str]]:
+ in_name = f"{arg_name}_opt_in"
+ res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
+ t.elem, in_name
+ )
+ return (
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
+ """.split(
+ "\n"
+ ),
+ decl,
+ )
+
+ def _gen_code_list_type(
+ self, arg_name: str, out_name: str, t: ListType, ctype: CType
+ ) -> Tuple[List[str], List[str]]:
+ in_name = f"{arg_name}_list_in"
+ elem_name = f"{arg_name}_elem"
+ code = []
+ res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
+ t.elem, elem_name
+ )
+
+ if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
+ """.split(
+ "\n"
+ )
+ )
+ elif isinstance(t.elem, BaseType) and (
+ t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
+ ):
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
+ """.split(
+ "\n"
+ )
+ )
+ elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
+ """.split(
+ "\n"
+ )
+ )
+ elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
+ # handle list type with size, e.g., bool[4]
+ code.extend(
+ f"""
+ {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
+ """.split(
+ "\n"
+ )
+ )
+ # pytorch codegen:
+ # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
+ elif (
+ isinstance(t.elem, OptionalType)
+ and isinstance(t.elem.elem, BaseType)
+ and t.elem.elem.name == BaseTy.Tensor
+ ):
+ code.extend(
+ f"""
+#ifdef USE_ATEN_LIB
+at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
+c10::List<c10::optional<at::Tensor>> {out_name};
+for (auto {elem_name}: {in_name}) {{
+ {out_name}.push_back({elem_name});
+}}
+#else
+torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
+#endif
+ """.split(
+ "\n"
+ )
+ )
+ else:
+ # use ArrayRef as default.
+ vec_name = arg_name + "_vec"
+ # need to bring vector instantiation out of scope so that ArrayRef has valid data
+ decl.append(
+ f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
+ )
+ code.extend(
+ f"""
+ for (EValue {elem_name}: {in_name}) {{
+ {connector.join(res_code)}
+ {vec_name}.push_back({res_name});
+ }}
+ {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
+ """.split(
+ "\n"
+ )
+ )
+ return code, decl