Preserve dispatch state across function tracing (#122073)

If we throw an exception in the "wrong" place we can end up with the dispatch state being in a weird state which can cause all future dispatching to fail. Preserve and restore it as part of `preserve_global_state` so we know it's sane after that.

Also fake_tensor's in_kernel_invocation_manager() was leaving a bit set in the dispatcher (DispatchKey.Dense) which affected follow-on code.  Fixed that to reset after as well.

Repro:

before:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
======== 1 passed, 6173 deselected in 5.21s =============
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
========= 1 skipped, 6172 deselected, 1 error in 5.29s =========
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes on its own but failed when including the skipped test_export.py tests)
after:
```
$ rm test/dynamo_skips/TestSparseCPU.test_to_dense_with_gradcheck_sparse_cpu_complex64
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 6173 deselected in 5.42s =====================
$ PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_export.py test/test_sparse.py -k 'test_torch_inference_mode_ctx or test_to_dense_with_gradcheck_sparse_cpu_complex64'
===================== 1 passed, 1 skipped, 6172 deselected in 7.30s ======================
```
(note that test_to_dense_with_gradcheck_sparse_cpu_complex64 passes in both runs)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122073
Approved by: https://github.com/zou3519
diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h
index ef4acbb..176d0a6 100644
--- a/c10/core/impl/LocalDispatchKeySet.h
+++ b/c10/core/impl/LocalDispatchKeySet.h
@@ -117,14 +117,16 @@
 
 struct C10_API ForceDispatchKeyGuard {
  public:
+  ForceDispatchKeyGuard()
+      : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {}
   ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set)
-      : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
+      : ForceDispatchKeyGuard() {
     c10::impl::_force_tls_local_dispatch_key_set(key_set);
   }
   ForceDispatchKeyGuard(
       c10::DispatchKeySet include,
       c10::DispatchKeySet exclude)
