Make proxy_tensor.py not depend on SymPy (#112036)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112036
Approved by: https://github.com/malfet, https://github.com/peterbell10
ghstack dependencies: #112035
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 0349075..2173cdb 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -4,7 +4,7 @@
import torch
from copy import deepcopy
from torch.library import Library, impl, fallthrough_kernel
-from torch.fx.experimental.proxy_tensor import ShapeEnv
+from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch import SymInt
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.cuda.jiterator import _create_jit_fn
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 7b8146e..80a674c 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -29,7 +29,6 @@
_push_mode,
)
-from .symbolic_shapes import ShapeEnv, SymNode
from ._sym_dispatch_mode import SymDispatchMode
from torch.fx import Proxy
import torch.fx.traceback as fx_traceback
@@ -81,6 +80,8 @@
# on a tensor, and it affects the metadata on the proxy.
tracer.tensor_tracker[obj] = proxy
else:
+ # Avoid importing sympy at a module level
+ from .symbolic_shapes import SymNode
# NB: Never clobber pre-existing proxy. Although the proxies
# are in principle equivalent, when we do graph partitioning
# we need there not to be spurious dependencies on tangent inputs.
@@ -92,6 +93,8 @@
tracer.symnode_tracker[obj] = proxy
def has_proxy_slot(obj, tracer):
+ # Avoid importing sympy at a module level
+ from .symbolic_shapes import SymNode
assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
return get_proxy_slot(obj, tracer, False, lambda _: True)
@@ -102,6 +105,8 @@
if isinstance(obj, torch.Tensor):
tracker = tracer.tensor_tracker
else:
+ # Avoid importing sympy at a module level
+ from .symbolic_shapes import SymNode
assert isinstance(obj, SymNode), type(obj)
tracker = tracer.symnode_tracker
@@ -769,6 +774,9 @@
@functools.wraps(f)
def wrapped(*args):
+ # Avoid importing sympy at a module level
+ from .symbolic_shapes import ShapeEnv
+
phs = pytree.tree_map(lambda _: fx.PH, args) # type: ignore[attr-defined]
fx_tracer = PythonKeyTracer()
fake_tensor_mode: Any = nullcontext()