Support module dict iter (#99503)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99503
Approved by: https://github.com/Chillee, https://github.com/jansel
diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py
index 8d13c7f..67fdfae 100644
--- a/test/dynamo/test_modules.py
+++ b/test/dynamo/test_modules.py
@@ -1687,6 +1687,72 @@
self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())
+ def test_module_dict_iter_name(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activations = torch.nn.ModuleDict(
+ [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
+ )
+
+ def forward(self, x):
+ for activation_name in self.activations:
+ x = self.activations[activation_name](x)
+ return x
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ # Eager
+ eager_res = MyModule()(torch.ones(10, 10))
+
+ # Compile
+ optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
+ self.assertEqual(eager_res, optim_res)
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_module_dict_iter_keys(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activations = torch.nn.ModuleDict(
+ [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
+ )
+
+ def forward(self, x):
+ for activation_name in self.activations.keys():
+ x = self.activations[activation_name](x)
+ return x
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ # Eager
+ eager_res = MyModule()(torch.ones(10, 10))
+
+ # Compile
+ optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
+ self.assertEqual(eager_res, optim_res)
+ self.assertEqual(cnt.frame_count, 1)
+
+ def test_module_dict_iter_values(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.activations = torch.nn.ModuleDict(
+ [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
+ )
+
+ def forward(self, x):
+ for activation in self.activations.values():
+ x = activation(x)
+ return x
+
+ cnt = torch._dynamo.testing.CompileCounter()
+ # Eager
+ eager_res = MyModule()(torch.ones(10, 10))
+
+ # Compile
+ optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
+ self.assertEqual(eager_res, optim_res)
+ self.assertEqual(cnt.frame_count, 1)
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py
index 7249b86..deee4da 100644
--- a/torch/_dynamo/variables/nn_module.py
+++ b/torch/_dynamo/variables/nn_module.py
@@ -82,6 +82,20 @@
# implement list/iter/tuple/etc calls
base = tx.output.get_submodule(self.module_key)
options = VariableTracker.propagate([self])
+ if isinstance(base, torch.nn.ModuleDict):
+ result = []
+ for name, submod in base.items():
+ name_var = variables.ConstantVariable(name)
+ tx.output.register_attr_or_module(
+ submod,
+ self.module_key,
+ name,
+ source=NNModuleSource(GetItemSource(self.source, name)),
+ **options,
+ )
+ result.append(name_var)
+ return result
+
assert isinstance(
base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
), typestr(base)