fix __len__, __contains__, getitem inherited from interface class derived from nn container (closes #40603) (#40789)
Summary:
Define static script implementation of __len__ and __contains__ on any subclass derived from a type such as ModuleList, Sequential, or ModuleDict. Implement getitem for classes derived from ModuleDict.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40789
Reviewed By: eellison
Differential Revision: D22325159
Pulled By: wconstab
fbshipit-source-id: fc1562c29640fe800e13b5a1dd48e595c2c7239b
diff --git a/test/jit/test_module_containers.py b/test/jit/test_module_containers.py
new file mode 100644
index 0000000..42d2c32
--- /dev/null
+++ b/test/jit/test_module_containers.py
@@ -0,0 +1,407 @@
+from typing import List
+from collections import OrderedDict
+import torch
+import torch.nn as nn
+from torch.testing._internal.jit_utils import JitTestCase
+
+class TestModuleContainers(JitTestCase):
+ def test_sequential_intermediary_types(self):
+ class A(torch.nn.Module):
+ def __init__(self):
+ super(A, self).__init__()
+
+ def forward(self, x):
+ return x + 3
+
+ class B(torch.nn.Module):
+ def __init__(self):
+ super(B, self).__init__()
+
+ def forward(self, x):
+ return {"1": x}
+
+ class C(torch.nn.Module):
+ def __init__(self):
+ super(C, self).__init__()
+ self.foo = torch.nn.Sequential(A(), B())
+
+ def forward(self, x):
+ return self.foo(x)
+
+ self.checkModule(C(), (torch.tensor(1),))
+
+ def test_moduledict(self):
+ class Inner(torch.nn.Module):
+ def forward(self, x):
+ return x + 10
+
+ class Inner2(torch.nn.Module):
+ def forward(self, x):
+ return x * 2
+
+ class Inner3(torch.nn.Module):
+ def forward(self, x):
+ return (x - 4) * 3
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ modules = OrderedDict([
+ ('one', Inner()),
+ ('two', Inner2()),
+ ('three', Inner3()),
+ ])
+ self.moduledict = nn.ModuleDict(modules)
+
+ def forward(self, x, skip_name):
+ # type: (Tensor, str)
+ names = torch.jit.annotate(List[str], [])
+ values = []
+ for name in self.moduledict:
+ names.append(name)
+
+ for name, mod in self.moduledict.items():
+ if name != skip_name:
+ names.append(name)
+ x = mod(x)
+ values.append(x)
+
+ for mod in self.moduledict.values():
+ x = mod(x)
+ values.append(x)
+
+ for key in self.moduledict.keys():
+ names.append(key)
+
+ return x, names
+
+ class M2(M):
+ def __init__(self):
+ super(M2, self).__init__()
+
+ def forward(self, x, skip_name):
+ # type: (Tensor, str)
+ names = torch.jit.annotate(List[str], [])
+ values = []
+ x2 = x
+ iter = 0
+ for name in self.moduledict:
+ names.append(name)
+
+ for i, (name, mod) in enumerate(self.moduledict.items()):
+ iter += i
+ if name != skip_name:
+ names.append(name)
+ x = mod(x)
+ values.append(x)
+
+ for i, mod in enumerate(self.moduledict.values()):
+ iter += i
+ x = mod(x)
+ values.append(x)
+
+ for i, key in enumerate(self.moduledict.keys()):
+ iter += i
+ names.append(key)
+
+ for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
+ iter += i
+ x2 = mod(mod(x2))
+
+ return x, x2, names, iter
+
+
+ for name in ["", "one", "two", "three"]:
+ inp = torch.tensor(1)
+ self.checkModule(M(), (inp, name))
+ self.checkModule(M2(), (inp, name))
+
+ def test_custom_container_forward(self):
+ class Inner(torch.nn.Module):
+ def forward(self, x):
+ return x + 10
+
+ class CustomSequential(nn.Sequential):
+ def __init__(self):
+ super(CustomSequential, self).__init__(
+ nn.ReLU(), Inner())
+
+ def forward(self, x):
+ x = x + 3
+ for mod in self:
+ x = mod(x)
+ return x - 5
+
+ self.checkModule(CustomSequential(), (torch.tensor(.5),))
+
+ class CustomModuleList(nn.ModuleList):
+ def __init__(self):
+ super(CustomModuleList, self).__init__(
+ [nn.ReLU(), Inner()])
+
+ def forward(self, x):
+ x = x + 3
+ for mod in self:
+ x = mod(x)
+ return x - 5
+
+ self.checkModule(CustomModuleList(), (torch.tensor(.5),))
+
+ class CustomModuleDict(nn.ModuleDict):
+ def __init__(self):
+ super(CustomModuleDict, self).__init__(
+ OrderedDict([
+ ('one', Inner()),
+ ('two', nn.ReLU()),
+ ('three', Inner()),
+ ]))
+
+ def forward(self, x):
+ x = x + 3
+ names = torch.jit.annotate(List[str], [])
+ for name, mod in self.items():
+ x = mod(x)
+ names.append(name)
+ return names, x - 5
+
+ self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
+
+ def test_script_module_list_sequential(self):
+ class M(torch.jit.ScriptModule):
+ def __init__(self, mod_list):
+ super(M, self).__init__()
+ self.mods = mod_list
+
+ @torch.jit.script_method
+ def forward(self, v):
+ for m in self.mods:
+ v = m(v)
+ return v
+
+ with torch.jit.optimized_execution(False):
+ m = M(nn.Sequential(nn.ReLU()))
+ self.assertExportImportModule(m, (torch.randn(2, 2),))
+
+ def test_script_modulelist_index(self):
+ class Sub(torch.nn.Module):
+ def __init__(self, i):
+ super(Sub, self).__init__()
+ self.i = i
+
+ def forward(self, thing):
+ return thing - self.i
+
+ class M(torch.nn.Module):
+ def __init__(self):
+ super(M, self).__init__()
+ self.mods = nn.ModuleList([Sub(i) for i in range(10)])
+
+ def forward(self, v):
+ v = self.mods[4].forward(v)
+ v = self.mods[-1].forward(v)
+ v = self.mods[-9].forward(v)
+ return v
+
+ x = torch.tensor(1)
+ self.checkModule(M(), (x,))
+
+ class MForward(torch.nn.Module):
+ def __init__(self):
+ super(MForward, self).__init__()
+ self.mods = nn.ModuleList([Sub(i) for i in range(10)])
+
+ def forward(self, v):
+ v = self.mods[4](v)
+ v = self.mods[-1](v)
+ v = self.mods[-9](v)
+ return v
+
+ self.checkModule(MForward(), (torch.tensor(1),))
+
+ class M2(M):
+ def __init__(self):
+ super(M2, self).__init__()
+
+ def forward(self, v):
+ return self.mods[-11].forward(v)
+
+ with self.assertRaisesRegex(Exception, "Index -11 out of range"):
+ torch.jit.script(M2())
+
+
+ class M2(M):
+ def __init__(self):
+ super(M2, self).__init__()
+
+ def forward(self, v):
+ return self.mods[-11].forward(v)
+
+ with self.assertRaisesRegex(Exception, "Index -11 out of range"):
+ torch.jit.script(M2())
+
+ def test_module_interface_special_methods(self):
+ class CustomModuleInterface(torch.nn.Module):
+ def __init__(self):
+ super(CustomModuleInterface, self).__init__()
+
+ class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
+ def __init__(self, modules=None):
+ CustomModuleInterface.__init__(self)
+ torch.nn.ModuleList.__init__(self, modules)
+
+ class CustomSequential(CustomModuleInterface, torch.nn.Sequential):
+ def __init__(self, modules=None):
+ CustomModuleInterface.__init__(self)
+ torch.nn.Sequential.__init__(self, modules)
+
+ class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
+ def __init__(self, modules=None):
+ CustomModuleInterface.__init__(self)
+ torch.nn.ModuleDict.__init__(self, modules)
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ # work around aliasing issue for 'is' operator by scripting ReLU up front
+ self.submod = torch.jit.script(torch.nn.ReLU())
+ self.modulelist = CustomModuleList([self.submod])
+ self.sequential = CustomSequential(self.submod)
+ self.moduledict = CustomModuleDict({"submod": self.submod})
+
+ def forward(self, inputs):
+ assert self.modulelist[0] is self.submod, "__getitem__ failing for ModuleList"
+ assert len(self.modulelist) == 1, "__len__ failing for ModuleList"
+ for module in self.modulelist:
+ assert module is self.submod, "__iter__ failing for ModuleList"
+
+ assert self.sequential[0] is self.submod, "__getitem__ failing for Sequential"
+ assert len(self.sequential) == 1, "__len__ failing for Sequential"
+ for module in self.sequential:
+ assert module is self.submod, "__iter__ failing for Sequential"
+
+ assert self.moduledict["submod"] is self.submod, "__getitem__ failing for ModuleDict"
+ assert len(self.moduledict) == 1, "__len__ failing for ModuleDict"
+
+ # note: unable to index moduledict with a string variable currently
+ i = 0
+ for key in self.moduledict:
+ i += 1
+ assert i == len(self.moduledict), "iteration failing for ModuleDict"
+
+ assert "submod" in self.moduledict, "__contains__ fails for ModuleDict"
+
+ for key in self.moduledict.keys():
+ assert key == "submod", "keys() fails for ModuleDict"
+
+ for item in self.moduledict.items():
+ assert item[0] == "submod", "items() fails for ModuleDict"
+ assert item[1] is self.submod, "items() fails for ModuleDict"
+
+ for value in self.moduledict.values():
+ assert value is self.submod, "values() fails for ModuleDict"
+
+ return inputs
+
+ m = MyModule()
+ self.checkModule(m, [torch.randn(2, 2)])
+
+ def test_special_method_with_override(self):
+ class CustomModuleInterface(torch.nn.Module):
+ def __init__(self):
+ super(CustomModuleInterface, self).__init__()
+
+ class CustomModuleList(CustomModuleInterface, torch.nn.ModuleList):
+ def __init__(self, modules=None):
+ CustomModuleInterface.__init__(self)
+ torch.nn.ModuleList.__init__(self, modules)
+
+ def __len__(self):
+ # this is arbitrary, just to check that the overridden py __len__ from
+ # CustomModuleList takes precedence over the automatically generated
+ # __len__ added by the jit compiler
+ return 2
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ # work around aliasing issue for 'is' operator by scripting ReLU up front
+ self.submod = torch.jit.script(torch.nn.ReLU())
+ self.modulelist = CustomModuleList([self.submod])
+
+ def forward(self, inputs):
+ assert len(self.modulelist) == 2, "__len__ failing for ModuleList"
+ return inputs
+
+ m = MyModule()
+ self.checkModule(m, [torch.randn(2, 2)])
+ mm = torch.jit.script(m)
+
+ def test_moduledict_getitem(self):
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ self.relu = torch.jit.script(torch.nn.ReLU())
+ self.tanh = torch.jit.script(torch.nn.Tanh())
+ self.moduledict = torch.nn.ModuleDict({"relu": self.relu,
+ "tanh": self.tanh})
+
+ def forward(self, input):
+ assert self.moduledict['relu'] is self.relu
+ assert self.moduledict['tanh'] is self.tanh
+ return input
+
+ m = MyModule()
+ self.checkModule(m, [torch.randn(2, 2)])
+
+ def test_moduledict_keyerror(self):
+ class BadModule(torch.nn.Module):
+ def __init__(self):
+ super(BadModule, self).__init__()
+ self.moduledict = torch.nn.ModuleDict({"foo": None,
+ "bar": None})
+
+ def forward(self, input):
+ assert self.moduledict['blah'] == "blah", "this is a keyerror"
+
+ with self.assertRaisesRegex(RuntimeError, "Key Error, blah"):
+ b = BadModule()
+ torch.jit.script(b)
+
+ class AnotherBadModule(torch.nn.Module):
+ def __init__(self):
+ super(AnotherBadModule, self).__init__()
+ self.moduledict = torch.nn.ModuleDict({"foo": None,
+ "bar": None})
+
+ def forward(self, input):
+ idx = 'blah'
+ assert self.moduledict[idx] == "blah", "this is a string literal error"
+
+ with self.assertRaisesRegex(RuntimeError, "Unable to extract string literal index. "
+ "ModuleDict indexing is only supported with string literals."):
+ b = AnotherBadModule()
+ torch.jit.script(b)
+
+ def test_empty_dict_override_contains(self):
+ class CustomModuleInterface(torch.nn.Module):
+ def __init__(self):
+ super(CustomModuleInterface, self).__init__()
+
+ class CustomModuleDict(CustomModuleInterface, torch.nn.ModuleDict):
+ def __init__(self, modules=None):
+ CustomModuleInterface.__init__(self)
+ torch.nn.ModuleDict.__init__(self, modules)
+
+ class MyModule(torch.nn.Module):
+ def __init__(self):
+ super(MyModule, self).__init__()
+ # work around aliasing issue for 'is' operator by scripting ReLU up front
+ self.submod = torch.jit.script(torch.nn.ReLU())
+ self.moduledict = CustomModuleDict()
+
+ def forward(self, inputs):
+ assert "submod" not in self.moduledict, "__contains__ fails for ModuleDict"
+ return inputs
+
+ m = MyModule()
+ self.checkModule(m, [torch.randn(2, 2)])
diff --git a/test/test_jit.py b/test/test_jit.py
index 9333bc0..1b354be 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -8229,142 +8229,6 @@
m = M()
self.assertEqual(m(), 10)
- def test_moduledict(self):
- class Inner(torch.nn.Module):
- def forward(self, x):
- return x + 10
-
- class Inner2(torch.nn.Module):
- def forward(self, x):
- return x * 2
-
- class Inner3(torch.nn.Module):
- def forward(self, x):
- return (x - 4) * 3
-
- class M(torch.nn.Module):
- def __init__(self):
- super(M, self).__init__()
- modules = OrderedDict([
- ('one', Inner()),
- ('two', Inner2()),
- ('three', Inner3()),
- ])
- self.moduledict = nn.ModuleDict(modules)
-
- def forward(self, x, skip_name):
- # type: (Tensor, str)
- names = torch.jit.annotate(List[str], [])
- values = []
- for name in self.moduledict:
- names.append(name)
-
- for name, mod in self.moduledict.items():
- if name != skip_name:
- names.append(name)
- x = mod(x)
- values.append(x)
-
- for mod in self.moduledict.values():
- x = mod(x)
- values.append(x)
-
- for key in self.moduledict.keys():
- names.append(key)
-
- return x, names
-
- class M2(M):
- def __init__(self):
- super(M2, self).__init__()
-
- def forward(self, x, skip_name):
- # type: (Tensor, str)
- names = torch.jit.annotate(List[str], [])
- values = []
- x2 = x
- iter = 0
- for name in self.moduledict:
- names.append(name)
-
- for i, (name, mod) in enumerate(self.moduledict.items()):
- iter += i
- if name != skip_name:
- names.append(name)
- x = mod(x)
- values.append(x)
-
- for i, mod in enumerate(self.moduledict.values()):
- iter += i
- x = mod(x)
- values.append(x)
-
- for i, key in enumerate(self.moduledict.keys()):
- iter += i
- names.append(key)
-
- for mod, mod in zip(self.moduledict.values(), self.moduledict.values()):
- iter += i
- x2 = mod(mod(x2))
-
- return x, x2, names, iter
-
-
- for name in ["", "one", "two", "three"]:
- inp = torch.tensor(1)
- self.checkModule(M(), (inp, name))
- self.checkModule(M2(), (inp, name))
-
- def test_custom_container_forward(self):
- class Inner(torch.nn.Module):
- def forward(self, x):
- return x + 10
-
- class CustomSequential(nn.Sequential):
- def __init__(self):
- super(CustomSequential, self).__init__(
- nn.ReLU(), Inner())
-
- def forward(self, x):
- x = x + 3
- for mod in self:
- x = mod(x)
- return x - 5
-
- self.checkModule(CustomSequential(), (torch.tensor(.5),))
-
- class CustomModuleList(nn.ModuleList):
- def __init__(self):
- super(CustomModuleList, self).__init__(
- [nn.ReLU(), Inner()])
-
- def forward(self, x):
- x = x + 3
- for mod in self:
- x = mod(x)
- return x - 5
-
- self.checkModule(CustomModuleList(), (torch.tensor(.5),))
-
- class CustomModuleDict(nn.ModuleDict):
- def __init__(self):
- super(CustomModuleDict, self).__init__(
- OrderedDict([
- ('one', Inner()),
- ('two', nn.ReLU()),
- ('three', Inner()),
- ]))
-
- def forward(self, x):
- x = x + 3
- names = torch.jit.annotate(List[str], [])
- for name, mod in self.items():
- x = mod(x)
- names.append(name)
- return names, x - 5
-
- self.checkModule(CustomModuleDict(), (torch.tensor(.5),))
-
def test_override_magic(self):
class OverrideMagic(nn.Module):
def __init__(self):
@@ -8421,63 +8285,6 @@
with self.assertRaisesRegex(Exception, "object is not iterable"):
print(list(m))
- def test_script_modulelist_index(self):
- class Sub(torch.nn.Module):
- def __init__(self, i):
- super(Sub, self).__init__()
- self.i = i
-
- def forward(self, thing):
- return thing - self.i
-
- class M(torch.nn.Module):
- def __init__(self):
- super(M, self).__init__()
- self.mods = nn.ModuleList([Sub(i) for i in range(10)])
-
- def forward(self, v):
- v = self.mods[4].forward(v)
- v = self.mods[-1].forward(v)
- v = self.mods[-9].forward(v)
- return v
-
- x = torch.tensor(1)
- self.checkModule(M(), (x,))
-
- class MForward(torch.nn.Module):
- def __init__(self):
- super(MForward, self).__init__()
- self.mods = nn.ModuleList([Sub(i) for i in range(10)])
-
- def forward(self, v):
- v = self.mods[4](v)
- v = self.mods[-1](v)
- v = self.mods[-9](v)
- return v
-
- self.checkModule(MForward(), (torch.tensor(1),))
-
- class M2(M):
- def __init__(self):
- super(M2, self).__init__()
-
- def forward(self, v):
- return self.mods[-11].forward(v)
-
- with self.assertRaisesRegex(Exception, "Index -11 out of range"):
- torch.jit.script(M2())
-
-
- class M2(M):
- def __init__(self):
- super(M2, self).__init__()
-
- def forward(self, v):
- return self.mods[-11].forward(v)
-
- with self.assertRaisesRegex(Exception, "Index -11 out of range"):
- torch.jit.script(M2())
-
def test_attr_qscheme_script(self):
class Foo(torch.nn.Module):
def __init__(self):
@@ -8740,22 +8547,6 @@
with self.assertRaisesRegex(RuntimeError, "'int' object is not iterable"):
M()
- def test_script_module_list_sequential(self):
- class M(torch.jit.ScriptModule):
- def __init__(self, mod_list):
- super(M, self).__init__()
- self.mods = mod_list
-
- @torch.jit.script_method
- def forward(self, v):
- for m in self.mods:
- v = m(v)
- return v
-
- with torch.jit.optimized_execution(False):
- m = M(nn.Sequential(nn.ReLU()))
- self.assertExportImportModule(m, (torch.randn(2, 2),))
-
def test_attr_module_constants(self):
class M2(torch.jit.ScriptModule):
def __init__(self, mod_list):
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index 6c11ad5..4b5d872 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -235,6 +235,24 @@
Value* idx) {
if (concreteType_->getIterableModuleKind() == IterableModuleKind::LIST) {
return getSugaredDict(loc, m)->getModules()->getitem(loc, m, idx);
+ } else if (
+ concreteType_->getIterableModuleKind() == IterableModuleKind::DICT) {
+ if (auto ivalue = toIValue(idx)) {
+ auto sd = getSugaredDict(loc, m);
+ auto idx_str = ivalue->toStringRef();
+ auto keys_iter = sd->keys_;
+ auto module_values_iter = sd->modules_;
+ for (size_t i = 0; i < keys_iter->tup_.size(); ++i) {
+ auto key = keys_iter->tup_.at(i);
+ auto key_str = toIValue(key->asValue(loc, m))->toStringRef();
+ if (key_str == idx_str) {
+ return module_values_iter->tup_.at(i);
+ }
+ }
+ throw ErrorReport(loc) << "Key Error, " << idx_str;
+ }
+ throw ErrorReport(loc)
+ << "Unable to extract string literal index. ModuleDict indexing is only supported with string literals.";
}
throw ErrorReport(loc)
<< "Only ModuleList, Sequential, and ModuleDict modules are subscriptable";
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index bc442ea..f566925 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -314,7 +314,6 @@
concrete_type_builder = infer_concrete_type_builder(nn_module)
concrete_type_builder.set_poisoned()
concrete_type = concrete_type_builder.build()
-
return create_script_module_impl(nn_module, concrete_type, stubs_fn)
def create_script_module_impl(nn_module, concrete_type, stubs_fn):
@@ -351,6 +350,7 @@
else:
# use the default recursive rule to compile the module
scripted = create_script_module_impl(orig_value, sub_concrete_type, infer_methods_to_compile)
+
cpp_module.setattr(name, scripted)
script_module._modules[name] = scripted
@@ -377,6 +377,19 @@
torch._C._run_emit_module_hook(cpp_module)
concrete_type_store.methods_compiled.add(concrete_type)
+ # Special handling so methods like __len__ work in script methods on classes derived from containers
+ if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)) and \
+ '__len__' not in cpp_module._method_names():
+ script_module.define("def __len__(self):\n return {}\n".format(len(nn_module)))
+ if isinstance(nn_module, torch.nn.ModuleDict) and \
+ '__contains__' not in cpp_module._method_names():
+ if len(nn_module.keys()):
+ keys = repr(list(nn_module.keys()))
+ script_module.define("def __contains__(self, key: str):\n return key in {}\n".format(keys))
+ else:
+ script_module.define("def __contains__(self, key: str):\n return False\n")
+
+
# Make the compiled methods available to the Python ScriptModule class.
for stub in stubs:
if stub.original_method is None: