| import contextlib |
| from typing import Optional |
| |
| import warnings |
| import torch |
| from torch._C import _len_torch_dispatch_stack, _get_dispatch_stack_at,\ |
| _pop_torch_dispatch_stack, _push_on_torch_dispatch_stack, DispatchKey |
| |
| |
| # TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it: |
| # - We need a better user-facing api for _DisableTorchDispatch that |
| # is able to selectively disable __torch_dispatch__ of a particular class. |
| # - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor) |
| # - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694) |
| |
| class TorchDispatchMode: |
| """ |
| A ``TorchDispatchMode`` allows you to override the meaning of all |
| ``__torch_dispatch__`` overrideable functions within a dynamic scope, |
| without having to actually create a tensor subclass or manually |
| monkey-patch functions in the PyTorch API. Some common situations |
| where you should use a mode: |
| |
| * You want to override the meaning of factory functions, or other |
| functions that do not otherwise take a tensor as an argument |
| (these cannot be overridden with tensor subclasses). |
| |
| * You want to override the behavior of all functions without needing |
| to wrap your inputs in tensor subclasses; e.g., if you are just |
| interested in logging intermediate computations. |
| |
| * You want to control the order of execution of various tensor |
| subclasses explicitly, rather than implicitly via the return of |
| ``NotImplemented``. |
| |
| Independent subclasses of :class:`TorchDispatchMode` are compositional: |
| modes can be pushed onto a stack using ``with MyMode():``. |
| When you call functions in the PyTorch API inside your |
| ``__torch_dispatch__`` implementation, by default, they will forward on to |
| the next mode on the mode stack. If you want recursively call back into |
| your current ``__torch_dispatch__`` implementation, either explicitly |
| invoke ``self.__torch_dispatch__(...)``, or use the context manager |
| ``__torch_dispatch__(self)`` to make PyTorch |
| API self-referential (beware of infinite loops, in this case!) |
| """ |
| def __init__(self, _dispatch_key=None): |
| if _dispatch_key is not None: |
| assert isinstance(_dispatch_key, torch._C.DispatchKey) |
| self.__dict__['_dispatch_key'] = _dispatch_key |
| |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| raise NotImplementedError() |
| |
| def __enter__(self): |
| _push_mode(self, self.__dict__.get("_dispatch_key", None)) |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| _pop_mode(self.__dict__.get("_dispatch_key", None)) |
| |
| @classmethod |
| def push(cls, *args, **kwargs): |
| warnings.warn("`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`") |
| instance = cls(*args, **kwargs) |
| return instance |
| |
| def _get_current_dispatch_mode(): |
| stack_len = _len_torch_dispatch_stack() |
| return _get_dispatch_stack_at(stack_len - 1) if stack_len > 0 else None |
| |
| |
| def _get_current_dispatch_mode_stack(): |
| stack_len = _len_torch_dispatch_stack() |
| return [_get_dispatch_stack_at(i) for i in range(stack_len)] |
| |
| def _push_mode(mode, k: Optional[DispatchKey] = None): |
| if k is not None: |
| from torch._ops import push_mode_for_key, get_cached_ops |
| # See Note [Not Caching Per-Dispatch-Key Mode Handlers] |
| # Clear the cache of every op that has been used so far, for this particular key. |
| ks = torch._C._functionality_to_backend_keys(k) |
| for op in get_cached_ops(): |
| for key in ks: |
| op._uncache_dispatch(key) |
| push_mode_for_key(k, mode) |
| # Note [Per-Dispatch-Key Modes Must Be Reentrant] |
| # The idea here is that we are allowed to push modes onto any dispatch key's mode stack, but: |
| # (1) We **always** push the mode onto the python mode stack. Operators can have fallthrough |
| # kernels registered to any dispatch key, so we use the Python mode stack as a catchall, |
| # to guarantee that every op will be seen by our mode. |
| # (2) We expect the mode that you push to handle being re-entrant: If we end up invoking the mode |
| # at both the Autograd key and the Python key, nothing bad should happen. |
| # The main use case for this is pre-autograd tracing with TorchProxyDispatchMode. |
| _push_on_torch_dispatch_stack(mode) |
| |
| |
| def _pop_mode(k: Optional[DispatchKey] = None): |
| m = _pop_torch_dispatch_stack() |
| if k is not None: |
| from torch._ops import pop_mode_for_key |
| tmp = pop_mode_for_key(k) |
| assert m is tmp |
| return m |
| |
| |
| @contextlib.contextmanager |
| def _pop_mode_temporarily(k: Optional[DispatchKey] = None): |
| old = _pop_mode(k) |
| try: |
| yield old |
| finally: |
| _push_mode(old, k) |
| |
| |
| @contextlib.contextmanager |
| def _disable_current_modes(): |
| mode_len = _len_torch_dispatch_stack() |
| old_modes = [_pop_mode() for _ in range(mode_len)] |
| try: |
| yield old_modes |
| finally: |
| for mode in reversed(old_modes): |
| _push_mode(mode) |
| |
| |
| class BaseTorchDispatchMode(TorchDispatchMode): |
| def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
| if kwargs is None: |
| kwargs = {} |
| return func(*args, **kwargs) |