[dynamo] Reland #104317 - Lazy disable_dynamo API out-of-dynamo (#104664)

Internal failed because of torch.deploy issues with disable_dynamo in fx/* and _jit/* files. Removing disable_dynamo for both. Added a comment in the code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104664
Approved by: https://github.com/wconstab
diff --git a/torch/__init__.py b/torch/__init__.py
index 358b61a..2eea94a 100644
--- a/torch/__init__.py
+++ b/torch/__init__.py
@@ -1336,6 +1336,15 @@
     if not name.startswith("_"):
         __all__.append(name)
 
+
+
+################################################################################
+# Import TorchDynamo's lazy APIs to avoid circular dependenices
+################################################################################
+
+# needs to be before from .functional import * to avoid circular dependencies
+from ._compile import _disable_dynamo
+
 ################################################################################
 # Import interface functions defined in Python
 ################################################################################
diff --git a/torch/_compile.py b/torch/_compile.py
new file mode 100644
index 0000000..354d64e
--- /dev/null
+++ b/torch/_compile.py
@@ -0,0 +1,30 @@
+"""
+APIs related to torch.compile which lazily import torch._dynamo to avoid
+circular dependencies.
+"""
+import functools
+
+
+def _disable_dynamo(fn=None, recursive=True):
+    """
+    This API should be only used inside torch, external users should still use
+    torch._dynamo.disable. The main goal of this API is to avoid circular
+    imports issues that is common while using _dynamo.disable inside torch
+    itself.
+
+    This API avoids it by lazily importing torch._dynamo from the import time to
+    the invocation of the decorated function.
+    """
+    if fn is not None:
+
+        @functools.wraps(fn)
+        def inner(*args, **kwargs):
+            import torch._dynamo
+
+            return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
+
+        return inner
+    else:
+        # decorator usage like @_disable_dynamo(recursive=False). The resulting
+        # object expects the original decorated function as the arg.
+        return functools.partial(_disable_dynamo, recursive=recursive)
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index f4351b1..31a91ed 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -57,7 +57,6 @@
 log = logging.getLogger(__name__)
 
 from torch._dispatch.python import enable_python_dispatcher
-from torch.fx.experimental import proxy_tensor
 
 always_optimize_code_objects = utils.ExactWeakKeyDictionary()
 null_context = contextlib.nullcontext
@@ -1212,33 +1211,23 @@
     @staticmethod
     @functools.lru_cache(None)
     def patch():
-        # Disable TorchDynamo on some torch.* compilers generated frames
+        # A better way to disable the following would be decorate the source
+        # functions with @torch._disable_dynamo. However, this causes issues
+        # with torch.deploy internally.
         torch.jit.trace = disable(torch.jit.trace)
         torch.jit.trace_module = disable(torch.jit.trace_module)
         torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph)
-
-        # symbolic_trace creates new frames. We disable Dynamo on such frames
         torch.fx._symbolic_trace.Tracer.trace = disable(
             torch.fx._symbolic_trace.Tracer.trace
         )
-
-        torch.onnx.export_to_pretty_string = disable(torch.onnx.export_to_pretty_string)
         torch.distributions.Distribution.set_default_validate_args(False)
 
-        proxy_tensor.dispatch_trace = disable(proxy_tensor.dispatch_trace)
-
         optimizers = [
             opt
             for opt in torch.optim.__dict__.values()
             if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer)
         ]
 
-        # disable dynamo for the wrapper that helps give dynamo hints about entering DDP
-        if hasattr(DistributedDataParallel, "_inside_ddp_forward"):
-            DistributedDataParallel._inside_ddp_forward = disable(
-                DistributedDataParallel._inside_ddp_forward, recursive=False
-            )
-
         # Note: this excludes the optimizers that are unsupported in excluded_opts below
         from ..optim import (
             adadelta,
@@ -1284,11 +1273,6 @@
             if opt in excluded_opts:
                 opt.step = disable(opt.step)
 
-            opt.zero_grad = disable(opt.zero_grad)
-            opt.state_dict = disable(opt.state_dict)
-            opt.load_state_dict = disable(opt.load_state_dict)
-            opt.add_param_group = disable(opt.add_param_group)
-
             if hasattr(opt, "_init_group"):
                 opt._init_group = disable(opt._init_group)
 
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 5ecf2c1..f590cda 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -459,6 +459,7 @@
         return super().create_arg(a)
 
 
+@torch._disable_dynamo
 def dispatch_trace(
         root: Union[torch.nn.Module, Callable],
         tracer: Tracer,
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index bed03f9..6f9fdcc 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -1344,6 +1344,7 @@
     # for the 'module_to_run' underneath
     # see torch._dynamo/eval_frame.py TorchPatcher.patch for more details
     @contextmanager
+    @torch._disable_dynamo(recursive=False)
     def _inside_ddp_forward(self):
         DistributedDataParallel._active_ddp_module = self
         try:
diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py
index c88aaf4..eca5c77 100644
--- a/torch/onnx/utils.py
+++ b/torch/onnx/utils.py
@@ -1231,6 +1231,7 @@
 
 
 @_beartype.beartype
+@torch._disable_dynamo
 def export_to_pretty_string(
     model,
     args,
diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py
index 93c6693..44ae074 100644
--- a/torch/optim/optimizer.py
+++ b/torch/optim/optimizer.py
@@ -382,6 +382,7 @@
         self._optimizer_step_post_hooks[handle.id] = hook
         return handle
 
+    @torch._disable_dynamo
     def state_dict(self):
         r"""Returns the state of the optimizer as a :class:`dict`.
 
@@ -439,6 +440,7 @@
             else:
                 return value.to(device=param.device)
 
+    @torch._disable_dynamo
     def load_state_dict(self, state_dict):
         r"""Loads the optimizer state.
 
@@ -495,6 +497,7 @@
             update_group(g, ng) for g, ng in zip(groups, saved_groups)]
         self.__setstate__({'state': state, 'param_groups': param_groups})
 
+    @torch._disable_dynamo
     def zero_grad(self, set_to_none: bool = True):
         r"""Resets the gradients of all optimized :class:`torch.Tensor` s.
 
@@ -549,6 +552,7 @@
         """
         raise NotImplementedError
 
+    @torch._disable_dynamo
     def add_param_group(self, param_group):
         r"""Add a param group to the :class:`Optimizer` s `param_groups`.