Support module dict iter (#99503)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99503
Approved by: https://github.com/Chillee, https://github.com/jansel
diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py
index 8d13c7f..67fdfae 100644
--- a/test/dynamo/test_modules.py
+++ b/test/dynamo/test_modules.py
@@ -1687,6 +1687,72 @@
         self.assertTrue(grad_sizes.keys() == backward_hook_handles.keys())
         self.assertTrue(pre_grad_sizes.keys() == pre_backward_hook_handles.keys())
 
+    def test_module_dict_iter_name(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.activations = torch.nn.ModuleDict(
+                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
+                )
+
+            def forward(self, x):
+                for activation_name in self.activations:
+                    x = self.activations[activation_name](x)
+                return x
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        # Eager
+        eager_res = MyModule()(torch.ones(10, 10))
+
+        # Compile
+        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
+        self.assertEqual(eager_res, optim_res)
+        self.assertEqual(cnt.frame_count, 1)
+
+    def test_module_dict_iter_keys(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.activations = torch.nn.ModuleDict(
+                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
+                )
+
+            def forward(self, x):
+                for activation_name in self.activations.keys():
+                    x = self.activations[activation_name](x)
+                return x
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        # Eager
+        eager_res = MyModule()(torch.ones(10, 10))
+
+        # Compile
+        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
+        self.assertEqual(eager_res, optim_res)
+        self.assertEqual(cnt.frame_count, 1)
+
+    def test_module_dict_iter_values(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.activations = torch.nn.ModuleDict(
+                    [["lrelu", torch.nn.LeakyReLU()], ["prelu", torch.nn.PReLU()]]
+                )
+
+            def forward(self, x):
+                for activation in self.activations.values():
+                    x = activation(x)
+                return x
+
+        cnt = torch._dynamo.testing.CompileCounter()
+        # Eager
+        eager_res = MyModule()(torch.ones(10, 10))
+
+        # Compile
+        optim_res = torch._dynamo.optimize(cnt)(MyModule())(torch.ones(10, 10))
+        self.assertEqual(eager_res, optim_res)
+        self.assertEqual(cnt.frame_count, 1)
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py
index 7249b86..deee4da 100644
--- a/torch/_dynamo/variables/nn_module.py
+++ b/torch/_dynamo/variables/nn_module.py
@@ -82,6 +82,20 @@
         # implement list/iter/tuple/etc calls
         base = tx.output.get_submodule(self.module_key)
         options = VariableTracker.propagate([self])
+        if isinstance(base, torch.nn.ModuleDict):
+            result = []
+            for name, submod in base.items():
+                name_var = variables.ConstantVariable(name)
+                tx.output.register_attr_or_module(
+                    submod,
+                    self.module_key,
+                    name,
+                    source=NNModuleSource(GetItemSource(self.source, name)),
+                    **options,
+                )
+                result.append(name_var)
+            return result
+
         assert isinstance(
             base, (torch.nn.ModuleList, torch.nn.ParameterList, torch.nn.Sequential)
         ), typestr(base)