from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI | |
from torchgen.context import native_function_manager | |
from torchgen.utils import T | |
import functools | |
from typing import Callable | |
# Like tools.api.context.with_native_function, but for | |
# NativeFunctionWithDifferentiabilityInfo. | |
def with_native_function_with_differentiability_info( | |
func: Callable[[NFWDI], T] | |
) -> Callable[[NFWDI], T]: | |
@functools.wraps(func) | |
def wrapper(f: NFWDI) -> T: | |
with native_function_manager(f.func): | |
return func(f) | |
return wrapper |