[jit] Better match behavior of loaded ScriptModules vs. freshly created ones (#43298)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43298
IR emitter uses `ModuleValue` to represent ScriptModules and emit IR for
attribute access, submodule access, etc.
`ModuleValue` relies on two pieces of information, the JIT type of the
module, and the `ConcreteModuleType`, which encapsulates Python-only
information about the module.
ScriptModules loaded from a package used to create a dummy
ConcreteModuleType without any info in it. This led to divergences in
behavior during compilation.
This PR makes the two ways of constructing a ConcreteModuleType equivalent,
modulo any py-only information (which, by definition, is never present in
packaged files anyway).
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D23228738
Pulled By: suo
fbshipit-source-id: f6a660f42272640ca1a1bb8c4ee7edfa2d1b07cc
diff --git a/test/test_jit.py b/test/test_jit.py
index fee5340..f070e7b 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15201,6 +15201,47 @@
with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
torch.jit.script(wrapped)
+ def test_rescripting_loaded_modules(self):
+ class InnerSubmod(nn.Module):
+ __constants__ = ['my_constant']
+
+ def __init__(self):
+ super().__init__()
+ self.register_buffer("foo", torch.ones(1))
+ self.register_parameter("bar", torch.nn.Parameter(torch.ones(1)))
+ self.baz = torch.ones(1)
+ self.my_constant = 1
+
+ def forward(self, x):
+ return x + x
+
+ class Inner(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.submod = InnerSubmod()
+
+ def forward(self, x):
+ return self.submod(x)
+
+ class Wrapper(nn.Module):
+ def __init__(self, inner):
+ super().__init__()
+ self.inner = inner
+
+ def forward(self, x):
+ # access inner elements
+ ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz
+ ret = ret + self.inner.submod.my_constant
+ return ret
+
+ inner_module = torch.jit.script(Inner())
+ wrapped = Wrapper(inner_module)
+ self.checkModule(wrapped, torch.ones(1))
+
+ inner_module_loaded = self.getExportImportCopy(inner_module)
+ wrapped_loaded = Wrapper(inner_module_loaded)
+ self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1)))
+
# known to be failing in tracer
EXCLUDE_TRACED = {
diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp
index 96433bc..169b589 100644
--- a/torch/csrc/jit/frontend/concrete_module_type.cpp
+++ b/torch/csrc/jit/frontend/concrete_module_type.cpp
@@ -29,20 +29,7 @@
}
for (const auto& pr : constants_) {
- const auto& name = pr.first;
- const auto& val = pr.second.v_;
- auto match = tryToInferType(val);
- if (!match.success()) {
- TORCH_INTERNAL_ASSERT(
- false,
- "We need to infer the type of constant to convert the python value to IValue, but failed to infer type of ",
- py::str(val),
- "\n:",
- match.reason());
- }
- // Validation and conversion to make sure `val` is a valid constant
- // is done in python, see `torch/jit/_recursive.py`
- cls->addConstant(name, toIValue(val, match.type()));
+ cls->addConstant(pr.first, pr.second);
}
for (const auto& moduleInfo : modules_) {
@@ -57,15 +44,43 @@
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
TypePtr type) {
+ ConcreteModuleTypeBuilder builder;
+ builder.setPoisoned();
+
// `type` should either be a module interface or a class type
if (auto interface = type->cast<InterfaceType>()) {
TORCH_INTERNAL_ASSERT(interface->is_module());
} else {
- TORCH_INTERNAL_ASSERT(type->cast<ClassType>());
+ const auto classType = type->expect<ClassType>();
+
+ // Populate the builder metadata from the JIT type. This is to ensure
+ // ConcreteModuleTypes produced from Python and ones produced from a JIT
+ // type directly behave the same to the rest of the system.
+ for (size_t i = 0; i < classType->numAttributes(); i++) {
+ const auto& attrName = classType->getAttributeName(i);
+ const auto& attrType = classType->getAttribute(i);
+ if (attrType->is_module()) {
+ builder.addModule(attrName, ConcreteModuleType::fromJitType(attrType));
+ } else {
+ builder.addAttribute(
+ attrName,
+ attrType,
+ classType->is_parameter(i),
+ classType->is_buffer(i));
+ }
+ }
+
+ for (size_t i = 0; i < classType->numConstants(); i++) {
+ builder.addConstant(
+ classType->getConstantName(i), classType->getConstant(i));
+ }
}
+
+ // Not make_shared because the constructor is private.
auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
ret->jitType_ = std::move(type);
- ret->data_.setPoisoned();
+ ret->data_ = builder;
+
return ret;
}
@@ -198,6 +213,20 @@
void ConcreteModuleTypeBuilder::addConstant(
std::string name,
py::object value) {
+ auto match = tryToInferType(value);
+ if (!match.success()) {
+ TORCH_INTERNAL_ASSERT(
+ false,
+ "We need to infer the type of constant to convert the python value to IValue,"
+ " but failed to infer type of ",
+ py::str(value),
+ "\n:",
+ match.reason());
+ }
+ constants_.emplace(std::move(name), toIValue(value, match.type()));
+}
+
+void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) {
constants_.emplace(std::move(name), std::move(value));
}
@@ -257,7 +286,7 @@
<< py::getattr(data_.pyClass_, "__name__") << "\n";
std::cout << "Constants: \n";
for (const auto& pr : data_.constants_) {
- std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
+ std::cout << "\t" << pr.first << ": " << pr.second << "\n";
}
std::cout << "\nAttributes: \n";
for (const auto& pr : data_.attributes_) {
@@ -286,7 +315,7 @@
// need to bind ConcreteModuleType::Constant as well.
std::unordered_map<std::string, py::object> ret;
for (const auto& pr : data_.constants_) {
- ret.emplace(pr.first, pr.second.v_);
+ ret.emplace(pr.first, toPyObject(pr.second));
}
return ret;
}
diff --git a/torch/csrc/jit/frontend/concrete_module_type.h b/torch/csrc/jit/frontend/concrete_module_type.h
index 4844e95..0410693 100644
--- a/torch/csrc/jit/frontend/concrete_module_type.h
+++ b/torch/csrc/jit/frontend/concrete_module_type.h
@@ -61,7 +61,9 @@
TORCH_INTERNAL_ASSERT(pyClass);
pyClass_ = std::move(pyClass);
}
+
void addConstant(std::string name, py::object value);
+ void addConstant(std::string name, IValue value);
void addAttribute(
std::string name,
TypePtr type,
@@ -94,19 +96,6 @@
// implements a meaningful comparison in that context.
bool equals(const ConcreteModuleTypeBuilder& other) const;
- struct Constant {
- /* implicit */ Constant(py::object v) : v_(std::move(v)) {}
- friend bool operator==(const Constant& lhs, const Constant& rhs) {
- // Perform the equivalent of `lhs == rhs` in Python.
- int rv = PyObject_RichCompareBool(lhs.v_.ptr(), rhs.v_.ptr(), Py_EQ);
- if (rv == -1) {
- throw py::error_already_set();
- }
- return rv == 1;
- }
- py::object v_;
- };
-
struct FunctionAttribute {
FunctionTypePtr function_;
py::object pyFunction_;
@@ -153,7 +142,7 @@
bool isPoisoned_ = false;
// The value of any constants defined by the module.
- std::unordered_map<std::string, Constant> constants_;
+ std::unordered_map<std::string, IValue> constants_;
// The types of any attributes
OrderedDict<std::string, Attribute> attributes_;
// Overloads, in the same format as `__overloads__` in Python
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index 5e5f785..aceb0ba 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -451,10 +451,15 @@
if (selfType->hasAttribute(field) &&
selfType->getAttribute(field)->is_module()) {
// ...if it's a submodule, return it as a new ModuleValue.
- const auto submoduleConcreteType =
- concreteType_->findSubmoduleConcreteType(field);
+ if (const auto submoduleConcreteType =
+ concreteType_->findSubmoduleConcreteType(field)) {
+ return std::make_shared<ModuleValue>(
+ m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
+ }
+
return std::make_shared<ModuleValue>(
- m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
+ m.graph()->insertGetAttr(self_, field),
+ ConcreteModuleType::fromJitType(selfType->getAttribute(field)));
} else if (selfType->hasAttribute(field) || selfType->findMethod(field)) {
// ...otherwise, methods, parameters, attributes, and buffers are all
// first class so they get returned as SimpleValues
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 1b28435..34803f9 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -1533,7 +1533,13 @@
std::shared_ptr<ConcreteModuleTypeBuilder>>(
m, "ConcreteModuleTypeBuilder")
.def(py::init<py::object>())
- .def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
+ .def(
+ "add_constant",
+ [](ConcreteModuleTypeBuilder& self,
+ std::string name,
+ py::object value) {
+ self.addConstant(std::move(name), std::move(value));
+ })
.def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
.def(
"add_function_attribute",