[dynamo] Reland #104317 - Lazy disable_dynamo API out-of-dynamo (#104664)
Internal failed because of torch.deploy issues with disable_dynamo in fx/* and _jit/* files. Removing disable_dynamo for both. Added a comment in the code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104664
Approved by: https://github.com/wconstab
diff --git a/torch/__init__.py b/torch/__init__.py
index 358b61a..2eea94a 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1336,6 +1336,15 @@
if not name.startswith("_"):
__all__.append(name)
+
+
+################################################################################
+# Import TorchDynamo's lazy APIs to avoid circular dependenices
+################################################################################
+
+# needs to be before from .functional import * to avoid circular dependencies
+from ._compile import _disable_dynamo
+
################################################################################
# Import interface functions defined in Python
################################################################################
diff --git a/torch/_compile.py b/torch/_compile.py
new file mode 100644
index 0000000..354d64e
--- /dev/null
+++ b/torch/_compile.py
@@ -0,0 +1,30 @@
+"""
+APIs related to torch.compile which lazily import torch._dynamo to avoid
+circular dependencies.
+"""
+import functools
+
+
+def _disable_dynamo(fn=None, recursive=True):
+ """
+ This API should be only used inside torch, external users should still use
+ torch._dynamo.disable. The main goal of this API is to avoid circular
+ imports issues that is common while using _dynamo.disable inside torch
+ itself.
+
+ This API avoids it by lazily importing torch._dynamo from the import time to
+ the invocation of the decorated function.
+ """
+ if fn is not None:
+
+ @functools.wraps(fn)
+ def inner(*args, **kwargs):
+ import torch._dynamo
+
+ return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
+
+ return inner
+ else:
+ # decorator usage like @_disable_dynamo(recursive=False). The resulting
+ # object expects the original decorated function as the arg.
+ return functools.partial(_disable_dynamo, recursive=recursive)
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index f4351b1..31a91ed 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -57,7 +57,6 @@
log = logging.getLogger(__name__)
from torch._dispatch.python import enable_python_dispatcher
-from torch.fx.experimental import proxy_tensor
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
null_context = contextlib.nullcontext
@@ -1212,33 +1211,23 @@
@staticmethod
@functools.lru_cache(None)
def patch():
- # Disable TorchDynamo on some torch.* compilers generated frames
+ # A better way to disable the following would be decorate the source
+ # functions with @torch._disable_dynamo. However, this causes issues
+ # with torch.deploy internally.
torch.jit.trace = disable(torch.jit.trace)
torch.jit.trace_module = disable(torch.jit.trace_module)
torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
-
- # symbolic_trace creates new frames. We disable Dynamo on such frames
torch.fx._symbolic_trace.Tracer.trace = disable(
torch.fx._symbolic_trace.Tracer.trace
)
-
- torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
torch.distributions.Distribution.set_default_validate_args(False)
- proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
-
optimizers = [
opt
for opt in torch.optim.__dict__.values()
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
]
- # disable dynamo for the wrapper that helps give dynamo hints about entering DDP
- if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
- DistributedDataParallel._inside_ddp_forward = disable(
- DistributedDataParallel._inside_ddp_forward, recursive=False
- )
-
# Note: this excludes the optimizers that are unsupported in excluded_opts below
from ..optim import (
adadelta,
@@ -1284,11 +1273,6 @@
if opt in excluded_opts:
opt.step = disable(opt.step)
- opt.zero_grad = disable(opt.zero_grad)
- opt.state_dict = disable(opt.state_dict)
- opt.load_state_dict = disable(opt.load_state_dict)
- opt.add_param_group = disable(opt.add_param_group)
-
if hasattr(opt, "_init_group"):
opt._init_group = disable(opt._init_group)
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 5ecf2c1..f590cda 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -459,6 +459,7 @@
return super().create_arg(a)
+@torch._disable_dynamo
def dispatch_trace(
root: Union[torch.nn.Module, Callable],
tracer: Tracer,
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index bed03f9..6f9fdcc 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1344,6 +1344,7 @@
# for the 'module_to_run' underneath
# see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
@contextmanager
+ @torch._disable_dynamo(recursive=False)
def _inside_ddp_forward(self):
DistributedDataParallel._active_ddp_module = self
try:
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index c88aaf4..eca5c77 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -1231,6 +1231,7 @@
@_beartype.beartype
+@torch._disable_dynamo
def export_to_pretty_string(
model,
args,
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index 93c6693..44ae074 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -382,6 +382,7 @@
self._optimizer_step_post_hooks[handle.id] = hook
return handle
+ @torch._disable_dynamo
def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`.
@@ -439,6 +440,7 @@
else:
return value.to(device=param.device)
+ @torch._disable_dynamo
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
@@ -495,6 +497,7 @@
update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
+ @torch._disable_dynamo
def zero_grad(self, set_to_none: bool = True):
r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
@@ -549,6 +552,7 @@
"""
raise NotImplementedError
+ @torch._disable_dynamo
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.