[Reland][Dynamo] VariableTracker.recursively_contains should be updated correctly when mutation happens (#103564) (#103717)
Summary: Reland of https://github.com/pytorch/pytorch/pull/103564
Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5c3556da9406f814e6a1286cb6762e5508d54971
Differential Revision: D46783727
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103717
Approved by: https://github.com/angelayi
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index a3b3a57..094358f 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -3992,6 +3992,23 @@
res = opt_fn(x)
self.assertTrue(same(ref, res))
+ def test_variable_tracker_recursively_contains(self):
+ # VariableTracker.recursively_contains should be updated correctly when mutation happens
+ def fn(x):
+ data = [[None] * 3] * 3
+ for i in range(3):
+ if i == 0:
+ data[0][i] = x
+ else:
+ data[0][i] = data[0][i - 1] + 1
+ return data[0][-1]
+
+ x = torch.rand(4)
+ ref = fn(x)
+ opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
+ res = opt_fn(x)
+ self.assertTrue(same(ref, res))
+
@unittest.skipIf(not TEST_CUDA, "requires cuda")
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
def test_torch_cudnn_is_acceptable(self):
diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py
index 224c0c9..cc2fc74 100644
--- a/torch/_dynamo/variables/base.py
+++ b/torch/_dynamo/variables/base.py
@@ -76,6 +76,7 @@
value,
cache=None,
skip_fn=lambda _: False, # Whether we should skip applying to this var
+ update_contains=False,
):
"""
Walk this object and call fn on all the VariableTracker
@@ -97,20 +98,25 @@
fn, updated_dict[key], cache, skip_fn
)
result = fn(value.clone(**updated_dict))
+ if update_contains is False:
+ result._update_contains()
else:
result = fn(value)
elif istype(value, list):
- result = [cls.apply(fn, v, cache, skip_fn) for v in value]
+ result = [cls.apply(fn, v, cache, skip_fn, update_contains) for v in value]
elif istype(value, tuple):
- result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
+ result = tuple(
+ cls.apply(fn, v, cache, skip_fn, update_contains) for v in value
+ )
elif istype(value, collections.OrderedDict):
result = collections.OrderedDict(
- cls.apply(fn, v, cache, skip_fn) for v in value.items()
+ cls.apply(fn, v, cache, skip_fn, update_contains) for v in value.items()
)
elif istype(value, dict):
result = {
- k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())
+ k: cls.apply(fn, v, cache, skip_fn, update_contains)
+ for k, v in list(value.items())
}
else:
result = value
@@ -271,19 +277,32 @@
if self.recursively_contains is None:
self.recursively_contains = set()
- def aggregate_mutables(var):
- self.recursively_contains.update(var.recursively_contains)
- if var.mutable_local is not None:
- self.recursively_contains.add(var.mutable_local)
-
- return var
-
VariableTracker.apply(
- aggregate_mutables, self, skip_fn=lambda var: var is not self
+ self._aggregate_mutables, self, skip_fn=lambda var: var is not self
)
assert None not in self.recursively_contains
+ def _aggregate_mutables(self, var):
+ self.recursively_contains.update(var.recursively_contains)
+ if var.mutable_local is not None:
+ self.recursively_contains.add(var.mutable_local)
+
+ return var
+
+ # This is used to forcely update self.recursively_contains
+ def _update_contains(self):
+ self.recursively_contains = set()
+
+ VariableTracker.apply(
+ self._aggregate_mutables,
+ self,
+ skip_fn=lambda var: var is not self,
+ update_contains=True,
+ )
+
+ assert None not in self.recursively_contains
+
def typestr(*objs):
if len(objs) == 1: