If we re-fakeify a FakeTensor with the same ShapeEnv, preserve symbols (#113651)

Subsumes half of https://github.com/pytorch/pytorch/pull/113605

We support fakeifying an already fake tensor, which will give you a new fake tensor mirroring the same structure as the original fake tensor, which is what is needed by https://github.com/pytorch/pytorch/issues/113643 . However, when this refakeification happens, we will naively reallocate all new sizes for all of the fake tensor. This is the right thing to do if you are re-fakeifying on a fresh ShapeEnv (because you're reparametrizing the sizes or something), but if you have two fake tensor modes which are sharing a shape environment, you would actually rather just reuse the original sizes/strides/offset from the original fake tensor. This ends up being pretty simple. I recommend viewing with whitespace diff turned off.

There's some fuzz around jagged tensor handling; that code is probably not quite right, but I fixed it for this particular case in the most straightforward way.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113651
Approved by: https://github.com/albanD, https://github.com/eellison, https://github.com/bdhirsh
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 0780271..65b6989 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -15,7 +15,7 @@
     DynamicOutputShapeException,
     UnsupportedOperatorException,
 )
-from torch.fx.experimental.symbolic_shapes import ShapeEnv
+from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols
 from torch.testing._internal.custom_op_db import custom_op_db
 from torch.testing._internal.common_device_type import ops
 from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
@@ -517,6 +517,43 @@
             x = torch.rand([10])
             x.tolist()
 
+    def test_same_shape_env_preserved(self):
+        shape_env = ShapeEnv()
+        mode1 = FakeTensorMode(shape_env=shape_env)
+        t1 = mode1.from_tensor(torch.randn(10), dynamic_dims=[DimDynamic.DYNAMIC])
+        mode2 = FakeTensorMode(shape_env=shape_env)
+        t2 = mode2.from_tensor(t1)
+        # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
+        self.assertIsNot(t2, t1)
+        self.assertIs(t1.fake_mode, mode1)
+        self.assertIs(t2.fake_mode, mode2)
+        self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
+        self.assertEqual(str(t2.size(0)), str(t1.size(0)))
+
+    def test_jagged_fake_to_fake_preserved(self):
+        from torch.nested._internal.nested_tensor import jagged_from_list
+
+        S0, S1, S2 = 3, 4, 5
+        D = 4
+        a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
+        b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
+        c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
+        offsets = None
+        jt, _ = jagged_from_list([a, b, c], offsets)
+        shape_env = ShapeEnv()
+        mode1 = FakeTensorMode(shape_env=shape_env)
+        t1 = mode1.from_tensor(jt)
+        mode2 = FakeTensorMode(shape_env=shape_env)
+        t2 = mode2.from_tensor(t1)
+        # It's not obvious that the invocation above makes it dynamic but it
+        # does!
+        self.assertTrue(free_symbols(t1.size()))
+        self.assertIsNot(t2, t1)
+        self.assertIs(t1.offsets().fake_mode, mode1)
+        self.assertIs(t2.offsets().fake_mode, mode2)
+        self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
+        self.assertEqual(str(t2.size(1)), str(t1.size(1)))
+
     def checkMetaProps(self, t1, t2):
         prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
 
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 1474996..247de42 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -1,7 +1,7 @@
 import contextlib
 import warnings
 import weakref
-from typing import ContextManager, List, Optional, TYPE_CHECKING
+from typing import ContextManager, List, Optional, Tuple, TYPE_CHECKING
 
 import torch
 from torch._C._functorch import (
@@ -187,6 +187,8 @@
         dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
         constraint_dims: "Optional[DimList[DimConstraint]]" = None,
     ):
+        from torch._subclasses.fake_tensor import FakeTensor
+
         if source is None:
             from torch._dynamo.source import ConstantSource
 
@@ -233,18 +235,25 @@
         if shape_env is not None:
             maybe_suppress = shape_env.suppress_guards
 
