Add mechanism to turn any RAII guard into a Python Context Manager (#102037)
This PR:
- adds a mechanism to turn any RAII guard into a Python Context Manager
- turns ExcludeDispatchKeyGuard into a context manager, and purges usages
of the older torch._C.ExcludeDispatchKeyGuard from the codebase.
The mechanism is that given a RAII guard, we construct a context
manager object that holds an optional guard. When we enter the context
manager we populate the guard, when we exit we reset it.
We don't delete torch._C.ExcludeDispatchKeyGuard for BC reasons (people
are using it in fbcode). If this code actually sticks
(it is using C++17 and that worries me a bit), then I'll apply the
change to other RAII guards we have, otherwise, we can write our own
std::apply.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102037
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
diff --git a/functorch/experimental/_cond.py b/functorch/experimental/_cond.py
index 9726c9b..f17eca5 100644
--- a/functorch/experimental/_cond.py
+++ b/functorch/experimental/_cond.py
@@ -5,7 +5,7 @@
import torch.utils._pytree as pytree
-from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
+from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
@@ -128,8 +128,8 @@
assert all(not f.requires_grad for f in flat_operands
if isinstance(f, torch.Tensor))
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU))
- return cond(pred, true_fn, false_fn, *operands)
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.AutogradCPU)):
+ return cond(pred, true_fn, false_fn, *operands)
@cond.py_impl(ProxyTorchDispatchMode)
@@ -245,8 +245,7 @@
unwrapped_inputs = _unwrap_all_tensors_from_functional(inputs, reapply_views=reapply_views)
unwrapped_pred = _unwrap_all_tensors_from_functional(pred, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
- try:
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_true = functionalize(true_fn, remove=mode)
functional_false = functionalize(false_fn, remove=mode)
for branch in [true_fn, false_fn]:
@@ -261,9 +260,6 @@
cond_return = cond(unwrapped_pred, functional_true, functional_false, unwrapped_inputs)
return _wrap_all_tensors_to_functional(cond_return, level=0)
- finally:
- del guard
-
@cond.py_impl(torch._C._functorch.TransformType.Functionalize)
def cond_functionalize(interpreter, pred, true_fn, false_fn, inputs):
diff --git a/functorch/experimental/_map.py b/functorch/experimental/_map.py
index ea5fb2b..c8eef22 100644
--- a/functorch/experimental/_map.py
+++ b/functorch/experimental/_map.py
@@ -2,7 +2,7 @@
import torch
import torch.utils._pytree as pytree
-from torch._C import DispatchKey, DispatchKeySet, ExcludeDispatchKeyGuard
+from torch._C import DispatchKey, DispatchKeySet, _ExcludeDispatchKeyGuard
from torch._functorch.eager_transforms import _unwrap_all_tensors_from_functional, _wrap_all_tensors_to_functional, functionalize
from torch._functorch.aot_autograd import create_joint, AOTConfig
from torch._ops import HigherOrderOperator
@@ -268,8 +268,7 @@
unwrapped_args = _unwrap_all_tensors_from_functional(pos_args, reapply_views=reapply_views)
mode = 'mutations_and_views' if reapply_views else 'mutations'
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
- try:
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
functional_map_fn = functionalize(f, remove=mode)
with disable_proxy_modes_tracing():
example_inputs = (*_unstack_pytree(unwrapped_xs)[0], *unwrapped_args)
@@ -286,8 +285,6 @@
map_return = map_impl(functional_map_fn, num_mapped, *unwrapped_xs, *unwrapped_args)
return _wrap_all_tensors_to_functional(map_return, level=0)
- finally:
- del guard
@map_impl.py_impl(torch._C._functorch.TransformType.Functionalize)
def map_functionalize(interpreter, f, num_mapped, *args):
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
index 726b28e..ce6ff33 100644
--- a/test/functorch/test_eager_transforms.py
+++ b/test/functorch/test_eager_transforms.py
@@ -27,7 +27,7 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from functools import partial
from functorch.experimental import replace_all_batch_norm_modules_
-from torch._C import ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
+from torch._C import _ExcludeDispatchKeyGuard, DispatchKeySet, DispatchKey
import functorch
from functorch import (
@@ -3373,16 +3373,13 @@
x = torch.randn([], device=device)
expected = f(x)
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize))
- try:
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
gm = make_fx(functorch.functionalize(f))(x)
self.assertTrue('sin_' not in gm.code)
self.assertEqual(gm(x), expected)
local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
self.assertTrue(local_exclude_set.has(DispatchKey.Functionalize))
- finally:
- del guard
def test_can_use_vmap_when_key_is_excluded(self, device):
def f(x):
@@ -3391,14 +3388,11 @@
x = torch.randn(3, device=device)
expected = vmap(f)(x)
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched))
- try:
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.FuncTorchBatched)):
result = vmap(f)(x)
self.assertEqual(result, expected)
local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
self.assertTrue(local_exclude_set.has(DispatchKey.FuncTorchBatched))
- finally:
- del guard
def test_can_use_grad_when_key_is_excluded(self, device):
def f(x):
@@ -3407,14 +3401,11 @@
x = torch.randn([], device=device)
expected = grad(f)(x)
- guard = ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd))
- try:
+ with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Autograd)):
result = grad(f)(x)
self.assertEqual(result, expected)
local_exclude_set = torch._C._dispatch_tls_local_exclude_set()
self.assertTrue(local_exclude_set.has(DispatchKey.Autograd))
- finally:
- del guard
class TestMakeFunctional(TestCase):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 7fafbbf..84134b6 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1363,7 +1363,7 @@
dispatch_b: _dispatchkey,
) -> _bool: ...
-class ExcludeDispatchKeyGuard: ...
+class _ExcludeDispatchKeyGuard: ...
class _AutoDispatchBelowAutograd: ...
def _dispatch_print_registrations_for_dispatch_key(dispatch_key: str = "") -> None: ...
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 51350e5..eeef77d 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -19,6 +19,7 @@
#include <pybind11/stl.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/python_raii.h>
#include <iostream>
@@ -645,9 +646,15 @@
[](c10::DispatchKey a, c10::DispatchKey b) {
return c10::isIncludedInAlias(a, b);
});
+
+ // DEPRECATED, please don't use this. Instead use
+ // torch._C._ExcludeDispatchKeyGuard
py::class_<c10::impl::ExcludeDispatchKeyGuard>(m, "ExcludeDispatchKeyGuard")
.def(py::init<c10::DispatchKeySet>());
+ py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
+ m, "_ExcludeDispatchKeyGuard");
+
py::class_<at::AutoDispatchBelowAutograd>(m, "_AutoDispatchBelowAutograd")
.def(py::init<>());
diff --git a/torch/csrc/utils/python_raii.h b/torch/csrc/utils/python_raii.h
new file mode 100644
index 0000000..e35a3b1
--- /dev/null
+++ b/torch/csrc/utils/python_raii.h
@@ -0,0 +1,46 @@
+#include <c10/util/Optional.h>
+#include <torch/csrc/utils/pybind.h>
+#include <tuple>
+
+namespace torch {
+namespace impl {
+
+template <typename GuardT, typename... Args>
+struct RAIIContextManager {
+ explicit RAIIContextManager(Args&&... args)
+ : args_(std::forward<Args>(args)...) {}
+
+ void enter() {
+ auto emplace = [&](Args... args) {
+ return guard_.emplace(std::forward<Args>(args)...);
+ };
+ std::apply(std::move(emplace), args_);
+ }
+
+ void exit() {
+ guard_ = c10::nullopt;
+ }
+
+ private:
+ c10::optional<GuardT> guard_;
+ std::tuple<Args...> args_;
+};
+
+// Turns a C++ RAII guard into a Python context manager.
+// See _ExcludeDispatchKeyGuard in python_dispatch.cpp for example.
+template <typename GuardT, typename... GuardArgs>
+void py_context_manager(const py::module& m, const char* name) {
+ using ContextManagerT = RAIIContextManager<GuardT, GuardArgs...>;
+ py::class_<ContextManagerT>(m, name)
+ .def(py::init<GuardArgs...>())
+ .def("__enter__", [](ContextManagerT& guard) { guard.enter(); })
+ .def(
+ "__exit__",
+ [](ContextManagerT& guard,
+ py::object exc_type,
+ py::object exc_value,
+ py::object traceback) { guard.exit(); });
+}
+
+} // namespace impl
+} // namespace torch