blob: 66f4f81aa0fbbc2863d895512044c23705195082 [file] [log] [blame]
from tools.codegen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
from tools.codegen.context import native_function_manager
from tools.codegen.utils import T
import functools
from typing import Callable
# Like tools.api.context.with_native_function, but for
# NativeFunctionWithDifferentiabilityInfo.
def with_native_function_with_differentiability_info(func: Callable[[NFWDI], T]) -> Callable[[NFWDI], T]:
@functools.wraps(func)
def wrapper(f: NFWDI) -> T:
with native_function_manager(f.func):
return func(f)
return wrapper