Return mode object from __enter__ (#80998)

This makes `with Mode() as m:` work.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80998
Approved by: https://github.com/samdow
diff --git a/test/test_overrides.py b/test/test_overrides.py
index 6e992b4..af58a10 100644
--- a/test/test_overrides.py
+++ b/test/test_overrides.py
@@ -1323,8 +1323,7 @@
         class A(TorchFunctionMode):
             pass
 
-        x = A()
-        with x:
+        with A() as x:
             pass
 
         with self.assertRaisesRegex(RuntimeError, "has already been used as a mode. Please use a fresh version"):
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 109e7b0..9cdbf8a 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -1143,8 +1143,7 @@
             __torch_function__ = torch._C._disabled_torch_function_impl
 
         a = SubTensor(torch.randn(2))
-        mode = PoliteMode()
-        with mode:
+        with PoliteMode() as mode:
             a.abs()
 
         self.assertEqual(mode.pre_count, 2)
diff --git a/torch/overrides.py b/torch/overrides.py
index 5ed7f64..ba9ad31 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -1859,6 +1859,7 @@
             else:
                 self.ancestors = self.inner.ancestors.union({self.inner})
         _set_torch_function_mode(self)
+        return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         _set_torch_function_mode(self.inner)
diff --git a/torch/utils/_mode_utils.py b/torch/utils/_mode_utils.py
index c15ac26..21c4018 100644
--- a/torch/utils/_mode_utils.py
+++ b/torch/utils/_mode_utils.py
@@ -1,9 +1,11 @@
 import functools
 import torch
-from typing import Iterator
+from typing import Iterator, TypeVar
 from dataclasses import dataclass
 from contextlib import contextmanager
 
+T = TypeVar('T')
+
 # This file has all the logic to dedupe logic between torch dispatch and
 # torch function modes
 #
@@ -51,7 +53,7 @@
 # shared version of enable_torch_function/enable_torch_dispatch_mode in order to deduplicate the code.
 # The differences between the modes are captured by `mode_info` and then queried when they're
 # needed during the function's invocation
-def _enable_mode(mode, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[None]:
+def _enable_mode(mode: T, mode_info: _ModeInfo, *, replace=None, ignore_preexisting=False) -> Iterator[T]:
     if not (
         mode is None or
         isinstance(mode, mode_info.mode_class) or
@@ -61,7 +63,7 @@
                          f'or None as an argument got {type(mode)} instead')
     old = mode_info.get_mode()
     if old is mode:
-        yield
+        yield mode  # type: ignore[misc]
         return
     if old is not None and not ignore_preexisting and old is not replace:
         if isinstance(mode, mode_info.mode_class):
@@ -86,7 +88,7 @@
         )
     mode_info.set_mode(mode)
     try:
-        yield
+        yield mode  # type: ignore[misc]
     finally:
         mode_info.set_mode(old)
 
diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py
index 0bf0795..496d97f 100644
--- a/torch/utils/_python_dispatch.py
+++ b/torch/utils/_python_dispatch.py
@@ -159,6 +159,7 @@
             else:
                 self.ancestors = self.inner.ancestors.union({self.inner})
         _set_torch_dispatch_mode(self)
+        return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         _set_torch_dispatch_mode(self.inner)