Revert "[dynamo] Error when user nests FX with dynamo (#87797)"
This reverts commit 1da5aeb97b73664ff0fe2f4bb48379655cede969.
Reverted https://github.com/pytorch/pytorch/pull/87797 on behalf of https://github.com/ezyang due to breaks nvfuser stack, needs more investigation
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index a63a6d8..a0f5922 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2732,20 +2732,6 @@
dynamo_result = graph(x)
self.assertTrue(same(real, dynamo_result))
- def test_error_on_nested_fx_trace(self):
- input = torch.rand(2, 3)
-
- def f(x):
- x + x
-
- real = f(input)
-
- optimized = torch._dynamo.optimize("eager")(f)
- self.assertTrue(same(optimized(input), real))
-
- with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
- gm = torch.fx.symbolic_trace(optimized)
-
class CustomFunc(torch.autograd.Function):
@staticmethod
diff --git a/test/test_prims.py b/test/test_prims.py
index 6f400ce..6223a34 100644
--- a/test/test_prims.py
+++ b/test/test_prims.py
@@ -8,14 +8,7 @@
import torch
from torch.testing import make_tensor
-from torch.testing._internal.common_utils import (
- parametrize,
- run_tests,
- TestCase,
- TEST_SCIPY,
- skipCUDAMemoryLeakCheckIf,
- skipIfTorchDynamo,
-)
+from torch.testing._internal.common_utils import parametrize, run_tests, TestCase, TEST_SCIPY, skipCUDAMemoryLeakCheckIf
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
@@ -394,7 +387,6 @@
actual = execute(gm, a.mT, executor="nvfuser")
self.assertEqual(expected, actual)
- @skipIfTorchDynamo
def test_nvfuser_capability_context(self, device):
# This test is to ensure that the torch calls are replaced with refs
# based on the nvfuser+prims capability
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index 1208838..87014b2 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -153,10 +153,6 @@
# How to import torchinductor, either torchinductor or torch.inductor
inductor_import = dynamo_import.replace("dynamo", "inductor")
-# If true, error with a better message if we symbolically trace over a
-# dynamo-optimized function. If false, silently suppress dynamo.
-error_on_nested_fx_trace = True
-
# root folder of the project
if "torch." in dynamo_import:
base_dir = dirname(dirname(dirname(abspath(__file__))))
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
index 29bb14b..fce9e43 100644
--- a/torch/_dynamo/eval_frame.py
+++ b/torch/_dynamo/eval_frame.py
@@ -14,7 +14,6 @@
import torch
import torch.utils._pytree as pytree
-from torch.fx._symbolic_trace import is_fx_tracing
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn.parallel.distributed import DistributedDataParallel
@@ -150,14 +149,6 @@
@functools.wraps(fn)
def _fn(*args, **kwargs):
- if is_fx_tracing():
- if config.error_on_nested_fx_trace:
- raise RuntimeError(
- "Detected that you are using FX to symbolically trace "
- "a dynamo-optimized function. This is not supported at the moment."
- )
- return fn
-
on_enter()
prior = set_eval_frame(callback)
backend_ctx = backend_ctx_ctor()