blob: 303ebbc45d63e239338df3c91c226167ee1776e4 [file] [log] [blame]
import contextlib
from typing import Tuple, Union
import torch
from torch._C._functorch import (
get_single_level_autograd_function_allowed,
set_single_level_autograd_function_allowed,
unwrap_if_dead,
)
from torch.utils._exposed_in import exposed_in
__all__ = [
"exposed_in",
"argnums_t",
"enable_single_level_autograd_function",
"unwrap_dead_wrappers",
]
@contextlib.contextmanager
def enable_single_level_autograd_function():
try:
prev_state = get_single_level_autograd_function_allowed()
set_single_level_autograd_function_allowed(True)
yield
finally:
set_single_level_autograd_function_allowed(prev_state)
def unwrap_dead_wrappers(args):
# NB: doesn't use tree_map_only for performance reasons
result = tuple(
unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
)
return result
argnums_t = Union[int, Tuple[int, ...]]