-      : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
+      : ForceDispatchKeyGuard() {
     auto updated_set = saved_keyset_;
     updated_set.included_ = include;
     updated_set.excluded_ = exclude;
diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py
index ce7a233..1380de4 100644
--- a/test/inductor/test_cpu_cpp_wrapper.py
+++ b/test/inductor/test_cpu_cpp_wrapper.py
@@ -133,16 +133,21 @@
         tests.setUpClass()
         tests.setUp()
         try:
-            _, code = test_torchinductor.run_and_get_cpp_code(
-                func, *func_inputs if func_inputs else []
-            )
-            self.assertEqual("CppWrapperCodeCache" in code, True)
-            self.assertTrue(
-                all(
-                    code.count(string) == code_string_count[string]
-                    for string in code_string_count
+            with torch._C._PreserveDispatchKeyGuard():
+                torch._C._dispatch_tls_set_dispatch_key_included(
+                    torch._C.DispatchKey.Dense, True
                 )
-            )
+
+                _, code = test_torchinductor.run_and_get_cpp_code(
+                    func, *func_inputs if func_inputs else []
+                )
+                self.assertEqual("CppWrapperCodeCache" in code, True)
+                self.assertTrue(
+                    all(
+                        code.count(string) == code_string_count[string]
+                        for string in code_string_count
+                    )
+                )
         finally:
             tests.tearDown()
             tests.tearDownClass()
diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py
index 78f9bee..b662e24 100644
--- a/test/inductor/test_cuda_cpp_wrapper.py
+++ b/test/inductor/test_cuda_cpp_wrapper.py
@@ -3,6 +3,7 @@
 import unittest
 from typing import NamedTuple
 
+import torch
 from torch._inductor import config
 from torch._inductor.test_case import TestCase as InductorTestCase
 from torch.testing._internal.common_device_type import (
@@ -142,16 +143,21 @@
         tests.setUpClass()
         tests.setUp()
         try:
-            _, code = test_torchinductor.run_and_get_cpp_code(
-                func, *func_inputs if func_inputs else []
-            )
-            self.assertEqual("CppWrapperCodeCache" in code, True)
-            self.assertTrue(
-                all(
-                    code.count(string) == code_string_count[string]
-                    for string in code_string_count
+            with torch._C._PreserveDispatchKeyGuard():
+                torch._C._dispatch_tls_set_dispatch_key_included(
+                    torch._C.DispatchKey.Dense, True
                 )
-            )
+
+                _, code = test_torchinductor.run_and_get_cpp_code(
+                    func, *func_inputs if func_inputs else []
+                )
+                self.assertEqual("CppWrapperCodeCache" in code, True)
+                self.assertTrue(
+                    all(
+                        code.count(string) == code_string_count[string]
+                        for string in code_string_count
+                    )
+                )
         finally:
             tests.tearDown()
             tests.tearDownClass()
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 42d0a52..b39f0f4 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1532,6 +1532,11 @@
     def __enter__(self): ...
     def __exit__(self, exc_type, exc_value, traceback): ...
 
+class _PreserveDispatchKeyGuard:
+    def __init__(self): ...
+    def __enter__(self): ...
+    def __exit__(self, exc_type, exc_value, traceback): ...
+
 class _AutoDispatchBelowAutograd:
     def __init__(self): ...
     def __enter__(self): ...
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 93ff6d9..41a37cb 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -153,35 +153,42 @@
     def _fn(*args, **kwargs):
         guards = GlobalStateGuard()
         prior_grad_mode = torch.is_grad_enabled()
-        prior_inference_mode = torch.is_inference_mode_enabled()
-        prior_deterministic = torch.are_deterministic_algorithms_enabled()
-        prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
-        py_rng_state = random.getstate()
-        torch_rng_state = torch.random.get_rng_state()
-        if torch.cuda.is_available():
-            cuda_rng_state = torch.cuda.get_rng_state()
-        allow_tf32 = torch._C._get_cublas_allow_tf32()
-        prior_fwd_from_src = torch.fx.graph_module._forward_from_src
-        torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
-        cleanup = setup_compile_debug()
-        try:
-            return fn(*args, **kwargs)
-        finally:
-            cleanup.close()
-            torch._C._set_grad_enabled(prior_grad_mode)
-            torch.torch.autograd.grad_mode._enter_inference_mode(prior_inference_mode)
-            torch.use_deterministic_algorithms(
-                prior_deterministic, warn_only=prior_warn_only
-            )
-            random.setstate(py_rng_state)
-            torch.random.set_rng_state(torch_rng_state)
+        # Just in case we get left in a bad dispatch state we want to restore
+        # it. This can happen because the dispatch bits aren't a true
+        # stack/counter - so we can't just increment/decrement them as we enter
+        # and leave.
+        with torch._C._PreserveDispatchKeyGuard():
+            prior_inference_mode = torch.is_inference_mode_enabled()
+            prior_deterministic = torch.are_deterministic_algorithms_enabled()
+            prior_warn_only = torch.is_deterministic_algorithms_warn_only_enabled()
+            py_rng_state = random.getstate()
+            torch_rng_state = torch.random.get_rng_state()
             if torch.cuda.is_available():
-                torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
-            torch._C._set_cublas_allow_tf32(allow_tf32)
-            torch.fx.graph_module._forward_from_src = prior_fwd_from_src
-            assert (
-                guards.check()
-            ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
+                cuda_rng_state = torch.cuda.get_rng_state()
+            allow_tf32 = torch._C._get_cublas_allow_tf32()
+            prior_fwd_from_src = torch.fx.graph_module._forward_from_src
+            torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
+            cleanup = setup_compile_debug()
+            try:
+                return fn(*args, **kwargs)
+            finally:
+                cleanup.close()
+                torch._C._set_grad_enabled(prior_grad_mode)
+                torch.torch.autograd.grad_mode._enter_inference_mode(
+                    prior_inference_mode
+                )
+                torch.use_deterministic_algorithms(
+                    prior_deterministic, warn_only=prior_warn_only
+                )
+                random.setstate(py_rng_state)
+                torch.random.set_rng_state(torch_rng_state)
+                if torch.cuda.is_available():
+                    torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
+                torch._C._set_cublas_allow_tf32(allow_tf32)
+                torch.fx.graph_module._forward_from_src = prior_fwd_from_src
+                assert (
+                    guards.check()
+                ), f"Global {guards.reason()}state changed while dynamo tracing, please report a bug"
 
     _fn._torchdynamo_orig_callable = fn  # type: ignore[attr-defined]
     return _fn
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 4bb5232..4cf1485 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -336,15 +336,17 @@
     meta_in_tls = torch._C._meta_in_tls_dispatch_include()
     assert meta_in_tls == prev_in_kernel, f"{meta_in_tls}, {prev_in_kernel}"
 
-    guard = torch._C._DisableTorchDispatch()  # type: ignore[attr-defined]
-    fake_mode.in_kernel_invocation = True
-    torch._C._set_meta_in_tls_dispatch_include(True)
-    try:
-        yield
-    finally:
-        fake_mode.in_kernel_invocation = prev_in_kernel
-        torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
-        del guard
+    with torch._C._DisableTorchDispatch():
+        fake_mode.in_kernel_invocation = True
+        # Unfortunately _set_meta_in_tls_dispatch_include(False) can leave
+        # `Dense` turned on (because it's implied by `Meta`)
+        with torch._C._PreserveDispatchKeyGuard():
+            torch._C._set_meta_in_tls_dispatch_include(True)
+            try:
+                yield
+            finally:
+                fake_mode.in_kernel_invocation = prev_in_kernel
+                # torch._C._set_meta_in_tls_dispatch_include(prev_in_kernel)
 
 
 # Return if the function allows Python numbers to bind to Tensors
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 3d1c6d0..2b905a9 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -721,6 +721,8 @@
       c10::impl::ForceDispatchKeyGuard,
       c10::DispatchKeySet,
       c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
+  py_context_manager<c10::impl::ForceDispatchKeyGuard>(
+      m, "_PreserveDispatchKeyGuard");
   py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
       m, "_IncludeDispatchKeyGuard");
   py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(