[dynamo] Optimize overheads from _TorchDynamoContext (#118070)

Based on `python benchmarks/dynamo/microbenchmarks/overheads.py`:
- Before `18.1us`
- After `12.2us`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118070
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
ghstack dependencies: #118065
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index a03e1ed..528a237 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -358,6 +358,10 @@
 if TYPE_CHECKING:
     from torch.utils._config_typing import *  # noqa: F401, F403
 
+    def _make_closure_patcher(**changes):
+        ...
+
+
 from torch.utils._config_module import install_config_module
 
 install_config_module(sys.modules[__name__])
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 0236516..084721a 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -288,19 +288,15 @@
     return unaltered_fn
 
 
-@contextlib.contextmanager
-def enable_dynamic(enable: Optional[bool] = None, export: bool = False):
-    if enable is None:
-        yield
-    elif enable:
+def make_set_enable_dynamic(enable: bool):
+    assert isinstance(enable, bool)
+    if enable:
         # Assume everything is dynamic by default
-        with config.patch(assume_static_by_default=False):
-            yield
+        return config._make_closure_patcher(assume_static_by_default=False)
     else:
-        with config.patch(
+        return config._make_closure_patcher(
             automatic_dynamic_shapes=False, assume_static_by_default=True
-        ):
-            yield
+        )
 
 
 class _TorchDynamoContext:
@@ -320,16 +316,33 @@
         assert callable(callback) or callback is False or callback is None
         self.callback: DynamoCallback = callback
         self.prior: Union[Unset, DynamoCallback] = unset
-        self.on_enter = on_enter
-        self.extra_ctx_ctor = backend_ctx_ctor
         self.first_ctx = first_ctx
         self.export = export
-        self.dynamic = dynamic
         self.compiler_config = compiler_config
-        self.set_backend_cache = backend_cache_manager(self.callback)
         self.cleanup_fns: List[Callable[[], Any]] = []
+        self.enter_exit_hooks = [backend_cache_manager(self.callback)]
         patch_fn()
 
+        if dynamic is not None:
+            self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
+
+        if on_enter is not nothing:
+            # this case is not common
+            def call_on_enter():
+                on_enter()
+                return nothing
+
+            self.enter_exit_hooks.append(call_on_enter)
+
+        if backend_ctx_ctor is not contextlib.nullcontext:
+            # this case is not common
+            def call_backend_ctx():
+                ctx = backend_ctx_ctor()
+                ctx.__enter__()
+                return functools.partial(ctx.__exit__, None, None, None)
+
+            self.enter_exit_hooks.append(call_backend_ctx)
+
     def __enter__(self):
         if config.raise_on_ctx_manager_usage:
             raise RuntimeError(
@@ -337,21 +350,13 @@
                 "Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
                 "to use torch._dynamo.optimize(...) as an annotation/decorator. "
             )
-        self.on_enter()
+        self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
         self.prior = set_eval_frame(self.callback)
-        self.cleanup_fns.append(self.set_backend_cache())
-        self.backend_ctx = self.extra_ctx_ctor()
-        self.backend_ctx.__enter__()
-        self.dynamic_ctx = enable_dynamic(self.dynamic, self.export)
-        self.dynamic_ctx.__enter__()
 
     def __exit__(self, exc_type, exc_val, exc_tb):
         assert self.prior is not unset
         set_eval_frame(self.prior)
         self.prior = unset
-        # TODO: This is totally not the right way to chain contexts manually
-        self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb)
-        self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)
         for cleanup in self.cleanup_fns:
             cleanup()
         self.cleanup_fns.clear()
@@ -401,8 +406,6 @@
             fn = external_utils.wrap_inline(fn)
 
         callback = self.callback
-        on_enter = self.on_enter
-        backend_ctx_ctor = self.extra_ctx_ctor
 
         @functools.wraps(fn)
         def _fn(*args, **kwargs):
@@ -426,19 +429,12 @@
                 else:
                     return fn(*args, **kwargs)
 
-            on_enter()
+            cleanups = [enter() for enter in self.enter_exit_hooks]
             prior = set_eval_frame(callback)
-            cleanups = (self.set_backend_cache(),)
-            backend_ctx = backend_ctx_ctor()
-            backend_ctx.__enter__()
-            dynamic_ctx = enable_dynamic(self.dynamic, self.export)
-            dynamic_ctx.__enter__()
             try:
                 return fn(*args, **kwargs)
             finally:
                 set_eval_frame(prior)
-                dynamic_ctx.__exit__(None, None, None)
-                backend_ctx.__exit__(None, None, None)
                 for cleanup in cleanups:
                     cleanup()
 
diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py
index 8b06a65..b3e4861 100644
--- a/torch/_dynamo/skipfiles.py
+++ b/torch/_dynamo/skipfiles.py
@@ -34,6 +34,7 @@
 import torch._inductor.test_operators
 import torch.distributed
 import torch.utils._content_store
+from ..utils import _config_module
 from .utils import getfile
 
 from .variables.functions import (
@@ -42,7 +43,6 @@
     UserMethodVariable,
 )
 
-
 """
 A note on skipfiles:
 
@@ -260,6 +260,7 @@
 SKIP_DIRS = [
     "<frozen importlib",
     "<__array_function__ internals>",
+    _config_module.__file__,
 ] + [_module_dir(m) for m in BUILTIN_SKIPLIST]
 
 SKIP_DIRS_RE = re.compile(r"match nothing^")
diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py
index 4412048..7089e1d 100644
--- a/torch/utils/_config_module.py
+++ b/torch/utils/_config_module.py
@@ -268,6 +268,37 @@
 
         return ConfigPatch()
 
+    def _make_closure_patcher(self, **changes):
+        """
+        A lower-overhead version of patch() for things on the critical path.
+
+        Usage:
+
+            # do this off the critical path
+            change_fn = config.make_closure_patcher(foo=True)
+
+            ...
+
+            revert = change_fn()
+            try:
+              ...
+            finally:
+                revert()
+
+        """
+        config = self._config
+
+        def change():
+            prior = {k: config[k] for k in changes}
+            config.update(changes)
+
+            def revert():
+                config.update(prior)
+
+            return revert
+
+        return change
+
 
 class ContextDecorator(contextlib.ContextDecorator):
     """