Make proxy_tensor.py not depend on SymPy (#112036)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112036
Approved by: https://github.com/malfet, https://github.com/peterbell10
ghstack dependencies: #112035
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 0349075..2173cdb 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -4,7 +4,7 @@
 import torch
 from copy import deepcopy
 from torch.library import Library, impl, fallthrough_kernel
-from torch.fx.experimental.proxy_tensor import ShapeEnv
+from torch.fx.experimental.symbolic_shapes import ShapeEnv
 from torch import SymInt
 from torch._subclasses.fake_tensor import FakeTensorMode
 from torch.cuda.jiterator import _create_jit_fn
diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py
index 7b8146e..80a674c 100644
--- a/torch/fx/experimental/proxy_tensor.py
+++ b/torch/fx/experimental/proxy_tensor.py
@@ -29,7 +29,6 @@
     _push_mode,
 )
 
-from .symbolic_shapes import ShapeEnv, SymNode
 from ._sym_dispatch_mode import SymDispatchMode
 from torch.fx import Proxy
 import torch.fx.traceback as fx_traceback
@@ -81,6 +80,8 @@
         # on a tensor, and it affects the metadata on the proxy.
         tracer.tensor_tracker[obj] = proxy
     else:
+        # Avoid importing sympy at a module level
+        from .symbolic_shapes import SymNode
         # NB: Never clobber pre-existing proxy.  Although the proxies
         # are in principle equivalent, when we do graph partitioning
         # we need there not to be spurious dependencies on tangent inputs.
@@ -92,6 +93,8 @@
             tracer.symnode_tracker[obj] = proxy
 
 def has_proxy_slot(obj, tracer):
+    # Avoid importing sympy at a module level
+    from .symbolic_shapes import SymNode
     assert isinstance(obj, (torch.Tensor, SymNode)), type(obj)
     return get_proxy_slot(obj, tracer, False, lambda _: True)
 
@@ -102,6 +105,8 @@
     if isinstance(obj, torch.Tensor):
         tracker = tracer.tensor_tracker
     else:
+        # Avoid importing sympy at a module level
+        from .symbolic_shapes import SymNode
         assert isinstance(obj, SymNode), type(obj)
         tracker = tracer.symnode_tracker
 
@@ -769,6 +774,9 @@
 
     @functools.wraps(f)
     def wrapped(*args):
+        # Avoid importing sympy at a module level
+        from .symbolic_shapes import ShapeEnv
+
         phs = pytree.tree_map(lambda _: fx.PH, args)  # type: ignore[attr-defined]
         fx_tracer = PythonKeyTracer()
         fake_tensor_mode: Any = nullcontext()