Fix nested fqn discovery (#125957)

I think I missed some fix!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125957
Approved by: https://github.com/sanketpurandare, https://github.com/janeyx99
diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py
index e465b12..09b5503 100644
--- a/test/test_module_tracker.py
+++ b/test/test_module_tracker.py
@@ -25,10 +25,12 @@
             def __init__(self):
                 super().__init__()
                 self.a = Foo()
-                self.b = Foo()
+                self.b = torch.nn.ModuleDict({"nest": Foo()})
+                self.c = torch.nn.ModuleList([Foo()])
 
             def forward(self, x):
-                return self.b(self.a(x))
+                x = self.c[0](x)
+                return self.b["nest"](self.a(x))
 
         mod = Mod()
 
@@ -43,20 +45,24 @@
         self.assertEqual(
             seen_fw,
             [
+                ({"Global", "Mod", "Mod.c.0"}, False),
                 ({"Global", "Mod", "Mod.a"}, False),
-                ({"Global", "Mod", "Mod.b"}, False),
+                ({"Global", "Mod", "Mod.b.nest"}, False),
+                ({"Global", "Mod", "Mod.c.0"}, False),
                 ({"Global", "Mod", "Mod.a"}, False),
-                ({"Global", "Mod", "Mod.b"}, False),
+                ({"Global", "Mod", "Mod.b.nest"}, False),
             ],
         )
 
         self.assertEqual(
             seen_bw,
             [
-                ({"Global", "Mod", "Mod.b"}, True),
+                ({"Global", "Mod", "Mod.b.nest"}, True),
                 ({"Global", "Mod", "Mod.a"}, True),
-                ({"Global", "Mod", "Mod.b"}, True),
+                ({"Global", "Mod", "Mod.c.0"}, True),
+                ({"Global", "Mod", "Mod.b.nest"}, True),
                 ({"Global", "Mod", "Mod.a"}, True),
+                ({"Global", "Mod", "Mod.c.0"}, True),
             ],
         )
 
diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py
index 078effe..b79d143 100644
--- a/torch/utils/module_tracker.py
+++ b/torch/utils/module_tracker.py
@@ -52,7 +52,7 @@
     def __init__(self):
         self.parents = {"Global"}
         self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
-        self._seen_modules = set()
+        self._seen_modules: weakref.WeakSet = weakref.WeakSet()
         self._has_callback = False
 
     def _maybe_set_engine_callback(self):
@@ -81,6 +81,8 @@
         if mod not in self._seen_modules:
             for name, submod in mod.named_children():
                 self._known_modules[submod] = f"{mod_name}.{name}"
+                self._get_mod_name(submod)
+            self._seen_modules.add(mod)
         return mod_name
 
     def _get_append_fn(self, name, is_bw):