Add mechanism to turn any RAII guard into a Python Context Manager (#102037)

This PR:
- adds a mechanism to turn any RAII guard into a Python Context Manager
- turns ExcludeDispatchKeyGuard into a context manager, and purges usages
of the older torch._C.ExcludeDispatchKeyGuard from the codebase.

The mechanism is that given a RAII guard, we construct a context
manager object that holds an optional guard. When we enter the context
manager we populate the guard, when we exit we reset it.

We don't delete torch._C.ExcludeDispatchKeyGuard for BC reasons (people
are using it in fbcode). If this code actually sticks
(it is using C++17 and that worries me a bit), then I'll apply the
change to other RAII guards we have, otherwise, we can write our own
std::apply.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102037
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index 9726c9b..f17eca5 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -5,7 +5,7 @@
 
 import torch.utils._pytree as pytree
 
-from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
+from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
 from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
 from torch._ops import HigherOrderOperator
 from torch._subclasses.fake_tensor import FakeTensorMode
@@ -128,8 +128,8 @@
     assert all(not f.requires_grad for f in flat_operands
                if isinstance(f, torch.Tensor))
 
-    guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
-    return cond(pred, true_fn, false_fn, *operands)
+    with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
+        return cond(pred, true_fn, false_fn, *operands)
 
 
 @cond.py_impl(ProxyTorchDispatchMode)
@@ -245,8 +245,7 @@
     unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
     unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
     mode = 'mutations_and_views' if reapply_views else 'mutations'
-    guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
-    try:
+    with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
         functional_true = functionalize(true_fn, remove=mode)
         functional_false = functionalize(false_fn, remove=mode)
         for branch in [true_fn, false_fn]:
@@ -261,9 +260,6 @@
         cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs)
         return _wrap_all_tensors_to_functional(cond_return, level=0)
 
-    finally:
-        del guard
-
 
 @cond.py_impl(torch._C._functorch.TransformType.Functionalize)
 def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index ea5fb2b..c8eef22 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -2,7 +2,7 @@
 
 import torch
 import torch.utils._pytree as pytree
-from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
+from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
 from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
 from torch._functorch.aot_autograd import create_joint, AOTConfig
 from torch._ops import HigherOrderOperator
@@ -268,8 +268,7 @@
     unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
     mode = 'mutations_and_views' if reapply_views else 'mutations'
 
-    guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
-    try:
+    with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
         functional_map_fn = functionalize(f, remove=mode)
         with disable_proxy_modes_tracing():
             example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
@@ -286,8 +285,6 @@
 
         map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
         return _wrap_all_tensors_to_functional(map_return, level=0)
-    finally:
-        del guard
 
 @map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
 def map_functionalize(interpreter, f, num_mapped, *args):
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index 726b28e..ce6ff33 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -27,7 +27,7 @@
 from torch._subclasses.fake_tensor import FakeTensorMode
 from functools import partial
 from functorch.experimental import replace_all_batch_norm_modules_
-from torch._C import ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
+from torch._C import _ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
 
 import functorch
 from functorch import (
@@ -3373,16 +3373,13 @@
         x = torch.randn([], device=device)
         expected = f(x)
 
-        guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
-        try:
+        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
             gm = make_fx(functorch.functionalize(f))(x)
             self.assertTrue('sin_' not in gm.code)
             self.assertEqual(gm(x), expected)
 
             local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
             self.assertTrue(local_exclude_set.has(DispatchKey.Functionalize))
-        finally:
-            del guard
 
     def test_can_use_vmap_when_key_is_excluded(self, device):
         def f(x):
@@ -3391,14 +3388,11 @@
         x = torch.randn(3, device=device)
         expected = vmap(f)(x)
 
-        guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched))
-        try:
+        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched)):
             result = vmap(f)(x)
             self.assertEqual(result, expected)
             local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
             self.assertTrue(local_exclude_set.has(DispatchKey.FuncTorchBatched))
-        finally:
-            del guard
 
     def test_can_use_grad_when_key_is_excluded(self, device):
         def f(x):
@@ -3407,14 +3401,11 @@
         x = torch.randn([], device=device)
         expected = grad(f)(x)
 
-        guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd))
-        try:
+        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd)):
             result = grad(f)(x)
             self.assertEqual(result, expected)
             local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
             self.assertTrue(local_exclude_set.has(DispatchKey.Autograd))
-        finally:
-            del guard
 
 
 class TestMakeFunctional(TestCase):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 7fafbbf..84134b6 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1363,7 +1363,7 @@
     dispatch_b: _dispatchkey,
 ) -> _bool: ...
 
-class ExcludeDispatchKeyGuard: ...
+class _ExcludeDispatchKeyGuard: ...
 class _AutoDispatchBelowAutograd: ...
 
 def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 51350e5..eeef77d 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -19,6 +19,7 @@
 #include <pybind11/stl.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
 #include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/python_raii.h>
 
 #include <iostream>
 
@@ -645,9 +646,15 @@
       [](c10::DispatchKey a, c10::DispatchKey b) {
         return c10::isIncludedInAlias(a, b);
       });
+
+  // DEPRECATED, please don't use this. Instead use
+  // torch._C._ExcludeDispatchKeyGuard
   py::class_<c10::impl::ExcludeDispatchKeyGuard>(m, "ExcludeDispatchKeyGuard")
       .def(py::init<c10::DispatchKeySet>());
 
+  py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
+      m, "_ExcludeDispatchKeyGuard");
+
   py::class_<at::AutoDispatchBelowAutograd>(m, "_AutoDispatchBelowAutograd")
       .def(py::init<>());
 
diff --git a/torch/csrc/utils/python_raii.h b/torch/csrc/utils/python_raii.h
new file mode 100644
index 0000000..e35a3b1
--- /dev/null
+++ b/torch/csrc/utils/python_raii.h
@@ -0,0 +1,46 @@
+#include <c10/util/Optional.h>
+#include <torch/csrc/utils/pybind.h>
+#include <tuple>
+
+namespace torch {
+namespace impl {
+
+template <typename GuardT, typename... Args>
+struct RAIIContextManager {
+  explicit RAIIContextManager(Args&&... args)
+      : args_(std::forward<Args>(args)...) {}
+
+  void enter() {
+    auto emplace = [&](Args... args) {
+      return guard_.emplace(std::forward<Args>(args)...);
+    };
+    std::apply(std::move(emplace), args_);
+  }
+
+  void exit() {
+    guard_ = c10::nullopt;
+  }
+
+ private:
+  c10::optional<GuardT> guard_;
+  std::tuple<Args...> args_;
+};
+
+// Turns a C++ RAII guard into a Python context manager.
+// See _ExcludeDispatchKeyGuard in python_dispatch.cpp for example.
+template <typename GuardT, typename... GuardArgs>
+void py_context_manager(const py::module& m, const char* name) {
+  using ContextManagerT = RAIIContextManager<GuardT, GuardArgs...>;
+  py::class_<ContextManagerT>(m, name)
+      .def(py::init<GuardArgs...>())
+      .def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
+      .def(
+          "__exit__",
+          [](ContextManagerT& guard,
+             py::object exc_type,
+             py::object exc_value,
+             py::object traceback) { guard.exit(); });
+}
+
+} // namespace impl
+} // namespace torch