| import torch |
| import functools |
| import inspect |
| from typing import Any, Callable, TypeVar, cast |
| |
| |
| __all__ = ['no_grad', 'enable_grad', 'set_grad_enabled'] |
| |
| |
| # Used for annotating the decorator usage of 'no_grad' and 'enable_grad'. |
| # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators |
| FuncType = Callable[..., Any] |
| F = TypeVar('F', bound=FuncType) |
| |
| |
| class _DecoratorContextManager: |
| """Allow a context manager to be used as a decorator""" |
| |
| def __call__(self, func: F) -> F: |
| if inspect.isgeneratorfunction(func): |
| return self._wrap_generator(func) |
| |
| @functools.wraps(func) |
| def decorate_context(*args, **kwargs): |
| with self.__class__(): |
| return func(*args, **kwargs) |
| return cast(F, decorate_context) |
| |
| def _wrap_generator(self, func): |
| """Wrap each generator invocation with the context manager""" |
| @functools.wraps(func) |
| def generator_context(*args, **kwargs): |
| gen = func(*args, **kwargs) |
| while True: |
| try: |
| with self.__class__(): |
| x = next(gen) |
| yield x |
| except StopIteration: |
| break |
| return generator_context |
| |
| def __enter__(self) -> None: |
| raise NotImplementedError |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| raise NotImplementedError |
| |
| |
| class no_grad(_DecoratorContextManager): |
| r"""Context-manager that disabled gradient calculation. |
| |
| Disabling gradient calculation is useful for inference, when you are sure |
| that you will not call :meth:`Tensor.backward()`. It will reduce memory |
| consumption for computations that would otherwise have `requires_grad=True`. |
| |
| In this mode, the result of every computation will have |
| `requires_grad=False`, even when the inputs have `requires_grad=True`. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Also functions as a decorator. (Make sure to instantiate with parenthesis.) |
| |
| |
| Example:: |
| |
| >>> x = torch.tensor([1], requires_grad=True) |
| >>> with torch.no_grad(): |
| ... y = x * 2 |
| >>> y.requires_grad |
| False |
| >>> @torch.no_grad() |
| ... def doubler(x): |
| ... return x * 2 |
| >>> z = doubler(x) |
| >>> z.requires_grad |
| False |
| """ |
| def __init__(self): |
| if not torch._jit_internal.is_scripting(): |
| super().__init__() |
| self.prev = False |
| |
| def __enter__(self): |
| self.prev = torch.is_grad_enabled() |
| torch.set_grad_enabled(False) |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch.set_grad_enabled(self.prev) |
| |
| |
| class enable_grad(_DecoratorContextManager): |
| r"""Context-manager that enables gradient calculation. |
| |
| Enables gradient calculation, if it has been disabled via :class:`~no_grad` |
| or :class:`~set_grad_enabled`. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Also functions as a decorator. (Make sure to instantiate with parenthesis.) |
| |
| |
| Example:: |
| |
| >>> x = torch.tensor([1], requires_grad=True) |
| >>> with torch.no_grad(): |
| ... with torch.enable_grad(): |
| ... y = x * 2 |
| >>> y.requires_grad |
| True |
| >>> y.backward() |
| >>> x.grad |
| >>> @torch.enable_grad() |
| ... def doubler(x): |
| ... return x * 2 |
| >>> with torch.no_grad(): |
| ... z = doubler(x) |
| >>> z.requires_grad |
| True |
| |
| """ |
| def __enter__(self) -> None: |
| self.prev = torch.is_grad_enabled() |
| torch._C._set_grad_enabled(True) |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch._C._set_grad_enabled(self.prev) |
| |
| |
| class set_grad_enabled(object): |
| r"""Context-manager that sets gradient calculation to on or off. |
| |
| ``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`. |
| It can be used as a context-manager or as a function. |
| |
| This context manager is thread local; it will not affect computation |
| in other threads. |
| |
| Arguments: |
| mode (bool): Flag whether to enable grad (``True``), or disable |
| (``False``). This can be used to conditionally enable |
| gradients. |
| |
| |
| Example:: |
| |
| >>> x = torch.tensor([1], requires_grad=True) |
| >>> is_train = False |
| >>> with torch.set_grad_enabled(is_train): |
| ... y = x * 2 |
| >>> y.requires_grad |
| False |
| >>> torch.set_grad_enabled(True) |
| >>> y = x * 2 |
| >>> y.requires_grad |
| True |
| >>> torch.set_grad_enabled(False) |
| >>> y = x * 2 |
| >>> y.requires_grad |
| False |
| |
| """ |
| |
| def __init__(self, mode: bool) -> None: |
| self.prev = torch.is_grad_enabled() |
| torch._C._set_grad_enabled(mode) |
| |
| def __enter__(self) -> None: |
| pass |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| torch._C._set_grad_enabled(self.prev) |