| from tools.codegen.model import ( |
| FunctionSchema, BaseTy, BaseType, NativeFunction, Argument, Tag, |
| ) |
| from tools.codegen.api.types import ( |
| Binding, NamedCType, ConstRefCType, BaseCType, CType, tensorT, longT |
| ) |
| from tools.codegen.api import dispatcher |
| from typing import List, Optional |
| |
| |
| # This file describes the translation of JIT schema to API's used |
| # when creating view lambdas that are used by the functionalization pass. |
| # There are two types of lambdas: forward lambdas and reverse lambdas. |
| # These API's mostly follow the dispatcher API, with a few quirks: |
| # - The lambda capture has to convert reference types to value types |
| # - While the forward lambda just directly calls into the at::_ops API |
| # (following the dispatcher convention), the logic here for the reverse lambda |
| # is responsible for generating both the call-site, and the declarations |
| # (which are implemented manually in the at::functionalization::impl namespace). |
| |
| # The lambdas generated for each view op in the functionalization pass are of the form |
| # [capture_arguments](outer_arguments) -> returns_type { |
| # return name(inner_arguments); |
| # } |
| |
| # Define some specific lambda input arguments. |
| base_binding = Binding( |
| name='base', |
| nctype=NamedCType(name='base', type=ConstRefCType(BaseCType(tensorT))), |
| argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), |
| default=None) |
| mutated_view_binding = Binding( |
| name='mutated_view', |
| nctype=NamedCType(name='mutated_view', type=ConstRefCType(BaseCType(tensorT))), |
| argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), |
| default=None) |
| mutated_view_idx_binding = Binding( |
| name='mutated_view_idx', |
| nctype=NamedCType(name='mutated_view_idx', type=BaseCType(longT)), |
| argument=Argument(name='base', type=BaseType(BaseTy.Tensor), default=None, annotation=None), |
| default=None) |
| |
| # The lambda capture itself doesn't have a name. |
| # The name returned here corresponds to the name of the inner function called by the lambda. |
| def name(f: NativeFunction, *, functional_op: NativeFunction, is_reverse: bool, include_namespace: bool) -> str: |
| # For inplace_view ops, the lambda calls out to the corresponding functional view op |
| fn = functional_op if f.tag is Tag.inplace_view else f |
| name = fn.func.name.unambiguous_name() |
| if is_reverse: |
| # in the reverse case, we codegen both the call-sites (which need the full namespace) and the declarations (which don't) |
| if include_namespace: |
| return f'at::functionalization::FunctionalInverses::{name}_inverse' |
| else: |
| return f'{name}_inverse' |
| # in the forward case, we just diretly call into the at::_ops API (so we always need the namespace) |
| assert include_namespace |
| return f'at::_ops::{name}::call' |
| |
| |
| def capture_arguments(func: FunctionSchema, *, is_reverse: bool) -> List[Binding]: |
| # capture arguments include all arguments except `self`. |
| # Importantly, they don't include any C++ reference types (or else we'll get a dangling reference in the capture), |
| # So any reference types (IntArrayRef) need to be converted to value types (vector<int64_t>) |
| args = func.arguments.flat_all |
| assert args[0].type == BaseType(BaseTy.Tensor) |
| non_self_args = args[1:] |
| non_self_value_bindings = [dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args] |
| return non_self_value_bindings |
| |
| |
| def returns_type(func: FunctionSchema) -> CType: |
| # Assertion: all view ops return tensor-like outputs |
| assert len(func.returns) >= 1 |
| for ret in func.returns: |
| assert ret.type.is_tensor_like() |
| # However, the return type of the lambda is always an individual tensor. |
| # For multi-tensor outputs, each tensor needs to be tracked individually. |
| return BaseCType(tensorT) |
| |
| |
| def outer_arguments(*, is_reverse: bool) -> List[Binding]: |
| if is_reverse: |
| return [base_binding, mutated_view_binding, mutated_view_idx_binding] |
| else: |
| return [base_binding, mutated_view_idx_binding] |
| |
| |
| def inner_call_index(func: FunctionSchema) -> Optional[Binding]: |
| # For view ops that return multiple tensors (like `split`), we generate a separate lambda for each output. |
| # When we replay a view op that returns multiple tensors, we need to index into the output appropriately |
| if len(func.returns) > 1 or (len(func.returns) == 1 and func.returns[0].type.is_list_like()): |
| return mutated_view_idx_binding |
| return None |
| |
| |
| def inner_arguments(func: FunctionSchema, is_reverse: bool) -> List[Binding]: |
| args = func.arguments.flat_all |
| assert args[0].type == BaseType(BaseTy.Tensor) |
| non_self_args = args[1:] |
| # The forward lambda calls the at::_ops API, while the reverse lambda calls the view inverse API. |
| # Both of these follow the dispatcher API. |
| non_self_bindings = [dispatcher.argument(a) for a in non_self_args] |
| if not is_reverse: |
| # the forward lambda swaps out the original tensor argument with the lambd arg "base" |
| return [base_binding] + non_self_bindings |
| else: |
| # the reverse lambda does the same, but with an additional "mutated_view" arg |
| # additionally, we have a calling convention: for view ops that return multiple tensor outputs |
| # their corresponding view_inverse function takes in an additional index argument. |
| index_binding = inner_call_index(func) |
| if index_binding is not None: |
| return [base_binding, mutated_view_binding, index_binding] + non_self_bindings |
| else: |
| return [base_binding, mutated_view_binding] + non_self_bindings |