[dynamo] Optimize overheads from _TorchDynamoContext (#118070)
Based on `python benchmarks/dynamo/microbenchmarks/overheads.py`:
- Before `18.1us`
- After `12.2us`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118070
Approved by: https://github.com/yanboliang, https://github.com/anijain2305
ghstack dependencies: #118065
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index a03e1ed..528a237 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -358,6 +358,10 @@
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
+ def _make_closure_patcher(**changes):
+ ...
+
+
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 0236516..084721a 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -288,19 +288,15 @@
return unaltered_fn
-@contextlib.contextmanager
-def enable_dynamic(enable: Optional[bool] = None, export: bool = False):
- if enable is None:
- yield
- elif enable:
+def make_set_enable_dynamic(enable: bool):
+ assert isinstance(enable, bool)
+ if enable:
# Assume everything is dynamic by default
- with config.patch(assume_static_by_default=False):
- yield
+ return config._make_closure_patcher(assume_static_by_default=False)
else:
- with config.patch(
+ return config._make_closure_patcher(
automatic_dynamic_shapes=False, assume_static_by_default=True
- ):
- yield
+ )
class _TorchDynamoContext:
@@ -320,16 +316,33 @@
assert callable(callback) or callback is False or callback is None
self.callback: DynamoCallback = callback
self.prior: Union[Unset, DynamoCallback] = unset
- self.on_enter = on_enter
- self.extra_ctx_ctor = backend_ctx_ctor
self.first_ctx = first_ctx
self.export = export
- self.dynamic = dynamic
self.compiler_config = compiler_config
- self.set_backend_cache = backend_cache_manager(self.callback)
self.cleanup_fns: List[Callable[[], Any]] = []
+ self.enter_exit_hooks = [backend_cache_manager(self.callback)]
patch_fn()
+ if dynamic is not None:
+ self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
+
+ if on_enter is not nothing:
+ # this case is not common
+ def call_on_enter():
+ on_enter()
+ return nothing
+
+ self.enter_exit_hooks.append(call_on_enter)
+
+ if backend_ctx_ctor is not contextlib.nullcontext:
+ # this case is not common
+ def call_backend_ctx():
+ ctx = backend_ctx_ctor()
+ ctx.__enter__()
+ return functools.partial(ctx.__exit__, None, None, None)
+
+ self.enter_exit_hooks.append(call_backend_ctx)
+
def __enter__(self):
if config.raise_on_ctx_manager_usage:
raise RuntimeError(
@@ -337,21 +350,13 @@
"Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html "
"to use torch._dynamo.optimize(...) as an annotation/decorator. "
)
- self.on_enter()
+ self.cleanup_fns = [enter() for enter in self.enter_exit_hooks]
self.prior = set_eval_frame(self.callback)
- self.cleanup_fns.append(self.set_backend_cache())
- self.backend_ctx = self.extra_ctx_ctor()
- self.backend_ctx.__enter__()
- self.dynamic_ctx = enable_dynamic(self.dynamic, self.export)
- self.dynamic_ctx.__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
assert self.prior is not unset
set_eval_frame(self.prior)
self.prior = unset
- # TODO: This is totally not the right way to chain contexts manually
- self.dynamic_ctx.__exit__(exc_type, exc_val, exc_tb)
- self.backend_ctx.__exit__(exc_type, exc_val, exc_tb)
for cleanup in self.cleanup_fns:
cleanup()
self.cleanup_fns.clear()
@@ -401,8 +406,6 @@
fn = external_utils.wrap_inline(fn)
callback = self.callback
- on_enter = self.on_enter
- backend_ctx_ctor = self.extra_ctx_ctor
@functools.wraps(fn)
def _fn(*args, **kwargs):
@@ -426,19 +429,12 @@
else:
return fn(*args, **kwargs)
- on_enter()
+ cleanups = [enter() for enter in self.enter_exit_hooks]
prior = set_eval_frame(callback)
- cleanups = (self.set_backend_cache(),)
- backend_ctx = backend_ctx_ctor()
- backend_ctx.__enter__()
- dynamic_ctx = enable_dynamic(self.dynamic, self.export)
- dynamic_ctx.__enter__()
try:
return fn(*args, **kwargs)
finally:
set_eval_frame(prior)
- dynamic_ctx.__exit__(None, None, None)
- backend_ctx.__exit__(None, None, None)
for cleanup in cleanups:
cleanup()
diff --git a/torch/_dynamo/skipfiles.py b/torch/_dynamo/skipfiles.py
index 8b06a65..b3e4861 100644
--- a/torch/_dynamo/skipfiles.py
+++ b/torch/_dynamo/skipfiles.py
@@ -34,6 +34,7 @@
import torch._inductor.test_operators
import torch.distributed
import torch.utils._content_store
+from ..utils import _config_module
from .utils import getfile
from .variables.functions import (
@@ -42,7 +43,6 @@
UserMethodVariable,
)
-
"""
A note on skipfiles:
@@ -260,6 +260,7 @@
SKIP_DIRS = [
"<frozen importlib",
"<__array_function__ internals>",
+ _config_module.__file__,
] + [_module_dir(m) for m in BUILTIN_SKIPLIST]
SKIP_DIRS_RE = re.compile(r"match nothing^")
diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py
index 4412048..7089e1d 100644
--- a/torch/utils/_config_module.py
+++ b/torch/utils/_config_module.py
@@ -268,6 +268,37 @@
return ConfigPatch()
+ def _make_closure_patcher(self, **changes):
+ """
+ A lower-overhead version of patch() for things on the critical path.
+
+ Usage:
+
+ # do this off the critical path
+ change_fn = config.make_closure_patcher(foo=True)
+
+ ...
+
+ revert = change_fn()
+ try:
+ ...
+ finally:
+ revert()
+
+ """
+ config = self._config
+
+ def change():
+ prior = {k: config[k] for k in changes}
+ config.update(changes)
+
+ def revert():
+ config.update(prior)
+
+ return revert
+
+ return change
+
class ContextDecorator(contextlib.ContextDecorator):
"""