[jit] Better match behavior of loaded ScriptModules vs. freshly created ones (#43298)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43298

IR emitter uses `ModuleValue` to represent ScriptModules and emit IR for
attribute access, submodule access, etc.

`ModuleValue` relies on two pieces of information, the JIT type of the
module, and the `ConcreteModuleType`, which encapsulates Python-only
information about the module.

ScriptModules loaded from a package used to create a dummy
ConcreteModuleType without any info in it. This led to divergences in
behavior during compilation.

This PR makes the two ways of constructing a ConcreteModuleType equivalent,
modulo any py-only information (which, by definition, is never present in
packaged files anyway).

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D23228738

Pulled By: suo

fbshipit-source-id: f6a660f42272640ca1a1bb8c4ee7edfa2d1b07cc
diff --git a/test/test_jit.py b/test/test_jit.py
index fee5340..f070e7b 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -15201,6 +15201,47 @@
         with self.assertRaisesRegex(RuntimeError, 'has no attribute'):
             torch.jit.script(wrapped)
 
+    def test_rescripting_loaded_modules(self):
+        class InnerSubmod(nn.Module):
+            __constants__ = ['my_constant']
+
+            def __init__(self):
+                super().__init__()
+                self.register_buffer("foo", torch.ones(1))
+                self.register_parameter("bar", torch.nn.Parameter(torch.ones(1)))
+                self.baz = torch.ones(1)
+                self.my_constant = 1
+
+            def forward(self, x):
+                return x + x
+
+        class Inner(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.submod = InnerSubmod()
+
+            def forward(self, x):
+                return self.submod(x)
+
+        class Wrapper(nn.Module):
+            def __init__(self, inner):
+                super().__init__()
+                self.inner = inner
+
+            def forward(self, x):
+                # access inner elements
+                ret = self.inner.submod(x) + self.inner.submod.foo + self.inner.submod.bar + self.inner.submod.baz
+                ret = ret + self.inner.submod.my_constant
+                return ret
+
+        inner_module = torch.jit.script(Inner())
+        wrapped = Wrapper(inner_module)
+        self.checkModule(wrapped, torch.ones(1))
+
+        inner_module_loaded = self.getExportImportCopy(inner_module)
+        wrapped_loaded = Wrapper(inner_module_loaded)
+        self.assertEqual(wrapped(torch.ones(1)), wrapped_loaded(torch.ones(1)))
+
 
 # known to be failing in tracer
 EXCLUDE_TRACED = {
diff --git a/torch/csrc/jit/frontend/concrete_module_type.cpp b/torch/csrc/jit/frontend/concrete_module_type.cpp
index 96433bc..169b589 100644
--- a/torch/csrc/jit/frontend/concrete_module_type.cpp
+++ b/torch/csrc/jit/frontend/concrete_module_type.cpp
@@ -29,20 +29,7 @@
   }
 
   for (const auto& pr : constants_) {
-    const auto& name = pr.first;
-    const auto& val = pr.second.v_;
-    auto match = tryToInferType(val);
-    if (!match.success()) {
-      TORCH_INTERNAL_ASSERT(
-          false,
-          "We need to infer the type of constant to convert the python value to IValue, but failed to infer type of ",
-          py::str(val),
-          "\n:",
-          match.reason());
-    }
-    // Validation and conversion to make sure `val` is a valid constant
-    // is done in python, see `torch/jit/_recursive.py`
-    cls->addConstant(name, toIValue(val, match.type()));
+    cls->addConstant(pr.first, pr.second);
   }
 
   for (const auto& moduleInfo : modules_) {
@@ -57,15 +44,43 @@
 
 std::shared_ptr<ConcreteModuleType> ConcreteModuleType::fromJitType(
     TypePtr type) {
+  ConcreteModuleTypeBuilder builder;
+  builder.setPoisoned();
+
   // `type` should either be a module interface or a class type
   if (auto interface = type->cast<InterfaceType>()) {
     TORCH_INTERNAL_ASSERT(interface->is_module());
   } else {
-    TORCH_INTERNAL_ASSERT(type->cast<ClassType>());
+    const auto classType = type->expect<ClassType>();
+
+    // Populate the builder metadata from the JIT type. This is to ensure
+    // ConcreteModuleTypes produced from Python and ones produced from a JIT
+    // type directly behave the same to the rest of the system.
+    for (size_t i = 0; i < classType->numAttributes(); i++) {
+      const auto& attrName = classType->getAttributeName(i);
+      const auto& attrType = classType->getAttribute(i);
+      if (attrType->is_module()) {
+        builder.addModule(attrName, ConcreteModuleType::fromJitType(attrType));
+      } else {
+        builder.addAttribute(
+            attrName,
+            attrType,
+            classType->is_parameter(i),
+            classType->is_buffer(i));
+      }
+    }
+
+    for (size_t i = 0; i < classType->numConstants(); i++) {
+      builder.addConstant(
+          classType->getConstantName(i), classType->getConstant(i));
+    }
   }
+
+  // Not make_shared because the constructor is private.
   auto ret = std::shared_ptr<ConcreteModuleType>(new ConcreteModuleType());
   ret->jitType_ = std::move(type);
-  ret->data_.setPoisoned();
+  ret->data_ = builder;
+
   return ret;
 }
 
@@ -198,6 +213,20 @@
 void ConcreteModuleTypeBuilder::addConstant(
     std::string name,
     py::object value) {
+  auto match = tryToInferType(value);
+  if (!match.success()) {
+    TORCH_INTERNAL_ASSERT(
+        false,
+        "We need to infer the type of constant to convert the python value to IValue,"
+        " but failed to infer type of ",
+        py::str(value),
+        "\n:",
+        match.reason());
+  }
+  constants_.emplace(std::move(name), toIValue(value, match.type()));
+}
+
+void ConcreteModuleTypeBuilder::addConstant(std::string name, IValue value) {
   constants_.emplace(std::move(name), std::move(value));
 }
 
@@ -257,7 +286,7 @@
             << py::getattr(data_.pyClass_, "__name__") << "\n";
   std::cout << "Constants: \n";
   for (const auto& pr : data_.constants_) {
-    std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
+    std::cout << "\t" << pr.first << ": " << pr.second << "\n";
   }
   std::cout << "\nAttributes: \n";
   for (const auto& pr : data_.attributes_) {
@@ -286,7 +315,7 @@
   // need to bind ConcreteModuleType::Constant as well.
   std::unordered_map<std::string, py::object> ret;
   for (const auto& pr : data_.constants_) {
-    ret.emplace(pr.first, pr.second.v_);
+    ret.emplace(pr.first, toPyObject(pr.second));
   }
   return ret;
 }
diff --git a/torch/csrc/jit/frontend/concrete_module_type.h b/torch/csrc/jit/frontend/concrete_module_type.h
index 4844e95..0410693 100644
--- a/torch/csrc/jit/frontend/concrete_module_type.h
+++ b/torch/csrc/jit/frontend/concrete_module_type.h
@@ -61,7 +61,9 @@
     TORCH_INTERNAL_ASSERT(pyClass);
     pyClass_ = std::move(pyClass);
   }
+
   void addConstant(std::string name, py::object value);
+  void addConstant(std::string name, IValue value);
   void addAttribute(
       std::string name,
       TypePtr type,
@@ -94,19 +96,6 @@
   // implements a meaningful comparison in that context.
   bool equals(const ConcreteModuleTypeBuilder& other) const;
 
-  struct Constant {
-    /* implicit */ Constant(py::object v) : v_(std::move(v)) {}
-    friend bool operator==(const Constant& lhs, const Constant& rhs) {
-      // Perform the equivalent of `lhs == rhs` in Python.
-      int rv = PyObject_RichCompareBool(lhs.v_.ptr(), rhs.v_.ptr(), Py_EQ);
-      if (rv == -1) {
-        throw py::error_already_set();
-      }
-      return rv == 1;
-    }
-    py::object v_;
-  };
-
   struct FunctionAttribute {
     FunctionTypePtr function_;
     py::object pyFunction_;
@@ -153,7 +142,7 @@
   bool isPoisoned_ = false;
 
   // The value of any constants defined by the module.
-  std::unordered_map<std::string, Constant> constants_;
+  std::unordered_map<std::string, IValue> constants_;
   // The types of any attributes
   OrderedDict<std::string, Attribute> attributes_;
   // Overloads, in the same format as `__overloads__` in Python
diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp
index 5e5f785..aceb0ba 100644
--- a/torch/csrc/jit/python/python_sugared_value.cpp
+++ b/torch/csrc/jit/python/python_sugared_value.cpp
@@ -451,10 +451,15 @@
   if (selfType->hasAttribute(field) &&
       selfType->getAttribute(field)->is_module()) {
     // ...if it's a submodule, return it as a new ModuleValue.
-    const auto submoduleConcreteType =
-        concreteType_->findSubmoduleConcreteType(field);
+    if (const auto submoduleConcreteType =
+            concreteType_->findSubmoduleConcreteType(field)) {
+      return std::make_shared<ModuleValue>(
+          m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
+    }
+
     return std::make_shared<ModuleValue>(
-        m.graph()->insertGetAttr(self_, field), submoduleConcreteType);
+        m.graph()->insertGetAttr(self_, field),
+        ConcreteModuleType::fromJitType(selfType->getAttribute(field)));
   } else if (selfType->hasAttribute(field) || selfType->findMethod(field)) {
     // ...otherwise, methods, parameters, attributes, and buffers are all
     // first class so they get returned as SimpleValues
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 1b28435..34803f9 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -1533,7 +1533,13 @@
       std::shared_ptr<ConcreteModuleTypeBuilder>>(
       m, "ConcreteModuleTypeBuilder")
       .def(py::init<py::object>())
-      .def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
+      .def(
+          "add_constant",
+          [](ConcreteModuleTypeBuilder& self,
+             std::string name,
+             py::object value) {
+            self.addConstant(std::move(name), std::move(value));
+          })
       .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
       .def(
           "add_function_attribute",