[Reland][Dynamo] VariableTracker.recursively_contains should be updated correctly when mutation happens (#103564) (#103717)

Summary: Reland of https://github.com/pytorch/pytorch/pull/103564

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/5c3556da9406f814e6a1286cb6762e5508d54971

Differential Revision: D46783727

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103717
Approved by: https://github.com/angelayi
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index a3b3a57..094358f 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -3992,6 +3992,23 @@
         res = opt_fn(x)
         self.assertTrue(same(ref, res))
 
+    def test_variable_tracker_recursively_contains(self):
+        # VariableTracker.recursively_contains should be updated correctly when mutation happens
+        def fn(x):
+            data = [[None] * 3] * 3
+            for i in range(3):
+                if i == 0:
+                    data[0][i] = x
+                else:
+                    data[0][i] = data[0][i - 1] + 1
+            return data[0][-1]
+
+        x = torch.rand(4)
+        ref = fn(x)
+        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
+        res = opt_fn(x)
+        self.assertTrue(same(ref, res))
+
     @unittest.skipIf(not TEST_CUDA, "requires cuda")
     @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
     def test_torch_cudnn_is_acceptable(self):
diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py
index 224c0c9..cc2fc74 100644
--- a/torch/_dynamo/variables/base.py
+++ b/torch/_dynamo/variables/base.py
@@ -76,6 +76,7 @@
         value,
         cache=None,
         skip_fn=lambda _: False,  # Whether we should skip applying to this var
+        update_contains=False,
     ):
         """
         Walk this object and call fn on all the VariableTracker
@@ -97,20 +98,25 @@
                             fn, updated_dict[key], cache, skip_fn
                         )
                 result = fn(value.clone(**updated_dict))
+                if update_contains is False:
+                    result._update_contains()
             else:
                 result = fn(value)
 
         elif istype(value, list):
-            result = [cls.apply(fn, v, cache, skip_fn) for v in value]
+            result = [cls.apply(fn, v, cache, skip_fn, update_contains) for v in value]
         elif istype(value, tuple):
-            result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
+            result = tuple(
+                cls.apply(fn, v, cache, skip_fn, update_contains) for v in value
+            )
         elif istype(value, collections.OrderedDict):
             result = collections.OrderedDict(
-                cls.apply(fn, v, cache, skip_fn) for v in value.items()
+                cls.apply(fn, v, cache, skip_fn, update_contains) for v in value.items()
             )
         elif istype(value, dict):
             result = {
-                k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())
+                k: cls.apply(fn, v, cache, skip_fn, update_contains)
+                for k, v in list(value.items())
             }
         else:
             result = value
@@ -271,19 +277,32 @@
         if self.recursively_contains is None:
             self.recursively_contains = set()
 
-            def aggregate_mutables(var):
-                self.recursively_contains.update(var.recursively_contains)
-                if var.mutable_local is not None:
-                    self.recursively_contains.add(var.mutable_local)
-
-                return var
-
             VariableTracker.apply(
-                aggregate_mutables, self, skip_fn=lambda var: var is not self
+                self._aggregate_mutables, self, skip_fn=lambda var: var is not self
             )
 
         assert None not in self.recursively_contains
 
+    def _aggregate_mutables(self, var):
+        self.recursively_contains.update(var.recursively_contains)
+        if var.mutable_local is not None:
+            self.recursively_contains.add(var.mutable_local)
+
+        return var
+
+    # This is used to forcely update self.recursively_contains
+    def _update_contains(self):
+        self.recursively_contains = set()
+
+        VariableTracker.apply(
+            self._aggregate_mutables,
+            self,
+            skip_fn=lambda var: var is not self,
+            update_contains=True,
+        )
+
+        assert None not in self.recursively_contains
+
 
 def typestr(*objs):
     if len(objs) == 1: