[torchgen] Move Executorch unboxing logic into torchgen (#90098)

This PR adds `unboxing.py` which converts a `EValue` (similar to `IValue`) to its corresponding C++ type, based on the `ExecutorchCppSignature`.

Added unit tests to it in `test_executorch_unboxing.py`. Notice that this unboxing logic should work for both ATen types and Executorch types, hence the unit tests are parametrized.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90098
Approved by: https://github.com/ezyang
diff --git a/tools/test/test_executorch_unboxing.py b/tools/test/test_executorch_unboxing.py
new file mode 100644
index 0000000..eff0145
--- /dev/null
+++ b/tools/test/test_executorch_unboxing.py
@@ -0,0 +1,176 @@
+import unittest
+from types import ModuleType
+
+from torchgen import local
+from torchgen.api import cpp as aten_cpp, types as aten_types
+from torchgen.api.types import (
+    ArgName,
+    BaseCType,
+    ConstRefCType,
+    MutRefCType,
+    NamedCType,
+)
+from torchgen.executorch.api import et_cpp as et_cpp, types as et_types
+from torchgen.executorch.api.unboxing import Unboxing
+from torchgen.model import BaseTy, BaseType, ListType, OptionalType, Type
+
+
+def aten_argumenttype_type_wrapper(
+    t: Type, *, mutable: bool, binds: ArgName, remove_non_owning_ref_types: bool = False
+) -> NamedCType:
+    return aten_cpp.argumenttype_type(
+        t,
+        mutable=mutable,
+        binds=binds,
+        remove_non_owning_ref_types=remove_non_owning_ref_types,
+    )
+
+
+ATEN_UNBOXING = Unboxing(argument_type_gen=aten_argumenttype_type_wrapper)
+ET_UNBOXING = Unboxing(argument_type_gen=et_cpp.argumenttype_type)
+
+
+class TestUnboxing(unittest.TestCase):
+    """
+    Could use torch.testing._internal.common_utils to reduce boilerplate.
+    GH CI job doesn't build torch before running tools unit tests, hence
+    manually adding these parametrized tests.
+    """
+
+    @local.parametrize(
+        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+    )
+    def test_symint_argument_translate_ctype_aten(self) -> None:
+        # test if `SymInt[]` JIT argument can be translated into C++ argument correctly.
+        # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig.
+
+        # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt`
+        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+        symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None)
+
+        out_name, ctype, _, _ = ATEN_UNBOXING.argumenttype_evalue_convert(
+            t=symint_list_type, arg_name="size", mutable=False
+        )
+
+        self.assertEqual(out_name, "size_list_out")
+        self.assertIsInstance(ctype, BaseCType)
+        # pyre-fixme[16]:
+        self.assertEqual(ctype, aten_types.BaseCType(aten_types.intArrayRefT))
+
+    @local.parametrize(
+        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+    )
+    def test_symint_argument_translate_ctype_executorch(self) -> None:
+        # test if `SymInt[]` JIT argument can be translated into C++ argument correctly.
+        # should be `IntArrayRef` due to the fact that Executorch doesn't use symint sig.
+
+        # pyre-fixme[16]: `enum.Enum` has no attribute `SymInt`
+        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+        symint_list_type = ListType(elem=BaseType(BaseTy.SymInt), size=None)
+
+        out_name, ctype, _, _ = ET_UNBOXING.argumenttype_evalue_convert(
+            t=symint_list_type, arg_name="size", mutable=False
+        )
+
+        self.assertEqual(out_name, "size_list_out")
+        self.assertIsInstance(ctype, et_types.ArrayRefCType)
+        # pyre-fixme[16]:
+        self.assertEqual(
+            ctype, et_types.ArrayRefCType(elem=BaseCType(aten_types.longT))
+        )
+
+    @local.parametrize(
+        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+    )
+    def _test_const_tensor_argument_translate_ctype(
+        self, unboxing: Unboxing, types: ModuleType
+    ) -> None:
+        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+        tensor_type = BaseType(BaseTy.Tensor)
+
+        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+            t=tensor_type, arg_name="self", mutable=False
+        )
+
+        self.assertEqual(out_name, "self_base")
+        # pyre-fixme[16]:
+        self.assertEqual(ctype, ConstRefCType(BaseCType(types.tensorT)))
+
+    def test_const_tensor_argument_translate_ctype_aten(self) -> None:
+        self._test_const_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+    def test_const_tensor_argument_translate_ctype_executorch(self) -> None:
+        self._test_const_tensor_argument_translate_ctype(ET_UNBOXING, et_types)
+
+    @local.parametrize(
+        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+    )
+    def _test_mutable_tensor_argument_translate_ctype(
+        self, unboxing: Unboxing, types: ModuleType
+    ) -> None:
+        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+        tensor_type = BaseType(BaseTy.Tensor)
+
+        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+            t=tensor_type, arg_name="out", mutable=True
+        )
+
+        self.assertEqual(out_name, "out_base")
+        # pyre-fixme[16]:
+        self.assertEqual(ctype, MutRefCType(BaseCType(types.tensorT)))
+
+    def test_mutable_tensor_argument_translate_ctype_aten(self) -> None:
+        self._test_mutable_tensor_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+    def test_mutable_tensor_argument_translate_ctype_executorch(self) -> None:
+        self._test_mutable_tensor_argument_translate_ctype(ET_UNBOXING, et_types)
+
+    @local.parametrize(
+        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+    )
+    def _test_tensor_list_argument_translate_ctype(
+        self, unboxing: Unboxing, types: ModuleType
+    ) -> None:
+        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+        tensor_list_type = ListType(elem=BaseType(BaseTy.Tensor), size=None)
+
+        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+            t=tensor_list_type, arg_name="out", mutable=True
+        )
+
+        self.assertEqual(out_name, "out_list_out")
+        # pyre-fixme[16]:
+        self.assertEqual(ctype, BaseCType(types.tensorListT))
+
+    def test_tensor_list_argument_translate_ctype_aten(self) -> None:
+        self._test_tensor_list_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+    def test_tensor_list_argument_translate_ctype_executorch(self) -> None:
+        self._test_tensor_list_argument_translate_ctype(ET_UNBOXING, et_types)
+
+    @local.parametrize(
+        use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
+    )
+    def _test_optional_int_argument_translate_ctype(
+        self, unboxing: Unboxing, types: ModuleType
+    ) -> None:
+        # pyre-fixme[16]: `enum.Enum` has no attribute `Tensor`
+        # pyre-fixme[19]: Call `BaseType.__init__` expects 0 positional arguments, 1 was provided.
+        optional_int_type = OptionalType(elem=BaseType(BaseTy.int))
+
+        out_name, ctype, _, _ = unboxing.argumenttype_evalue_convert(
+            t=optional_int_type, arg_name="something", mutable=True
+        )
+
+        self.assertEqual(out_name, "something_opt_out")
+        # pyre-fixme[16]:
+        self.assertEqual(ctype, types.OptionalCType(BaseCType(types.longT)))
+
+    def test_optional_int_argument_translate_ctype_aten(self) -> None:
+        self._test_optional_int_argument_translate_ctype(ATEN_UNBOXING, aten_types)
+
+    def test_optional_int_argument_translate_ctype_executorch(self) -> None:
+        self._test_optional_int_argument_translate_ctype(ET_UNBOXING, et_types)
diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py
new file mode 100644
index 0000000..9a8f717
--- /dev/null
+++ b/torchgen/executorch/api/unboxing.py
@@ -0,0 +1,213 @@
+from dataclasses import dataclass
+from typing import Callable, List, Sequence, Tuple
+
+from torchgen.api.types import Binding, CType, NamedCType
+from torchgen.model import (
+    Argument,
+    BaseTy,
+    BaseType,
+    ListType,
+    NativeFunction,
+    OptionalType,
+    Type,
+)
+
+connector = "\n\t"
+
+
+# Return unboxing function name for a NativeFunction
+def name(f: NativeFunction) -> str:
+    return f.func.name.unambiguous_name()
+
+
+@dataclass(frozen=True)
+class Unboxing:
+    """
+    Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
+    A sample generated code:
+    // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
+    void mul_out(EValue** stack) {
+        EValue& self = *stack[0];
+        EValue& other = *stack[1];
+        EValue& out = *stack[2];
+        const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
+        const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
+        torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
+
+        EXECUTORCH_SCOPE_PROF("native_call_mul.out");
+        torch::executor::mul_outf(self_base, other_base, out_base);
+
+
+    }
+    """
+
+    # this is a callable that converts a JIT argument, into its C++ type.
+    # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
+    argument_type_gen: Callable[
+        ...,
+        NamedCType,
+    ]
+
+    # Convert all the arguments in a NativeFunction to C++ code
+    def convert_arguments(
+        self, args: Sequence[Binding]
+    ) -> Tuple[List[Binding], List[str]]:
+        code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
+        binding_list = []
+        for arg in args:
+            # expecting only Argument
+            if not isinstance(arg.argument, Argument):
+                raise Exception(
+                    f"Unexpected argument type, expecting `Argument` but got {arg}"
+                )
+            argument: Argument = arg.argument
+            unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
+                argument.type, argument.name, mutable=argument.is_write
+            )
+            code_list.extend(decl)
+            code_list.extend(code)
+            binding_list.append(arg.with_name(unboxed_name))
+        return binding_list, code_list
+
+    def argumenttype_evalue_convert(
+        self, t: Type, arg_name: str, *, mutable: bool = False
+    ) -> Tuple[str, CType, List[str], List[str]]:
+        """
+        Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
+        (1) the C++ code necessary to unbox the argument
+        (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
+        :param t: a `Type` of an argument
+        :param arg_name: argument name
+        :param mutable: boolean for whether this argument type is mutable
+        :return: unboxed result
+        """
+        ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
+
+        if isinstance(t, BaseType):
+            out_name = f"{arg_name}_base"
+            code, decl = self._gen_code_base_type(
+                arg_name=arg_name, out_name=out_name, ctype=ctype
+            )
+        elif isinstance(t, OptionalType):
+            out_name = f"{arg_name}_opt_out"
+            code, decl = self._gen_code_optional_type(
+                arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
+            )
+        elif isinstance(t, ListType):
+            out_name = f"{arg_name}_list_out"
+            code, decl = self._gen_code_list_type(
+                arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
+            )
+        else:
+            raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
+        return out_name, ctype, code, decl
+
+    def _gen_code_base_type(
+        self, arg_name: str, out_name: str, ctype: CType
+    ) -> Tuple[List[str], List[str]]:
+        return [
+            f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
+        ], []
+
+    def _gen_code_optional_type(
+        self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
+    ) -> Tuple[List[str], List[str]]:
+        in_name = f"{arg_name}_opt_in"
+        res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
+            t.elem, in_name
+        )
+        return (
+            f"""
+    {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
+            """.split(
+                "\n"
+            ),
+            decl,
+        )
+
+    def _gen_code_list_type(
+        self, arg_name: str, out_name: str, t: ListType, ctype: CType
+    ) -> Tuple[List[str], List[str]]:
+        in_name = f"{arg_name}_list_in"
+        elem_name = f"{arg_name}_elem"
+        code = []
+        res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
+            t.elem, elem_name
+        )
+
+        if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
+            code.extend(
+                f"""
+    {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
+                """.split(
+                    "\n"
+                )
+            )
+        elif isinstance(t.elem, BaseType) and (
+            t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
+        ):
+            code.extend(
+                f"""
+    {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
+                """.split(
+                    "\n"
+                )
+            )
+        elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
+            code.extend(
+                f"""
+    {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
+                """.split(
+                    "\n"
+                )
+            )
+        elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
+            # handle list type with size, e.g., bool[4]
+            code.extend(
+                f"""
+    {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
+                """.split(
+                    "\n"
+                )
+            )
+        # pytorch codegen:
+        # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
+        elif (
+            isinstance(t.elem, OptionalType)
+            and isinstance(t.elem.elem, BaseType)
+            and t.elem.elem.name == BaseTy.Tensor
+        ):
+            code.extend(
+                f"""
+#ifdef USE_ATEN_LIB
+at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
+c10::List<c10::optional<at::Tensor>> {out_name};
+for (auto {elem_name}: {in_name}) {{
+    {out_name}.push_back({elem_name});
+}}
+#else
+torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
+#endif
+                """.split(
+                    "\n"
+                )
+            )
+        else:
+            # use ArrayRef as default.
+            vec_name = arg_name + "_vec"
+            # need to bring vector instantiation out of scope so that ArrayRef has valid data
+            decl.append(
+                f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
+            )
+            code.extend(
+                f"""
+    for (EValue {elem_name}: {in_name}) {{
+        {connector.join(res_code)}
+        {vec_name}.push_back({res_name});
+    }}
+    {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
+                """.split(
+                    "\n"
+                )
+            )
+        return code, decl