|  | from dataclasses import dataclass | 
|  | from typing import Callable, List, Sequence, Tuple | 
|  |  | 
|  | from torchgen.api.types import Binding, CType, NamedCType | 
|  | from torchgen.model import ( | 
|  | Argument, | 
|  | BaseTy, | 
|  | BaseType, | 
|  | ListType, | 
|  | NativeFunction, | 
|  | OptionalType, | 
|  | Type, | 
|  | ) | 
|  |  | 
|  | connector = "\n\t" | 
|  |  | 
|  |  | 
|  | # Return unboxing function name for a NativeFunction | 
|  | def name(f: NativeFunction) -> str: | 
|  | return f.func.name.unambiguous_name() | 
|  |  | 
|  |  | 
|  | @dataclass(frozen=True) | 
|  | class Unboxing: | 
|  | """ | 
|  | Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing. | 
|  | A sample generated code: | 
|  | // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) | 
|  | void mul_out(EValue** stack) { | 
|  | EValue& self = *stack[0]; | 
|  | EValue& other = *stack[1]; | 
|  | EValue& out = *stack[2]; | 
|  | const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>(); | 
|  | const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>(); | 
|  | torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>(); | 
|  |  | 
|  | EXECUTORCH_SCOPE_PROF("native_call_mul.out"); | 
|  | torch::executor::mul_outf(self_base, other_base, out_base); | 
|  |  | 
|  |  | 
|  | } | 
|  | """ | 
|  |  | 
|  | # this is a callable that converts a JIT argument, into its C++ type. | 
|  | # Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type. | 
|  | argument_type_gen: Callable[ | 
|  | ..., | 
|  | NamedCType, | 
|  | ] | 
|  |  | 
|  | # Convert all the arguments in a NativeFunction to C++ code | 
|  | def convert_arguments( | 
|  | self, args: Sequence[Binding] | 
|  | ) -> Tuple[List[Binding], List[str]]: | 
|  | code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))] | 
|  | binding_list = [] | 
|  | for arg in args: | 
|  | # expecting only Argument | 
|  | if not isinstance(arg.argument, Argument): | 
|  | raise Exception( | 
|  | f"Unexpected argument type, expecting `Argument` but got {arg}" | 
|  | ) | 
|  | argument: Argument = arg.argument | 
|  | unboxed_name, _, code, decl = self.argumenttype_evalue_convert( | 
|  | argument.type, argument.name, mutable=argument.is_write | 
|  | ) | 
|  | code_list.extend(decl) | 
|  | code_list.extend(code) | 
|  | binding_list.append(arg.with_name(unboxed_name)) | 
|  | return binding_list, code_list | 
|  |  | 
|  | def argumenttype_evalue_convert( | 
|  | self, t: Type, arg_name: str, *, mutable: bool = False | 
|  | ) -> Tuple[str, CType, List[str], List[str]]: | 
|  | """ | 
|  | Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: | 
|  | (1) the C++ code necessary to unbox the argument | 
|  | (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType | 
|  | :param t: a `Type` of an argument | 
|  | :param arg_name: argument name | 
|  | :param mutable: boolean for whether this argument type is mutable | 
|  | :return: unboxed result | 
|  | """ | 
|  | ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type | 
|  |  | 
|  | if isinstance(t, BaseType): | 
|  | out_name = f"{arg_name}_base" | 
|  | code, decl = self._gen_code_base_type( | 
|  | arg_name=arg_name, out_name=out_name, ctype=ctype | 
|  | ) | 
|  | elif isinstance(t, OptionalType): | 
|  | out_name = f"{arg_name}_opt_out" | 
|  | code, decl = self._gen_code_optional_type( | 
|  | arg_name=arg_name, out_name=out_name, t=t, ctype=ctype | 
|  | ) | 
|  | elif isinstance(t, ListType): | 
|  | out_name = f"{arg_name}_list_out" | 
|  | code, decl = self._gen_code_list_type( | 
|  | arg_name=arg_name, out_name=out_name, t=t, ctype=ctype | 
|  | ) | 
|  | else: | 
|  | raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") | 
|  | return out_name, ctype, code, decl | 
|  |  | 
|  | def _gen_code_base_type( | 
|  | self, arg_name: str, out_name: str, ctype: CType | 
|  | ) -> Tuple[List[str], List[str]]: | 
|  | return [ | 
|  | f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" | 
|  | ], [] | 
|  |  | 
|  | def _gen_code_optional_type( | 
|  | self, arg_name: str, out_name: str, t: OptionalType, ctype: CType | 
|  | ) -> Tuple[List[str], List[str]]: | 
|  | in_name = f"{arg_name}_opt_in" | 
|  | res_name, base_type, res_code, decl = self.argumenttype_evalue_convert( | 
|  | t.elem, in_name | 
|  | ) | 
|  | return ( | 
|  | f""" | 
|  | {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); | 
|  | """.split( | 
|  | "\n" | 
|  | ), | 
|  | decl, | 
|  | ) | 
|  |  | 
|  | def _gen_code_list_type( | 
|  | self, arg_name: str, out_name: str, t: ListType, ctype: CType | 
|  | ) -> Tuple[List[str], List[str]]: | 
|  | in_name = f"{arg_name}_list_in" | 
|  | elem_name = f"{arg_name}_elem" | 
|  | code = [] | 
|  | res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert( | 
|  | t.elem, elem_name | 
|  | ) | 
|  |  | 
|  | if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor: | 
|  | code.extend( | 
|  | f""" | 
|  | {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList(); | 
|  | """.split( | 
|  | "\n" | 
|  | ) | 
|  | ) | 
|  | elif isinstance(t.elem, BaseType) and ( | 
|  | t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt | 
|  | ): | 
|  | code.extend( | 
|  | f""" | 
|  | {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList(); | 
|  | """.split( | 
|  | "\n" | 
|  | ) | 
|  | ) | 
|  | elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: | 
|  | code.extend( | 
|  | f""" | 
|  | {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList(); | 
|  | """.split( | 
|  | "\n" | 
|  | ) | 
|  | ) | 
|  | elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: | 
|  | # handle list type with size, e.g., bool[4] | 
|  | code.extend( | 
|  | f""" | 
|  | {ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList(); | 
|  | """.split( | 
|  | "\n" | 
|  | ) | 
|  | ) | 
|  | # pytorch codegen: | 
|  | # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>> | 
|  | elif ( | 
|  | isinstance(t.elem, OptionalType) | 
|  | and isinstance(t.elem.elem, BaseType) | 
|  | and t.elem.elem.name == BaseTy.Tensor | 
|  | ): | 
|  | code.extend( | 
|  | f""" | 
|  | #ifdef USE_ATEN_LIB | 
|  | at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor(); | 
|  | c10::List<c10::optional<at::Tensor>> {out_name}; | 
|  | for (auto {elem_name}: {in_name}) {{ | 
|  | {out_name}.push_back({elem_name}); | 
|  | }} | 
|  | #else | 
|  | torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor(); | 
|  | #endif | 
|  | """.split( | 
|  | "\n" | 
|  | ) | 
|  | ) | 
|  | else: | 
|  | # use ArrayRef as default. | 
|  | vec_name = arg_name + "_vec" | 
|  | # need to bring vector instantiation out of scope so that ArrayRef has valid data | 
|  | decl.append( | 
|  | f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};" | 
|  | ) | 
|  | code.extend( | 
|  | f""" | 
|  | for (EValue {elem_name}: {in_name}) {{ | 
|  | {connector.join(res_code)} | 
|  | {vec_name}.push_back({res_name}); | 
|  | }} | 
|  | {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); | 
|  | """.split( | 
|  | "\n" | 
|  | ) | 
|  | ) | 
|  | return code, decl |