fix __len__, __contains__, getitem inherited from interface class derived from nn container (closes #40603) (#40789)

Summary:
Define static script implementation of __len__ and __contains__ on any subclass derived from a type such as ModuleList, Sequential, or ModuleDict.  Implement getitem for classes derived from ModuleDict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40789

Reviewed By: eellison

Differential Revision: D22325159

Pulled By: wconstab

fbshipit-source-id: fc1562c29640fe800e13b5a1dd48e595c2c7239b
diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py
new file mode 100644
index 0000000..42d2c32
--- /dev/null
+++ b/test/jit/test_module_containers.py
@@ -0,0 +1,407 @@
+from typing import List
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+from torch.testing._internal.jit_utils import JitTestCase
+
+class TestModuleContainers(JitTestCase):
+    def test_sequential_intermediary_types(self):
+        class A(torch.nn.Module):
+            def __init__(self):
+                super(A, self).__init__()
+
+            def forward(self, x):
+                return x + 3
+
+        class B(torch.nn.Module):
+            def __init__(self):
+                super(B, self).__init__()
+
+            def forward(self, x):
+                return {"1": x}
+
+        class C(torch.nn.Module):
+            def __init__(self):
+                super(C, self).__init__()
+                self.foo = torch.nn.Sequential(A(), B())
+
+            def forward(self, x):
+                return self.foo(x)
+
+        self.checkModule(C(), (torch.tensor(1),))
+
+    def test_moduledict(self):
+        class Inner(torch.nn.Module):
+            def forward(self, x):
+                return x + 10
+
+        class Inner2(torch.nn.Module):
+            def forward(self, x):
+                return x * 2
+
+        class Inner3(torch.nn.Module):
+            def forward(self, x):
+                return (x - 4) * 3
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                modules = OrderedDict([
+                    ('one', Inner()),
+                    ('two', Inner2()),
+                    ('three', Inner3()),
+                ])
+                self.moduledict = nn.ModuleDict(modules)
+
+            def forward(self, x, skip_name):
+                # type: (Tensor, str)
+                names = torch.jit.annotate(List[str], [])
+                values = []
+                for name in self.moduledict:
+                    names.append(name)
+
+                for name, mod in self.moduledict.items():
+                    if name != skip_name:
+                        names.append(name)
+                        x = mod(x)
+                        values.append(x)
+
+                for mod in self.moduledict.values():
+                    x = mod(x)
+                    values.append(x)
+
+                for key in self.moduledict.keys():
+                    names.append(key)
+
+                return x, names
+
+        class M2(M):
+            def __init__(self):
+                super(M2, self).__init__()
+
+            def forward(self, x, skip_name):
+                # type: (Tensor, str)
+                names = torch.jit.annotate(List[str], [])
+                values = []
+                x2 = x
+                iter = 0
+                for name in self.moduledict:
+                    names.append(name)
+
+                for i, (name, mod) in enumerate(self.moduledict.items()):
+                    iter += i
+                    if name != skip_name:
+                        names.append(name)
+                        x = mod(x)
+                        values.append(x)
+
+                for i, mod in enumerate(self.moduledict.values()):
+                    iter += i
+                    x = mod(x)
+                    values.append(x)
+
+                for i, key in enumerate(self.moduledict.keys()):
+                    iter += i
+                    names.append(key)
+
+                for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
+                    iter += i
+                    x2 = mod(mod(x2))
+
+                return x, x2, names, iter
+
+
+        for name in ["", "one", "two", "three"]:
+            inp = torch.tensor(1)
+            self.checkModule(M(), (inp, name))
+            self.checkModule(M2(), (inp, name))
+
+    def test_custom_container_forward(self):
+        class Inner(torch.nn.Module):
+            def forward(self, x):
+                return x + 10
+
+        class CustomSequential(nn.Sequential):
+            def __init__(self):
+                super(CustomSequential, self).__init__(
+                    nn.ReLU(), Inner())
+
+            def forward(self, x):
+                x = x + 3
+                for mod in self:
+                    x = mod(x)
+                return x - 5
+
+        self.checkModule(CustomSequential(), (torch.tensor(.5),))
+
+        class CustomModuleList(nn.ModuleList):
+            def __init__(self):
+                super(CustomModuleList, self).__init__(
+                    [nn.ReLU(), Inner()])
+
+            def forward(self, x):
+                x = x + 3
+                for mod in self:
+                    x = mod(x)
+                return x - 5
+
+        self.checkModule(CustomModuleList(), (torch.tensor(.5),))
+
+        class CustomModuleDict(nn.ModuleDict):
+            def __init__(self):
+                super(CustomModuleDict, self).__init__(
+                    OrderedDict([
+                        ('one', Inner()),
+                        ('two', nn.ReLU()),
+                        ('three', Inner()),
+                    ]))
+
+            def forward(self, x):
+                x = x + 3
+                names = torch.jit.annotate(List[str], [])
+                for name, mod in self.items():
+                    x = mod(x)
+                    names.append(name)
+                return names, x - 5
+
+        self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
+
+    def test_script_module_list_sequential(self):
+        class M(torch.jit.ScriptModule):
+            def __init__(self, mod_list):
+                super(M, self).__init__()
+                self.mods = mod_list
+
+            @torch.jit.script_method
+            def forward(self, v):
+                for m in self.mods:
+                    v = m(v)
+                return v
+
+        with torch.jit.optimized_execution(False):
+            m = M(nn.Sequential(nn.ReLU()))
+            self.assertExportImportModule(m, (torch.randn(2, 2),))
+
+    def test_script_modulelist_index(self):
+        class Sub(torch.nn.Module):
+            def __init__(self, i):
+                super(Sub, self).__init__()
+                self.i = i
+
+            def forward(self, thing):
+                return thing - self.i
+
+        class M(torch.nn.Module):
+            def __init__(self):
+                super(M, self).__init__()
+                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
+
+            def forward(self, v):
+                v = self.mods[4].forward(v)
+                v = self.mods[-1].forward(v)
+                v = self.mods[-9].forward(v)
+                return v
+
+        x = torch.tensor(1)
+        self.checkModule(M(), (x,))
+
+        class MForward(torch.nn.Module):
+            def __init__(self):
+                super(MForward, self).__init__()
+                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
+
+            def forward(self, v):
+                v = self.mods[4](v)
+                v = self.mods[-1](v)
+                v = self.mods[-9](v)
+                return v
+
+        self.checkModule(MForward(), (torch.tensor(1),))
+
+        class M2(M):
+            def __init__(self):
+                super(M2, self).__init__()
+
+            def forward(self, v):
+                return self.mods[-11].forward(v)
+
+        with self.assertRaisesRegex(Exception, "Index -11 out of range"):
+            torch.jit.script(M2())
+
+
+        class M2(M):
+            def __init__(self):
+                super(M2, self).__init__()
+
+            def forward(self, v):
+                return self.mods[-11].forward(v)
+
+        with self.assertRaisesRegex(Exception, "Index -11 out of range"):
+            torch.jit.script(M2())
+
+    def test_module_interface_special_methods(self):
+        class CustomModuleInterface(torch.nn.Module):
+            def __init__(self):
+                super(CustomModuleInterface, self).__init__()
+
+        class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
+            def __init__(self, modules=None):
+                CustomModuleInterface.__init__(self)
+                torch.nn.ModuleList.__init__(self, modules)
+
+        class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
+            def __init__(self, modules=None):
+                CustomModuleInterface.__init__(self)
+                torch.nn.Sequential.__init__(self, modules)
+
+        class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
+            def __init__(self, modules=None):
+                CustomModuleInterface.__init__(self)
+                torch.nn.ModuleDict.__init__(self, modules)
+
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                # work around aliasing issue for 'is' operator by scripting ReLU up front
+                self.submod = torch.jit.script(torch.nn.ReLU())
+                self.modulelist = CustomModuleList([self.submod])
+                self.sequential = CustomSequential(self.submod)
+                self.moduledict = CustomModuleDict({"submod": self.submod})
+
+            def forward(self, inputs):
+                assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList"
+                assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
+                for module in self.modulelist:
+                    assert module is self.submod, "__iter__ failing for ModuleList"
+
+                assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential"
+                assert len(self.sequential) == 1, "__len__ failing for Sequential"
+                for module in self.sequential:
+                    assert module is self.submod, "__iter__ failing for Sequential"
+
+                assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict"
+                assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
+
+                # note: unable to index moduledict with a string variable currently
+                i = 0
+                for key in self.moduledict:
+                    i += 1
+                assert i == len(self.moduledict), "iteration failing for ModuleDict"
+
+                assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
+
+                for key in self.moduledict.keys():
+                    assert key == "submod", "keys() fails for ModuleDict"
+
+                for item in self.moduledict.items():
+                    assert item[0] == "submod", "items() fails for ModuleDict"
+                    assert item[1] is self.submod, "items() fails for ModuleDict"
+
+                for value in self.moduledict.values():
+                    assert value is self.submod, "values() fails for ModuleDict"
+
+                return inputs
+
+        m = MyModule()
+        self.checkModule(m, [torch.randn(2, 2)])
+
+    def test_special_method_with_override(self):
+        class CustomModuleInterface(torch.nn.Module):
+            def __init__(self):
+                super(CustomModuleInterface, self).__init__()
+
+        class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
+            def __init__(self, modules=None):
+                CustomModuleInterface.__init__(self)
+                torch.nn.ModuleList.__init__(self, modules)
+
+            def __len__(self):
+                # this is arbitrary, just to check that the overridden py __len__ from
+                # CustomModuleList takes precedence over the automatically generated
+                # __len__ added by the jit compiler
+                return 2
+
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                # work around aliasing issue for 'is' operator by scripting ReLU up front
+                self.submod = torch.jit.script(torch.nn.ReLU())
+                self.modulelist = CustomModuleList([self.submod])
+
+            def forward(self, inputs):
+                assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
+                return inputs
+
+        m = MyModule()
+        self.checkModule(m, [torch.randn(2, 2)])
+        mm = torch.jit.script(m)
+
+    def test_moduledict_getitem(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                self.relu = torch.jit.script(torch.nn.ReLU())
+                self.tanh = torch.jit.script(torch.nn.Tanh())
+                self.moduledict = torch.nn.ModuleDict({"relu": self.relu,
+                                                       "tanh": self.tanh})
+
+            def forward(self, input):
+                assert self.moduledict['relu'] is self.relu
+                assert self.moduledict['tanh'] is self.tanh
+                return input
+
+        m = MyModule()
+        self.checkModule(m, [torch.randn(2, 2)])
+
+    def test_moduledict_keyerror(self):
+        class BadModule(torch.nn.Module):
+            def __init__(self):
+                super(BadModule, self).__init__()
+                self.moduledict = torch.nn.ModuleDict({"foo": None,
+                                                       "bar": None})
+
+            def forward(self, input):
+                assert self.moduledict['blah'] == "blah", "this is a keyerror"
+
+        with self.assertRaisesRegex(RuntimeError, "Key Error, blah"):
+            b = BadModule()
+            torch.jit.script(b)
+
+        class AnotherBadModule(torch.nn.Module):
+            def __init__(self):
+                super(AnotherBadModule, self).__init__()
+                self.moduledict = torch.nn.ModuleDict({"foo": None,
+                                                       "bar": None})
+
+            def forward(self, input):
+                idx = 'blah'
+                assert self.moduledict[idx] == "blah", "this is a string literal error"
+
+        with self.assertRaisesRegex(RuntimeError, "Unable to extract string literal index. "
+                                                  "ModuleDict indexing is only supported with string literals."):
+            b = AnotherBadModule()
+            torch.jit.script(b)
+
+    def test_empty_dict_override_contains(self):
+        class CustomModuleInterface(torch.nn.Module):
+            def __init__(self):
+                super(CustomModuleInterface, self).__init__()
+
+        class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
+            def __init__(self, modules=None):
+                CustomModuleInterface.__init__(self)
+                torch.nn.ModuleDict.__init__(self, modules)
+
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                # work around aliasing issue for 'is' operator by scripting ReLU up front
+                self.submod = torch.jit.script(torch.nn.ReLU())
+                self.moduledict = CustomModuleDict()
+
+            def forward(self, inputs):
+                assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict"
+                return inputs
+
+        m = MyModule()
+        self.checkModule(m, [torch.randn(2, 2)])
diff --git a/test/test_jit.py b/test/test_jit.py
index 9333bc0..1b354be 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8229,142 +8229,6 @@
             m = M()
             self.assertEqual(m(), 10)
 
-    def test_moduledict(self):
-        class Inner(torch.nn.Module):
-            def forward(self, x):
-                return x + 10
-
-        class Inner2(torch.nn.Module):
-            def forward(self, x):
-                return x * 2
-
-        class Inner3(torch.nn.Module):
-            def forward(self, x):
-                return (x - 4) * 3
-
-        class M(torch.nn.Module):
-            def __init__(self):
-                super(M, self).__init__()
-                modules = OrderedDict([
-                    ('one', Inner()),
-                    ('two', Inner2()),
-                    ('three', Inner3()),
-                ])
-                self.moduledict = nn.ModuleDict(modules)
-
-            def forward(self, x, skip_name):
-                # type: (Tensor, str)
-                names = torch.jit.annotate(List[str], [])
-                values = []
-                for name in self.moduledict:
-                    names.append(name)
-
-                for name, mod in self.moduledict.items():
-                    if name != skip_name:
-                        names.append(name)
-                        x = mod(x)
-                        values.append(x)
-
-                for mod in self.moduledict.values():
-                    x = mod(x)
-                    values.append(x)
-
-                for key in self.moduledict.keys():
-                    names.append(key)
-
-                return x, names
-
-        class M2(M):
-            def __init__(self):
-                super(M2, self).__init__()
-
-            def forward(self, x, skip_name):
-                # type: (Tensor, str)
-                names = torch.jit.annotate(List[str], [])
-                values = []
-                x2 = x
-                iter = 0
-                for name in self.moduledict:
-                    names.append(name)
-
-                for i, (name, mod) in enumerate(self.moduledict.items()):
-                    iter += i
-                    if name != skip_name:
-                        names.append(name)
-                        x = mod(x)
-                        values.append(x)
-
-                for i, mod in enumerate(self.moduledict.values()):
-                    iter += i
-                    x = mod(x)
-                    values.append(x)
-
-                for i, key in enumerate(self.moduledict.keys()):
-                    iter += i
-                    names.append(key)
-
-                for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
-                    iter += i
-                    x2 = mod(mod(x2))
-
-                return x, x2, names, iter
-
-
-        for name in ["", "one", "two", "three"]:
-            inp = torch.tensor(1)
-            self.checkModule(M(), (inp, name))
-            self.checkModule(M2(), (inp, name))
-
-    def test_custom_container_forward(self):
-        class Inner(torch.nn.Module):
-            def forward(self, x):
-                return x + 10
-
-        class CustomSequential(nn.Sequential):
-            def __init__(self):
-                super(CustomSequential, self).__init__(
-                    nn.ReLU(), Inner())
-
-            def forward(self, x):
-                x = x + 3
-                for mod in self:
-                    x = mod(x)
-                return x - 5
-
-        self.checkModule(CustomSequential(), (torch.tensor(.5),))
-
-        class CustomModuleList(nn.ModuleList):
-            def __init__(self):
-                super(CustomModuleList, self).__init__(
-                    [nn.ReLU(), Inner()])
-
-            def forward(self, x):
-                x = x + 3
-                for mod in self:
-                    x = mod(x)
-                return x - 5
-
-        self.checkModule(CustomModuleList(), (torch.tensor(.5),))
-
-        class CustomModuleDict(nn.ModuleDict):
-            def __init__(self):
-                super(CustomModuleDict, self).__init__(
-                    OrderedDict([
-                        ('one', Inner()),
-                        ('two', nn.ReLU()),
-                        ('three', Inner()),
-                    ]))
-
-            def forward(self, x):
-                x = x + 3
-                names = torch.jit.annotate(List[str], [])
-                for name, mod in self.items():
-                    x = mod(x)
-                    names.append(name)
-                return names, x - 5
-
-        self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
-
     def test_override_magic(self):
         class OverrideMagic(nn.Module):
             def __init__(self):
@@ -8421,63 +8285,6 @@
             with self.assertRaisesRegex(Exception, "object is not iterable"):
                 print(list(m))
 
-    def test_script_modulelist_index(self):
-        class Sub(torch.nn.Module):
-            def __init__(self, i):
-                super(Sub, self).__init__()
-                self.i = i
-
-            def forward(self, thing):
-                return thing - self.i
-
-        class M(torch.nn.Module):
-            def __init__(self):
-                super(M, self).__init__()
-                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
-
-            def forward(self, v):
-                v = self.mods[4].forward(v)
-                v = self.mods[-1].forward(v)
-                v = self.mods[-9].forward(v)
-                return v
-
-        x = torch.tensor(1)
-        self.checkModule(M(), (x,))
-
-        class MForward(torch.nn.Module):
-            def __init__(self):
-                super(MForward, self).__init__()
-                self.mods = nn.ModuleList([Sub(i) for i in range(10)])
-
-            def forward(self, v):
-                v = self.mods[4](v)
-                v = self.mods[-1](v)
-                v = self.mods[-9](v)
-                return v
-
-        self.checkModule(MForward(), (torch.tensor(1),))
-
-        class M2(M):
-            def __init__(self):
-                super(M2, self).__init__()
-
-            def forward(self, v):
-                return self.mods[-11].forward(v)
-
-        with self.assertRaisesRegex(Exception, "Index -11 out of range"):
-            torch.jit.script(M2())
-
-
-        class M2(M):
-            def __init__(self):
-                super(M2, self).__init__()
-
-            def forward(self, v):
-                return self.mods[-11].forward(v)
-
-        with self.assertRaisesRegex(Exception, "Index -11 out of range"):
-            torch.jit.script(M2())
-
     def test_attr_qscheme_script(self):
         class Foo(torch.nn.Module):
             def __init__(self):
@@ -8740,22 +8547,6 @@
         with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
             M()
 
-    def test_script_module_list_sequential(self):
-        class M(torch.jit.ScriptModule):
-            def __init__(self, mod_list):
-                super(M, self).__init__()
-                self.mods = mod_list
-
-            @torch.jit.script_method
-            def forward(self, v):
-                for m in self.mods:
-                    v = m(v)
-                return v
-
-        with torch.jit.optimized_execution(False):
-            m = M(nn.Sequential(nn.ReLU()))
-            self.assertExportImportModule(m, (torch.randn(2, 2),))
-
     def test_attr_module_constants(self):
         class M2(torch.jit.ScriptModule):
             def __init__(self, mod_list):
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index 6c11ad5..4b5d872 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -235,6 +235,24 @@
     Value* idx) {
   if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
     return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx);
+  } else if (
+      concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
+    if (auto ivalue = toIValue(idx)) {
+      auto sd = getSugaredDict(loc, m);
+      auto idx_str = ivalue->toStringRef();
+      auto keys_iter = sd->keys_;
+      auto module_values_iter = sd->modules_;
+      for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
+        auto key = keys_iter->tup_.at(i);
+        auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
+        if (key_str == idx_str) {
+          return module_values_iter->tup_.at(i);
+        }
+      }
+      throw ErrorReport(loc) << "Key Error, " << idx_str;
+    }
+    throw ErrorReport(loc)
+        << "Unable to extract string literal index. ModuleDict indexing is only supported with string literals.";
   }
   throw ErrorReport(loc)
       << "Only ModuleList, Sequential, and ModuleDict modules are subscriptable";
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index bc442ea..f566925 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -314,7 +314,6 @@
         concrete_type_builder = infer_concrete_type_builder(nn_module)
         concrete_type_builder.set_poisoned()
         concrete_type = concrete_type_builder.build()
-
     return create_script_module_impl(nn_module, concrete_type, stubs_fn)
 
 def create_script_module_impl(nn_module, concrete_type, stubs_fn):
@@ -351,6 +350,7 @@
             else:
                 # use the default recursive rule to compile the module
                 scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
+
             cpp_module.setattr(name, scripted)
             script_module._modules[name] = scripted
 
@@ -377,6 +377,19 @@
         torch._C._run_emit_module_hook(cpp_module)
         concrete_type_store.methods_compiled.add(concrete_type)
 
+    # Special handling so methods like __len__ work in script methods on classes derived from containers
+    if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \
+            '__len__' not in cpp_module._method_names():
+        script_module.define("def __len__(self):\n   return {}\n".format(len(nn_module)))
+    if isinstance(nn_module, torch.nn.ModuleDict) and \
+            '__contains__' not in cpp_module._method_names():
+        if len(nn_module.keys()):
+            keys = repr(list(nn_module.keys()))
+            script_module.define("def __contains__(self, key: str):\n   return key in {}\n".format(keys))
+        else:
+            script_module.define("def __contains__(self, key: str):\n   return False\n")
+
+
     # Make the compiled methods available to the Python ScriptModule class.
     for stub in stubs:
         if stub.original_method is None: