| import contextlib | 
 |  | 
 | import functools | 
 | from typing import Callable, Dict, Iterator, Optional, TypeVar, Union | 
 |  | 
 | import torchgen.local as local | 
 | from torchgen.model import ( | 
 |     BackendIndex, | 
 |     DispatchKey, | 
 |     NativeFunction, | 
 |     NativeFunctionsGroup, | 
 |     NativeFunctionsViewGroup, | 
 | ) | 
 | from torchgen.utils import context, S, T | 
 |  | 
 | # Helper functions for defining generators on things in the model | 
 |  | 
 | F = TypeVar( | 
 |     "F", | 
 |     NativeFunction, | 
 |     NativeFunctionsGroup, | 
 |     NativeFunctionsViewGroup, | 
 |     Union[NativeFunction, NativeFunctionsGroup], | 
 |     Union[NativeFunction, NativeFunctionsViewGroup], | 
 | ) | 
 |  | 
 | F2 = TypeVar( | 
 |     "F2", | 
 |     NativeFunction, | 
 |     NativeFunctionsGroup, | 
 |     Optional[NativeFunction], | 
 |     bool, | 
 | ) | 
 |  | 
 |  | 
 | @contextlib.contextmanager | 
 | def native_function_manager( | 
 |     g: Union[NativeFunctionsGroup, NativeFunctionsViewGroup, NativeFunction] | 
 | ) -> Iterator[None]: | 
 |     if isinstance(g, NativeFunctionsGroup): | 
 |         # By default, we associate all errors with structured native functions | 
 |         # with the out variant.  In some cases, it might be better to have | 
 |         # a more specific place to hang things; if so, use | 
 |         # native_function_manager again on the inside | 
 |         f = g.out | 
 |     elif isinstance(g, NativeFunctionsViewGroup): | 
 |         # We associate errors with the view operator | 
 |         f = g.view | 
 |     else: | 
 |         f = g | 
 |     with context(lambda: f"in native_functions.yaml line {f.loc}:\n  {f.func}"): | 
 |         with local.parametrize( | 
 |             use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors | 
 |         ): | 
 |             yield | 
 |  | 
 |  | 
 | # Given a function that operates on NativeFunction, wrap it into a new function | 
 | # that sets some appropriate context managers for that native function. | 
 | # YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound | 
 | # (you will get an error if we try to access the local variables without having | 
 | # set them). | 
 | def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: | 
 |     @functools.wraps(func) | 
 |     def wrapper(f: F) -> T: | 
 |         with native_function_manager(f): | 
 |             return func(f) | 
 |  | 
 |     return wrapper | 
 |  | 
 |  | 
 | def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: | 
 |     @functools.wraps(func) | 
 |     def wrapper(f: F, f2: F2) -> T: | 
 |         # The first native_function is assumed to be the one with the appropriate context. | 
 |         with native_function_manager(f): | 
 |             return func(f, f2) | 
 |  | 
 |     return wrapper | 
 |  | 
 |  | 
 | def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: | 
 |     @functools.wraps(func) | 
 |     def wrapper(slf: S, f: F) -> T: | 
 |         with native_function_manager(f): | 
 |             return func(slf, f) | 
 |  | 
 |     return wrapper | 
 |  | 
 |  | 
 | # Convenience decorator for functions that explicitly take in a BackendIndex, | 
 | # instead of indirectly taking one in as a closure | 
 | def with_native_function_and_index( | 
 |     func: Callable[[F, BackendIndex], T] | 
 | ) -> Callable[[F, BackendIndex], T]: | 
 |     @functools.wraps(func) | 
 |     def wrapper(f: F, backend_index: BackendIndex) -> T: | 
 |         with native_function_manager(f): | 
 |             return func(f, backend_index) | 
 |  | 
 |     return wrapper | 
 |  | 
 |  | 
 | # Convenience decorator for functions that explicitly take in a Dict of BackendIndices | 
 | def with_native_function_and_indices( | 
 |     func: Callable[[F, Dict[DispatchKey, BackendIndex]], T] | 
 | ) -> Callable[[F, Dict[DispatchKey, BackendIndex]], T]: | 
 |     @functools.wraps(func) | 
 |     def wrapper(f: F, backend_indices: Dict[DispatchKey, BackendIndex]) -> T: | 
 |         with native_function_manager(f): | 
 |             return func(f, backend_indices) | 
 |  | 
 |     return wrapper |