Add mechanism to disable the "saved tensors hooks" feature (#85553)
The rationale for this is that functorch doesn't work with saved
variable hooks at the moment or checkpointing and we need some way to
disable it.
Concretely:
- there's a context manager that does the disabling
- this feature is disabled on a thread-local basis
- one can set an error message or use the default error message that
says the feature has been disabled
Since it is thread local I needed to update ATen/ThreadLocalState. To
make things nicer, this PR refactors all the "saved tensors hooks"
related TLS things into a single struct.
Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85553
Approved by: https://github.com/soulitzer
diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp
index aff6ddd..d18c780 100644
--- a/aten/src/ATen/SavedTensorHooks.cpp
+++ b/aten/src/ATen/SavedTensorHooks.cpp
@@ -5,46 +5,78 @@
namespace at {
namespace {
- // PyObject is defined in c10/util/python_stub.h
- thread_local std::stack<std::pair<PyObject*, PyObject*>> stack;
+ thread_local impl::SavedTensorDefaultHooksTLS tls;
// This flag is set to true the first time default hooks are registered
// and left at true for the rest of the execution.
// It's an optimization so that users who never use default hooks don't need to
// read the thread_local variables pack_hook_ and unpack_hook_.
- static bool is_enabled(false);
+ static bool is_initialized(false);
+}
+
+static void assertSavedTensorHooksNotDisabled() {
+ TORCH_CHECK(SavedTensorDefaultHooks::is_enabled(), tls.disabled_error_message.value());
+}
+
+bool SavedTensorDefaultHooks::is_enabled() {
+ // See NOTE: [disabled_error_message invariant]
+ return !tls.disabled_error_message.has_value();
+}
+
+void SavedTensorDefaultHooks::disable(const std::string& message) {
+ tls.disabled_error_message = message;
+ if (tls.stack.size() > 0) {
+ assertSavedTensorHooksNotDisabled();
+ }
}
void SavedTensorDefaultHooks::enable() {
- is_enabled = true;
+ tls.disabled_error_message = c10::nullopt;
+}
+
+const optional<std::string>& SavedTensorDefaultHooks::get_disabled_error_message() {
+ return tls.disabled_error_message;
+}
+
+const impl::SavedTensorDefaultHooksTLS& SavedTensorDefaultHooks::get_tls_state() {
+ return tls;
+}
+
+void SavedTensorDefaultHooks::set_tls_state(const impl::SavedTensorDefaultHooksTLS& state) {
+ tls = state;
+}
+
+void SavedTensorDefaultHooks::lazy_initialize() {
+ is_initialized = true;
}
void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
// Reference counting is handled by the caller of `push_hooks`
- TORCH_INTERNAL_ASSERT(is_enabled);
+ TORCH_INTERNAL_ASSERT(is_initialized);
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
- stack.push(std::make_pair(pack_hook, unpack_hook));
+ assertSavedTensorHooksNotDisabled();
+ tls.stack.push(std::make_pair(pack_hook, unpack_hook));
}
void SavedTensorDefaultHooks::pop_hooks() {
// Reference counting is handled by the caller of `pop_hooks`
- TORCH_INTERNAL_ASSERT(is_enabled && !stack.empty());
- stack.pop();
+ TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
+ tls.stack.pop();
}
std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
- if (!is_enabled || stack.empty()) {
+ if (!is_initialized || tls.stack.empty()) {
return std::make_pair(nullptr, nullptr);
}
- return stack.top();
+ return tls.stack.top();
}
std::stack<std::pair<PyObject*, PyObject*>> SavedTensorDefaultHooks::get_stack() {
- return stack;
+ return tls.stack;
}
void SavedTensorDefaultHooks::set_stack(std::stack<std::pair<PyObject*, PyObject*>> stack_) {
- stack = stack_;
+ tls.stack = stack_;
}
}
diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h
index 0cdfa3c..6085a3f 100644
--- a/aten/src/ATen/SavedTensorHooks.h
+++ b/aten/src/ATen/SavedTensorHooks.h
@@ -1,20 +1,52 @@
#pragma once
#include <c10/macros/Export.h>
+#include <c10/util/Optional.h>
#include <c10/util/python_stub.h>
#include <stack>
+#include <string>
#include <utility>
namespace at {
+namespace impl {
+
+struct TORCH_API SavedTensorDefaultHooksTLS {
+ // PyObject is defined in c10/util/python_stub.h
+ std::stack<std::pair<PyObject*, PyObject*>> stack;
+
+ // See NOTE: [Disabling SavedTensorDefaultHooks] for context
+ // NOTE: [disabled_error_message invariant]
+ // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
+ // We did this for efficiency (so we didn't have to keep a separate bool
+ // around)
+ c10::optional<std::string> disabled_error_message;
+};
+
+} // namespace impl
+
struct TORCH_API SavedTensorDefaultHooks {
static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
static void pop_hooks();
static std::pair<PyObject*, PyObject*> get_hooks();
- static void enable();
+ static void lazy_initialize();
static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
+
+ static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
+ static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
+
+ // NOTE: [Disabling SavedTensorDefaultHooks]
+ // A developer of a PyTorch feature may choose to disable SavedTensorDefault
+ // hooks, especially if their feature does not work with it. If they are
+ // disabled, then the following will raise an error:
+ // - Attempting to push_hooks
+ // - calling disable(message) with a non-zero stack (from get_stack) size
+ static void disable(const std::string& error_message);
+ static void enable();
+ static bool is_enabled();
+ static const optional<std::string>& get_disabled_error_message();
};
} // namespace at
diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp
index 422c1dc..9a5f079 100644
--- a/aten/src/ATen/ThreadLocalState.cpp
+++ b/aten/src/ATen/ThreadLocalState.cpp
@@ -18,7 +18,7 @@
python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
rf_tls_ = at::get_record_function_tls_();
- saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack();
+ saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state();
torch_dispatch_mode_state_ = c10::impl::TorchDispatchModeTLS::get_state();
}
@@ -40,7 +40,7 @@
at::set_record_function_tls_(state.rf_tls_);
- at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
+ at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h
index cfae7db..2a3ca0e 100644
--- a/aten/src/ATen/ThreadLocalState.h
+++ b/aten/src/ATen/ThreadLocalState.h
@@ -9,6 +9,7 @@
#include <ATen/FuncTorchTLS.h>
#include <ATen/PythonTorchFunctionTLS.h>
+#include <ATen/SavedTensorHooks.h>
#include <ATen/record_function.h>
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <c10/core/impl/TorchDispatchModeTLS.h>
@@ -65,7 +66,7 @@
at::impl::PythonTorchFunctionTLS python_torch_function_state_;
// TLS for saved tensors default hooks
- std::stack<std::pair<PyObject*, PyObject*>> saved_tensors_default_hooks_;
+ at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
friend class ThreadLocalStateGuard;
};
diff --git a/test/test_autograd.py b/test/test_autograd.py
index bc575c6..bce652f 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -6718,6 +6718,32 @@
self.assertEqual(2 * 5 * 5 * a, a.grad)
self.assertEqual(2 * 3 * 3 * b, b.grad)
+ def test_disabling_saved_tensor_hooks(self):
+ with torch.autograd.graph._disable_saved_tensors_hooks("error message"):
+ with self.assertRaisesRegex(RuntimeError, "error message"):
+ with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+ pass
+
+ self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
+ with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+ with self.assertRaisesRegex(RuntimeError, "error message"):
+ with torch.autograd.graph._disable_saved_tensors_hooks("error message"):
+ pass
+
+ self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
+ def test_disabling_saved_tensor_hooks_nested(self):
+ with torch.autograd.graph._disable_saved_tensors_hooks("outer"):
+ with self.assertRaisesRegex(RuntimeError, "inner"):
+ with torch.autograd.graph._disable_saved_tensors_hooks("inner"):
+ with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+ pass
+
+ self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
+ self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
def test_save_on_cpu_and_checkpoint(self):
a = torch.randn(2, 2, requires_grad=True)
diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi
index 20f5819..bddda0a 100644
--- a/torch/_C/_autograd.pyi
+++ b/torch/_C/_autograd.pyi
@@ -1,4 +1,4 @@
-from typing import List, Set, Callable, Any, Union
+from typing import List, Set, Callable, Any, Union, Optional
from enum import Enum
import torch
@@ -83,3 +83,7 @@
def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
def _profiler_type() -> ActiveProfilerType: ...
+
+def _saved_tensors_hooks_enable() -> None: ...
+def _saved_tensors_hooks_disable(message: str) -> None: ...
+def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py
index 05c0d51..95b76f8 100644
--- a/torch/autograd/graph.py
+++ b/torch/autograd/graph.py
@@ -1,4 +1,5 @@
import torch
+import contextlib
from typing import Callable, Any
@@ -132,3 +133,17 @@
return tensor.to(device, non_blocking=pin_memory)
super().__init__(pack_to_cpu, unpack_from_cpu)
+
+
+@contextlib.contextmanager
+def _disable_saved_tensors_hooks(error_message):
+ try:
+ maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
+ torch._C._autograd._saved_tensors_hooks_disable(error_message)
+ yield
+ finally:
+ # See NOTE: [disabled_error_message invariant]
+ if maybe_prev_message is None:
+ torch._C._autograd._saved_tensors_hooks_enable()
+ else:
+ torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 5c16e4f..a1d6de2 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -1,6 +1,7 @@
#include <torch/csrc/python_headers.h>
#include <ATen/PythonTorchFunctionTLS.h>
+#include <ATen/SavedTensorHooks.h>
#include <ATen/autocast_mode.h>
#include <ATen/core/PythonFallbackKernel.h>
#include <ATen/record_function.h>
@@ -310,6 +311,14 @@
});
m.def("_clear_callbacks", []() { at::clearCallbacks(); });
m.def(
+ "_saved_tensors_hooks_is_enabled",
+ at::SavedTensorDefaultHooks::is_enabled);
+ m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable);
+ m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable);
+ m.def(
+ "_saved_tensors_hooks_get_disabled_error_message",
+ at::SavedTensorDefaultHooks::get_disabled_error_message);
+ m.def(
"_push_saved_tensors_default_hooks",
[](py::function& pack_hook, py::function& unpack_hook) {
torch::autograd::PyDefaultSavedVariableHooks::push_hooks(
diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp
index 2bec0e7..8f8027f 100644
--- a/torch/csrc/autograd/python_saved_variable_hooks.cpp
+++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp
@@ -56,7 +56,7 @@
void PyDefaultSavedVariableHooks::push_hooks(
py::function& pack_hook,
py::function& unpack_hook) {
- at::SavedTensorDefaultHooks::enable();
+ at::SavedTensorDefaultHooks::lazy_initialize();
at::SavedTensorDefaultHooks::push_hooks(
pack_hook.release().ptr(), unpack_hook.release().ptr());
}