Return mode object from __enter__ (#80998)
This makes `with Mode() as m:` work.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80998
Approved by: https://github.com/samdow
diff --git a/test/test_overrides.py b/test/test_overrides.py
index 6e992b4..af58a10 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -1323,8 +1323,7 @@
class A(TorchFunctionMode):
pass
- x = A()
- with x:
+ with A() as x:
pass
with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 109e7b0..9cdbf8a 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -1143,8 +1143,7 @@
__torch_function__ = torch._C._disabled_torch_function_impl
a = SubTensor(torch.randn(2))
- mode = PoliteMode()
- with mode:
+ with PoliteMode() as mode:
a.abs()
self.assertEqual(mode.pre_count, 2)
diff --git a/torch/overrides.py b/torch/overrides.py
index 5ed7f64..ba9ad31 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -1859,6 +1859,7 @@
else:
self.ancestors = self.inner.ancestors.union({self.inner})
_set_torch_function_mode(self)
+ return self
def __exit__(self, exc_type, exc_val, exc_tb):
_set_torch_function_mode(self.inner)
diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py
index c15ac26..21c4018 100644
--- a/torch/utils/_mode_utils.py
+++ b/torch/utils/_mode_utils.py
@@ -1,9 +1,11 @@
import functools
import torch
-from typing import Iterator
+from typing import Iterator, TypeVar
from dataclasses import dataclass
from contextlib import contextmanager
+T = TypeVar('T')
+
# This file has all the logic to dedupe logic between torch dispatch and
# torch function modes
#
@@ -51,7 +53,7 @@
# shared version of enable_torch_function/enable_torch_dispatch_mode in order to deduplicate the code.
# The differences between the modes are captured by `mode_info` and then queried when they're
# needed during the function's invocation
-def _enable_mode(mode, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
+def _enable_mode(mode: T, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[T]:
if not (
mode is None or
isinstance(mode, mode_info.mode_class) or
@@ -61,7 +63,7 @@
f'or None as an argument got {type(mode)} instead')
old = mode_info.get_mode()
if old is mode:
- yield
+ yield mode # type: ignore[misc]
return
if old is not None and not ignore_preexisting and old is not replace:
if isinstance(mode, mode_info.mode_class):
@@ -86,7 +88,7 @@
)
mode_info.set_mode(mode)
try:
- yield
+ yield mode # type: ignore[misc]
finally:
mode_info.set_mode(old)
diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py
index 0bf0795..496d97f 100644
--- a/torch/utils/_python_dispatch.py
+++ b/torch/utils/_python_dispatch.py
@@ -159,6 +159,7 @@
else:
self.ancestors = self.inner.ancestors.union({self.inner})
_set_torch_dispatch_mode(self)
+ return self
def __exit__(self, exc_type, exc_val, exc_tb):
_set_torch_dispatch_mode(self.inner)