-        def sym_sizes_strides_storage_offset(t, src):
+        def sym_sizes_strides_storage_offset(
+            t, src
+        ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
             if shape_env is not None:
-                return shape_env.create_symbolic_sizes_strides_storage_offset(
-                    t,
-                    src,
-                    # Assume that the set of dims that are dynamic are the same between
-                    # the wrapper tensor and any inner tensors.
-                    # We can revisit this if this assumption does not hold
-                    # for any important subclasses later.
-                    dynamic_dims=dynamic_dims,
-                    constraint_dims=constraint_dims,
-                )
+                if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env:
+                    # Don't reallocate the sizes; the shape envs are the same,
+                    # so reuse the old sizes/strides/etc
+                    return (t.size(), t.stride(), t.storage_offset())
+                else:
+                    return shape_env.create_symbolic_sizes_strides_storage_offset(
+                        t,
+                        src,
+                        # Assume that the set of dims that are dynamic are the same between
+                        # the wrapper tensor and any inner tensors.
+                        # We can revisit this if this assumption does not hold
+                        # for any important subclasses later.
+                        dynamic_dims=dynamic_dims,
+                        constraint_dims=constraint_dims,
+                    )
             else:
                 assert dynamic_dims is None
                 assert constraint_dims is None
@@ -474,8 +483,15 @@
                             # so we can insert some special processing on ctx
                             attrs, ctx = t.__tensor_flatten__()
                             transformed_tensors_dict = {}
+                            orig_shape_env = None
                             for attr in attrs:
                                 inner_t = getattr(t, attr)
+                                if orig_shape_env is None:
+                                    orig_shape_env = (
+                                        inner_t.fake_mode.shape_env
+                                        if isinstance(inner_t, FakeTensor)
+                                        else None
+                                    )
                                 transformed_tensors_dict[attr] = callback(
                                     lambda: empty_create(
                                         inner_t, AttrSource(source, attr)
@@ -483,22 +499,27 @@
                                 )
                             # We expect JaggedTensor to have a 'ragged_size' in
                             # its context
-                            assert isinstance(ctx, dict) and "ragged_size" in ctx
-                            assert (
-                                isinstance(t._size[1], torch.SymInt)
-                                and t._size[1].node.singleton_int() is not None
-                            )
-                            # Replace the eager ragged size with our freshly
-                            # allocated jagged size that has a source
-                            ctx["ragged_size"] = shape_env.create_symintnode(
-                                shape_env.create_symbol(
-                                    t._size[1],
-                                    TensorPropertySource(
-                                        source, TensorProperty.SIZE, 1
+                            assert isinstance(ctx, dict)
+                            assert "ragged_size" in ctx
+                            assert isinstance(t._size[1], torch.SymInt)
+                            if orig_shape_env is shape_env:
+                                # It's already fake and the shape envs line up, reuse the old size
+                                # Do not assert singleton_int; it may already
+                                # be a variable
+                                ctx["ragged_size"] = t._size[1]
+                            else:
+                                assert t._size[1].node.singleton_int() is not None
+                                # Replace the eager ragged size with our freshly
+                                # allocated jagged size that has a source
+                                ctx["ragged_size"] = shape_env.create_symintnode(
+                                    shape_env.create_symbol(
+                                        t._size[1],
+                                        TensorPropertySource(
+                                            source, TensorProperty.SIZE, 1
+                                        ),
                                     ),
-                                ),
-                                hint=t._size[1],
-                            )
+                                    hint=t._size[1],
+                                )
                             r = type(t).__tensor_unflatten__(
                                 transformed_tensors_dict, ctx
                             )
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index d76ed61..47980dc 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -2017,7 +2017,7 @@
         # TODO: This should be DYNAMIC, using DUCK for BC
         dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK
 
-        assert len(dynamic_dims) == dim
+        assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}"
         assert len(constraint_dims) == dim
 
         from torch._dynamo.source import TensorPropertySource, TensorProperty