blob: 13326d05f6a5ef3a3c7ab6783832f2e02328eba2 [file] [log] [blame]
import contextlib
import torch
from torch._C._functorch import (
set_autograd_function_allowed,
get_autograd_function_allowed,
unwrap_if_dead,
)
@contextlib.contextmanager
def enable_autograd_function():
try:
prev_state = get_autograd_function_allowed()
set_autograd_function_allowed(True)
yield
finally:
set_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg
for arg in args
)
return result
# Allows one to expose an API in a private submodule publicly as per the definition
# in PyTorch's public api policy.
#
# It is a temporary solution while we figure out if it should be the long-term solution
# or if we should amend PyTorch's public api policy. The concern is that this approach
# may not be very robust because it's not clear what __module__ is used for.
# However, both numpy and jax overwrite the __module__ attribute of their APIs
# without problem, so it seems fine.
def exposed_in(module):
def wrapper(fn):
fn.__module__ = module
return fn
return wrapper