| # This module contains functions that *will be allowed* by dynamo |
| import torch.utils._pytree as pytree |
| except ModuleNotFoundError: |
| np = None # type: ignore[assignment] |
| def is_compiling() -> bool: |
| Create an extra frame around fn that is not in skipfiles |
| def inner(*args, **kwargs): |
| return fn(*args, **kwargs) |
| def call_hook(hook, *args): |
| Used by compiled autograd to handle hook returning None |
| r"""Decorator that turns a function from ``np.ndarray``s to ``np.ndarray``s into a function |
| from ``torch.Tensor``s to ``torch.Tensor``s. |
| def wrap(*args, **kwargs): |
| args, kwargs = pytree.tree_map_only( |
| torch.Tensor, lambda x: x.numpy(), (args, kwargs) |
| return pytree.tree_map_only(np.ndarray, lambda x: torch.as_tensor(x), out) |