| import torch | |
| from typing import TypeVar | |
| from contextlib import contextmanager | |
| T = TypeVar('T') | |
| # returns if all are the same mode | |
| def all_same_mode(modes): | |
| return all(tuple(mode == modes[0] for mode in modes)) | |
| @contextmanager | |
| def no_dispatch(): | |
| guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined] | |
| try: | |
| yield | |
| finally: | |
| del guard |