If we re-fakeify a FakeTensor with the same ShapeEnv, preserve symbols (#113651)
Subsumes half of https://github.com/pytorch/pytorch/pull/113605
We support fakeifying an already fake tensor, which will give you a new fake tensor mirroring the same structure as the original fake tensor, which is what is needed by https://github.com/pytorch/pytorch/issues/113643 . However, when this refakeification happens, we will naively reallocate all new sizes for all of the fake tensor. This is the right thing to do if you are re-fakeifying on a fresh ShapeEnv (because you're reparametrizing the sizes or something), but if you have two fake tensor modes which are sharing a shape environment, you would actually rather just reuse the original sizes/strides/offset from the original fake tensor. This ends up being pretty simple. I recommend viewing with whitespace diff turned off.
There's some fuzz around jagged tensor handling; that code is probably not quite right, but I fixed it for this particular case in the most straightforward way.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113651
Approved by: https://github.com/albanD, https://github.com/eellison, https://github.com/bdhirsh
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 0780271..65b6989 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -15,7 +15,7 @@
DynamicOutputShapeException,
UnsupportedOperatorException,
)
-from torch.fx.experimental.symbolic_shapes import ShapeEnv
+from torch.fx.experimental.symbolic_shapes import ShapeEnv, DimDynamic, free_symbols
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.common_device_type import ops
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
@@ -517,6 +517,43 @@
x = torch.rand([10])
x.tolist()
+ def test_same_shape_env_preserved(self):
+ shape_env = ShapeEnv()
+ mode1 = FakeTensorMode(shape_env=shape_env)
+ t1 = mode1.from_tensor(torch.randn(10), dynamic_dims=[DimDynamic.DYNAMIC])
+ mode2 = FakeTensorMode(shape_env=shape_env)
+ t2 = mode2.from_tensor(t1)
+ # t2.size(0) is still dynamic, even though we didn't pass DYNAMIC here
+ self.assertIsNot(t2, t1)
+ self.assertIs(t1.fake_mode, mode1)
+ self.assertIs(t2.fake_mode, mode2)
+ self.assertIs(t2.size(0).node.shape_env, t1.size(0).node.shape_env)
+ self.assertEqual(str(t2.size(0)), str(t1.size(0)))
+
+ def test_jagged_fake_to_fake_preserved(self):
+ from torch.nested._internal.nested_tensor import jagged_from_list
+
+ S0, S1, S2 = 3, 4, 5
+ D = 4
+ a = torch.randn(S0, D, requires_grad=True, dtype=torch.float64)
+ b = torch.randn(S1, D, requires_grad=True, dtype=torch.float64)
+ c = torch.randn(S2, D, requires_grad=True, dtype=torch.float64)
+ offsets = None
+ jt, _ = jagged_from_list([a, b, c], offsets)
+ shape_env = ShapeEnv()
+ mode1 = FakeTensorMode(shape_env=shape_env)
+ t1 = mode1.from_tensor(jt)
+ mode2 = FakeTensorMode(shape_env=shape_env)
+ t2 = mode2.from_tensor(t1)
+ # It's not obvious that the invocation above makes it dynamic but it
+ # does!
+ self.assertTrue(free_symbols(t1.size()))
+ self.assertIsNot(t2, t1)
+ self.assertIs(t1.offsets().fake_mode, mode1)
+ self.assertIs(t2.offsets().fake_mode, mode2)
+ self.assertIs(t2.size(1).node.shape_env, t1.size(1).node.shape_env)
+ self.assertEqual(str(t2.size(1)), str(t1.size(1)))
+
def checkMetaProps(self, t1, t2):
prims.utils.compare_tensor_meta(t1, t2, check_strides=True)
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 1474996..247de42 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -1,7 +1,7 @@
import contextlib
import warnings
import weakref
-from typing import ContextManager, List, Optional, TYPE_CHECKING
+from typing import ContextManager, List, Optional, Tuple, TYPE_CHECKING
import torch
from torch._C._functorch import (
@@ -187,6 +187,8 @@
dynamic_dims: "Optional[DimList[DimDynamic]]" = None,
constraint_dims: "Optional[DimList[DimConstraint]]" = None,
):
+ from torch._subclasses.fake_tensor import FakeTensor
+
if source is None:
from torch._dynamo.source import ConstantSource
@@ -233,18 +235,25 @@
if shape_env is not None:
maybe_suppress = shape_env.suppress_guards
- def sym_sizes_strides_storage_offset(t, src):
+ def sym_sizes_strides_storage_offset(
+ t, src
+ ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]:
if shape_env is not None:
- return shape_env.create_symbolic_sizes_strides_storage_offset(
- t,
- src,
- # Assume that the set of dims that are dynamic are the same between
- # the wrapper tensor and any inner tensors.
- # We can revisit this if this assumption does not hold
- # for any important subclasses later.
- dynamic_dims=dynamic_dims,
- constraint_dims=constraint_dims,
- )
+ if isinstance(t, FakeTensor) and t.fake_mode.shape_env is shape_env:
+ # Don't reallocate the sizes; the shape envs are the same,
+ # so reuse the old sizes/strides/etc
+ return (t.size(), t.stride(), t.storage_offset())
+ else:
+ return shape_env.create_symbolic_sizes_strides_storage_offset(
+ t,
+ src,
+ # Assume that the set of dims that are dynamic are the same between
+ # the wrapper tensor and any inner tensors.
+ # We can revisit this if this assumption does not hold
+ # for any important subclasses later.
+ dynamic_dims=dynamic_dims,
+ constraint_dims=constraint_dims,
+ )
else:
assert dynamic_dims is None
assert constraint_dims is None
@@ -474,8 +483,15 @@
# so we can insert some special processing on ctx
attrs, ctx = t.__tensor_flatten__()
transformed_tensors_dict = {}
+ orig_shape_env = None
for attr in attrs:
inner_t = getattr(t, attr)
+ if orig_shape_env is None:
+ orig_shape_env = (
+ inner_t.fake_mode.shape_env
+ if isinstance(inner_t, FakeTensor)
+ else None
+ )
transformed_tensors_dict[attr] = callback(
lambda: empty_create(
inner_t, AttrSource(source, attr)
@@ -483,22 +499,27 @@
)
# We expect JaggedTensor to have a 'ragged_size' in
# its context
- assert isinstance(ctx, dict) and "ragged_size" in ctx
- assert (
- isinstance(t._size[1], torch.SymInt)
- and t._size[1].node.singleton_int() is not None
- )
- # Replace the eager ragged size with our freshly
- # allocated jagged size that has a source
- ctx["ragged_size"] = shape_env.create_symintnode(
- shape_env.create_symbol(
- t._size[1],
- TensorPropertySource(
- source, TensorProperty.SIZE, 1
+ assert isinstance(ctx, dict)
+ assert "ragged_size" in ctx
+ assert isinstance(t._size[1], torch.SymInt)
+ if orig_shape_env is shape_env:
+ # It's already fake and the shape envs line up, reuse the old size
+ # Do not assert singleton_int; it may already
+ # be a variable
+ ctx["ragged_size"] = t._size[1]
+ else:
+ assert t._size[1].node.singleton_int() is not None
+ # Replace the eager ragged size with our freshly
+ # allocated jagged size that has a source
+ ctx["ragged_size"] = shape_env.create_symintnode(
+ shape_env.create_symbol(
+ t._size[1],
+ TensorPropertySource(
+ source, TensorProperty.SIZE, 1
+ ),
),
- ),
- hint=t._size[1],
- )
+ hint=t._size[1],
+ )
r = type(t).__tensor_unflatten__(
transformed_tensors_dict, ctx
)
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index d76ed61..47980dc 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -2017,7 +2017,7 @@
# TODO: This should be DYNAMIC, using DUCK for BC
dynamic_strides_offset = DimDynamic.STATIC if all(r == DimDynamic.STATIC for r in dynamic_dims) else DimDynamic.DUCK
- assert len(dynamic_dims) == dim
+ assert len(dynamic_dims) == dim, f"{len(dynamic_dims)} != {dim}"
assert len(constraint_dims) == dim
from torch._dynamo.source import TensorPropertySource, TensorProperty