Add mechanism to disable the "saved tensors hooks" feature (#85553)

The rationale for this is that functorch doesn't work with saved
variable hooks at the moment or checkpointing and we need some way to
disable it.

Concretely:
- there's a context manager that does the disabling
- this feature is disabled on a thread-local basis
- one can set an error message or use the default error message that
says the feature has been disabled

Since it is thread local I needed to update ATen/ThreadLocalState. To
make things nicer, this PR refactors all the "saved tensors hooks"
related TLS things into a single struct.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85553
Approved by: https://github.com/soulitzer
diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp
index aff6ddd..d18c780 100644
--- a/aten/src/ATen/SavedTensorHooks.cpp
+++ b/aten/src/ATen/SavedTensorHooks.cpp
@@ -5,46 +5,78 @@
 namespace at {
 
 namespace {
-  // PyObject is defined in c10/util/python_stub.h
-  thread_local std::stack<std::pair<PyObject*, PyObject*>> stack;
+  thread_local impl::SavedTensorDefaultHooksTLS tls;
 
   // This flag is set to true the first time default hooks are registered
   // and left at true for the rest of the execution.
   // It's an optimization so that users who never use default hooks don't need to
   // read the thread_local variables pack_hook_ and unpack_hook_.
-  static bool is_enabled(false);
+  static bool is_initialized(false);
+}
+
+static void assertSavedTensorHooksNotDisabled() {
+  TORCH_CHECK(SavedTensorDefaultHooks::is_enabled(), tls.disabled_error_message.value());
+}
+
+bool SavedTensorDefaultHooks::is_enabled() {
+  // See NOTE: [disabled_error_message invariant]
+  return !tls.disabled_error_message.has_value();
+}
+
+void SavedTensorDefaultHooks::disable(const std::string& message) {
+  tls.disabled_error_message = message;
+  if (tls.stack.size() > 0) {
+    assertSavedTensorHooksNotDisabled();
+  }
 }
 
 void SavedTensorDefaultHooks::enable() {
-  is_enabled = true;
+  tls.disabled_error_message = c10::nullopt;
+}
+
+const optional<std::string>& SavedTensorDefaultHooks::get_disabled_error_message() {
+  return tls.disabled_error_message;
+}
+
+const impl::SavedTensorDefaultHooksTLS& SavedTensorDefaultHooks::get_tls_state() {
+  return tls;
+}
+
+void SavedTensorDefaultHooks::set_tls_state(const impl::SavedTensorDefaultHooksTLS& state) {
+  tls = state;
+}
+
+void SavedTensorDefaultHooks::lazy_initialize() {
+  is_initialized = true;
 }
 
 void SavedTensorDefaultHooks::push_hooks(PyObject* pack_hook, PyObject* unpack_hook) {
   // Reference counting is handled by the caller of `push_hooks`
-  TORCH_INTERNAL_ASSERT(is_enabled);
+  TORCH_INTERNAL_ASSERT(is_initialized);
   TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
-  stack.push(std::make_pair(pack_hook, unpack_hook));
+  assertSavedTensorHooksNotDisabled();
+  tls.stack.push(std::make_pair(pack_hook, unpack_hook));
 }
 
 void SavedTensorDefaultHooks::pop_hooks() {
   // Reference counting is handled by the caller of `pop_hooks`
-  TORCH_INTERNAL_ASSERT(is_enabled && !stack.empty());
-  stack.pop();
+  TORCH_INTERNAL_ASSERT(is_initialized && !tls.stack.empty());
+  tls.stack.pop();
 }
 
 std::pair<PyObject*, PyObject*> SavedTensorDefaultHooks::get_hooks() {
-  if (!is_enabled || stack.empty()) {
+  if (!is_initialized || tls.stack.empty()) {
     return std::make_pair(nullptr, nullptr);
   }
-  return stack.top();
+  return tls.stack.top();
 }
 
 std::stack<std::pair<PyObject*, PyObject*>> SavedTensorDefaultHooks::get_stack() {
-  return stack;
+  return tls.stack;
 }
 
 void SavedTensorDefaultHooks::set_stack(std::stack<std::pair<PyObject*, PyObject*>> stack_) {
-  stack = stack_;
+  tls.stack = stack_;
 }
 
 }
diff --git a/aten/src/ATen/SavedTensorHooks.h b/aten/src/ATen/SavedTensorHooks.h
index 0cdfa3c..6085a3f 100644
--- a/aten/src/ATen/SavedTensorHooks.h
+++ b/aten/src/ATen/SavedTensorHooks.h
@@ -1,20 +1,52 @@
 #pragma once
 
 #include <c10/macros/Export.h>
+#include <c10/util/Optional.h>
 #include <c10/util/python_stub.h>
 #include <stack>
+#include <string>
 
 #include <utility>
 
 namespace at {
 
+namespace impl {
+
+struct TORCH_API SavedTensorDefaultHooksTLS {
+  // PyObject is defined in c10/util/python_stub.h
+  std::stack<std::pair<PyObject*, PyObject*>> stack;
+
+  // See NOTE: [Disabling SavedTensorDefaultHooks] for context
+  // NOTE: [disabled_error_message invariant]
+  // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled
+  // We did this for efficiency (so we didn't have to keep a separate bool
+  // around)
+  c10::optional<std::string> disabled_error_message;
+};
+
+} // namespace impl
+
 struct TORCH_API SavedTensorDefaultHooks {
   static void push_hooks(PyObject* pack_hook, PyObject* unpack_hook);
   static void pop_hooks();
   static std::pair<PyObject*, PyObject*> get_hooks();
-  static void enable();
+  static void lazy_initialize();
   static std::stack<std::pair<PyObject*, PyObject*>> get_stack();
   static void set_stack(std::stack<std::pair<PyObject*, PyObject*>>);
+
+  static const impl::SavedTensorDefaultHooksTLS& get_tls_state();
+  static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls);
+
+  // NOTE: [Disabling SavedTensorDefaultHooks]
+  // A developer of a PyTorch feature may choose to disable SavedTensorDefault
+  // hooks, especially if their feature does not work with it. If they are
+  // disabled, then the following will raise an error:
+  // - Attempting to push_hooks
+  // - calling disable(message) with a non-zero stack (from get_stack) size
+  static void disable(const std::string& error_message);
+  static void enable();
+  static bool is_enabled();
+  static const optional<std::string>& get_disabled_error_message();
 };
 
 } // namespace at
diff --git a/aten/src/ATen/ThreadLocalState.cpp b/aten/src/ATen/ThreadLocalState.cpp
index 422c1dc..9a5f079 100644
--- a/aten/src/ATen/ThreadLocalState.cpp
+++ b/aten/src/ATen/ThreadLocalState.cpp
@@ -18,7 +18,7 @@
       python_torch_function_state_(at::impl::PythonTorchFunctionTLS::get_state()) {
   rf_tls_ = at::get_record_function_tls_();
 
-  saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack();
+  saved_tensors_default_hooks_state_ = at::SavedTensorDefaultHooks::get_tls_state();
 
   torch_dispatch_mode_state_ = c10::impl::TorchDispatchModeTLS::get_state();
 }
@@ -40,7 +40,7 @@
 
   at::set_record_function_tls_(state.rf_tls_);
 
-  at::SavedTensorDefaultHooks::set_stack(state.saved_tensors_default_hooks_);
+  at::SavedTensorDefaultHooks::set_tls_state(state.saved_tensors_default_hooks_state_);
 
   c10::impl::PythonDispatcherTLS::set_state(state.python_dispatcher_state_);
 
diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h
index cfae7db..2a3ca0e 100644
--- a/aten/src/ATen/ThreadLocalState.h
+++ b/aten/src/ATen/ThreadLocalState.h
@@ -9,6 +9,7 @@
 
 #include <ATen/FuncTorchTLS.h>
 #include <ATen/PythonTorchFunctionTLS.h>
+#include <ATen/SavedTensorHooks.h>
 #include <ATen/record_function.h>
 #include <c10/core/impl/PythonDispatcherTLS.h>
 #include <c10/core/impl/TorchDispatchModeTLS.h>
@@ -65,7 +66,7 @@
   at::impl::PythonTorchFunctionTLS python_torch_function_state_;
 
   // TLS for saved tensors default hooks
-  std::stack<std::pair<PyObject*, PyObject*>> saved_tensors_default_hooks_;
+  at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_;
 
   friend class ThreadLocalStateGuard;
 };
diff --git a/test/test_autograd.py b/test/test_autograd.py
index bc575c6..bce652f 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -6718,6 +6718,32 @@
         self.assertEqual(2 * 5 * 5 * a, a.grad)
         self.assertEqual(2 * 3 * 3 * b, b.grad)
 
+    def test_disabling_saved_tensor_hooks(self):
+        with torch.autograd.graph._disable_saved_tensors_hooks("error message"):
+            with self.assertRaisesRegex(RuntimeError, "error message"):
+                with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+                    pass
+
+        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
+        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+            with self.assertRaisesRegex(RuntimeError, "error message"):
+                with torch.autograd.graph._disable_saved_tensors_hooks("error message"):
+                    pass
+
+        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
+    def test_disabling_saved_tensor_hooks_nested(self):
+        with torch.autograd.graph._disable_saved_tensors_hooks("outer"):
+            with self.assertRaisesRegex(RuntimeError, "inner"):
+                with torch.autograd.graph._disable_saved_tensors_hooks("inner"):
+                    with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
+                        pass
+
+            self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
+        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
+
     def test_save_on_cpu_and_checkpoint(self):
         a = torch.randn(2, 2, requires_grad=True)
 
diff --git a/torch/_C/_autograd.pyi b/torch/_C/_autograd.pyi
index 20f5819..bddda0a 100644
--- a/torch/_C/_autograd.pyi
+++ b/torch/_C/_autograd.pyi
@@ -1,4 +1,4 @@
-from typing import List, Set, Callable, Any, Union
+from typing import List, Set, Callable, Any, Union, Optional
 from enum import Enum
 
 import torch
@@ -83,3 +83,7 @@
 def _enable_profiler_legacy(config: ProfilerConfig) -> None: ...
 def _disable_profiler_legacy() -> List[List[ProfilerEvent]]: ...
 def _profiler_type() -> ActiveProfilerType: ...
+
+def _saved_tensors_hooks_enable() -> None: ...
+def _saved_tensors_hooks_disable(message: str) -> None: ...
+def _saved_tensors_hooks_get_disabled_error_message() -> Optional[str]: ...
diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py
index 05c0d51..95b76f8 100644
--- a/torch/autograd/graph.py
+++ b/torch/autograd/graph.py
@@ -1,4 +1,5 @@
 import torch
+import contextlib
 from typing import Callable, Any
 
 
@@ -132,3 +133,17 @@
             return tensor.to(device, non_blocking=pin_memory)
 
         super().__init__(pack_to_cpu, unpack_from_cpu)
+
+
+@contextlib.contextmanager
+def _disable_saved_tensors_hooks(error_message):
+    try:
+        maybe_prev_message = torch._C._autograd._saved_tensors_hooks_get_disabled_error_message()
+        torch._C._autograd._saved_tensors_hooks_disable(error_message)
+        yield
+    finally:
+        # See NOTE: [disabled_error_message invariant]
+        if maybe_prev_message is None:
+            torch._C._autograd._saved_tensors_hooks_enable()
+        else:
+            torch._C._autograd._saved_tensors_hooks_disable(maybe_prev_message)
diff --git a/torch/csrc/autograd/init.cpp b/torch/csrc/autograd/init.cpp
index 5c16e4f..a1d6de2 100644
--- a/torch/csrc/autograd/init.cpp
+++ b/torch/csrc/autograd/init.cpp
@@ -1,6 +1,7 @@
 #include <torch/csrc/python_headers.h>
 
 #include <ATen/PythonTorchFunctionTLS.h>
+#include <ATen/SavedTensorHooks.h>
 #include <ATen/autocast_mode.h>
 #include <ATen/core/PythonFallbackKernel.h>
 #include <ATen/record_function.h>
@@ -310,6 +311,14 @@
   });
   m.def("_clear_callbacks", []() { at::clearCallbacks(); });
   m.def(
+      "_saved_tensors_hooks_is_enabled",
+      at::SavedTensorDefaultHooks::is_enabled);
+  m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable);
+  m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable);
+  m.def(
+      "_saved_tensors_hooks_get_disabled_error_message",
+      at::SavedTensorDefaultHooks::get_disabled_error_message);
+  m.def(
       "_push_saved_tensors_default_hooks",
       [](py::function& pack_hook, py::function& unpack_hook) {
         torch::autograd::PyDefaultSavedVariableHooks::push_hooks(
diff --git a/torch/csrc/autograd/python_saved_variable_hooks.cpp b/torch/csrc/autograd/python_saved_variable_hooks.cpp
index 2bec0e7..8f8027f 100644
--- a/torch/csrc/autograd/python_saved_variable_hooks.cpp
+++ b/torch/csrc/autograd/python_saved_variable_hooks.cpp
@@ -56,7 +56,7 @@
 void PyDefaultSavedVariableHooks::push_hooks(
     py::function& pack_hook,
     py::function& unpack_hook) {
-  at::SavedTensorDefaultHooks::enable();
+  at::SavedTensorDefaultHooks::lazy_initialize();
   at::SavedTensorDefaultHooks::push_hooks(
       pack_hook.release().ptr(), unpack_hook.release().ptr());
 }