Fix nested fqn discovery (#125957)
I think I missed some fix!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125957
Approved by: https://github.com/sanketpurandare, https://github.com/janeyx99
diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py
index e465b12..09b5503 100644
--- a/test/test_module_tracker.py
+++ b/test/test_module_tracker.py
@@ -25,10 +25,12 @@
def __init__(self):
super().__init__()
self.a = Foo()
- self.b = Foo()
+ self.b = torch.nn.ModuleDict({"nest": Foo()})
+ self.c = torch.nn.ModuleList([Foo()])
def forward(self, x):
- return self.b(self.a(x))
+ x = self.c[0](x)
+ return self.b["nest"](self.a(x))
mod = Mod()
@@ -43,20 +45,24 @@
self.assertEqual(
seen_fw,
[
+ ({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
- ({"Global", "Mod", "Mod.b"}, False),
+ ({"Global", "Mod", "Mod.b.nest"}, False),
+ ({"Global", "Mod", "Mod.c.0"}, False),
({"Global", "Mod", "Mod.a"}, False),
- ({"Global", "Mod", "Mod.b"}, False),
+ ({"Global", "Mod", "Mod.b.nest"}, False),
],
)
self.assertEqual(
seen_bw,
[
- ({"Global", "Mod", "Mod.b"}, True),
+ ({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
- ({"Global", "Mod", "Mod.b"}, True),
+ ({"Global", "Mod", "Mod.c.0"}, True),
+ ({"Global", "Mod", "Mod.b.nest"}, True),
({"Global", "Mod", "Mod.a"}, True),
+ ({"Global", "Mod", "Mod.c.0"}, True),
],
)
diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py
index 078effe..b79d143 100644
--- a/torch/utils/module_tracker.py
+++ b/torch/utils/module_tracker.py
@@ -52,7 +52,7 @@
def __init__(self):
self.parents = {"Global"}
self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
- self._seen_modules = set()
+ self._seen_modules: weakref.WeakSet = weakref.WeakSet()
self._has_callback = False
def _maybe_set_engine_callback(self):
@@ -81,6 +81,8 @@
if mod not in self._seen_modules:
for name, submod in mod.named_children():
self._known_modules[submod] = f"{mod_name}.{name}"
+ self._get_mod_name(submod)
+ self._seen_modules.add(mod)
return mod_name
def _get_append_fn(self, name, is_bw):