Split ConcreteModuleType into two types (#29824)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29824

We have two distinct phases/uses for ConcreteModuleType:
1. We are building it up and using it to check whether we can
reuse JIT types. (RawConcreteModuleType)
2. We are using it to satisfy ModuleValue::attr queries.
(ConcreteModuleType)

These types share an underlying `ConcreteModuleTypeData` which
actually stores the relevant info.

Previously they were the same type because I was lazy, but it's been the
source of a bug. So split them to formalize the differing invariants for
the two phases.

Test Plan: Imported from OSS

Differential Revision: D18575010

Pulled By: suo

fbshipit-source-id: 3e4ebcd36e78b947150d8f0dbb74ecccad23e7c4
diff --git a/torch/csrc/jit/script/concrete_module_type.cpp b/torch/csrc/jit/script/concrete_module_type.cpp
index 2ab9e90..1e0b96b 100644
--- a/torch/csrc/jit/script/concrete_module_type.cpp
+++ b/torch/csrc/jit/script/concrete_module_type.cpp
@@ -3,16 +3,7 @@
 namespace torch {
 namespace jit {
 namespace script {
-
-ClassTypePtr ConcreteModuleType::getJitType() const {
-  TORCH_INTERNAL_ASSERT(jitType_);
-  return jitType_;
-}
-
-ClassTypePtr ConcreteModuleType::createNewTypeFromThis() {
-  TORCH_INTERNAL_ASSERT(!jitType_);
-  TORCH_INTERNAL_ASSERT(pyClass_);
-
+ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
   auto cu = get_python_cu();
   py::object pyQualName = py::module::import("torch._jit_internal")
                               .attr("_qualified_name")(pyClass_);
@@ -41,21 +32,85 @@
         moduleInfo.name_, moduleInfo.getJitType(), /*is_parameter=*/false);
   }
 
-  jitType_ = std::move(cls);
+  return cls;
+}
+
+ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data)
+    : data_(std::move(data)) {
+  jitType_ = data_.createTypeFromThis();
+}
+
+TypePtr ConcreteModuleTypeBuilder::ModuleInfo::getJitType() const {
+  return meta_ == nullptr ? type_ : meta_->getJitType();
+}
+
+bool operator==(
+    const ConcreteModuleTypeBuilder::ModuleInfo& lhs,
+    const ConcreteModuleTypeBuilder::ModuleInfo& rhs) {
+  if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
+    return lhs.meta_->equals(*rhs.meta_);
+  } else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
+    return *(lhs.type_) == *(rhs.type_);
+  } else {
+    return false;
+  }
+}
+
+bool ConcreteModuleTypeBuilder::equals(
+    const ConcreteModuleTypeBuilder& other) const {
+  if (isPoisoned_ || other.isPoisoned_) {
+    return false;
+  }
+
+  // clang-format off
+    // These are vaguely ordered so that cheap, discriminating checks happen first.
+    bool equal =
+      pyClass_.is(other.pyClass_) &&
+      iterableModuleKind_ == other.iterableModuleKind_ &&
+      constants_ == other.constants_ &&
+      attributes_ == other.attributes_ &&
+      overloads_ == other.overloads_ &&
+      functionAttributes_ == other.functionAttributes_;
+  // clang-format on
+  if (!equal) {
+    return false;
+  }
+
+  // We store modules in order of insertion (to make compilation
+  // deterministic). However, for the purposes of equality, insertion order
+  // should not matter, so sort them by name.
+  // We put this check last because it involves the most work.
+  auto thisSorted = modules_;
+  std::sort(
+      thisSorted.begin(),
+      thisSorted.end(),
+      [](const ModuleInfo& a, const ModuleInfo& b) {
+        return a.name_ < b.name_;
+      });
+
+  auto otherSorted = other.modules_;
+  std::sort(
+      otherSorted.begin(),
+      otherSorted.end(),
+      [](const ModuleInfo& a, const ModuleInfo& b) {
+        return a.name_ < b.name_;
+      });
+
+  return thisSorted == otherSorted;
+}
+
+ClassTypePtr ConcreteModuleType::getJitType() const {
   return jitType_;
 }
 
 py::object ConcreteModuleType::getPyClass() const {
-  TORCH_INTERNAL_ASSERT(jitType_);
-  TORCH_INTERNAL_ASSERT(pyClass_);
-  return pyClass_;
+  return data_.pyClass_;
 }
 
 c10::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
     const std::string& name) const {
-  TORCH_INTERNAL_ASSERT(jitType_);
-  const auto it = overloads_.find(name);
-  if (it != overloads_.end()) {
+  const auto it = data_.overloads_.find(name);
+  if (it != data_.overloads_.end()) {
     return it->second;
   }
   return c10::nullopt;
@@ -63,9 +118,8 @@
 
 c10::optional<Function*> ConcreteModuleType::findFunctionAttribute(
     const std::string& name) const {
-  TORCH_INTERNAL_ASSERT(jitType_);
-  const auto it = functionAttributes_.find(name);
-  if (it != functionAttributes_.end()) {
+  const auto it = data_.functionAttributes_.find(name);
+  if (it != data_.functionAttributes_.end()) {
     return it->second.function_->function();
   }
   return c10::nullopt;
@@ -73,9 +127,8 @@
 
 c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
     const std::string& name) const {
-  TORCH_INTERNAL_ASSERT(jitType_);
-  const auto it = failedAttributes_.find(name);
-  if (it != failedAttributes_.end()) {
+  const auto it = data_.failedAttributes_.find(name);
+  if (it != data_.failedAttributes_.end()) {
     return it->second;
   }
   return c10::nullopt;
@@ -83,130 +136,114 @@
 
 std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
     findSubmoduleConcreteType(const std::string& name) const {
-  TORCH_INTERNAL_ASSERT(jitType_);
   const auto it = std::find_if(
-      modules_.cbegin(), modules_.cend(), [&](const ModuleInfo& info) {
+      data_.modules_.cbegin(),
+      data_.modules_.cend(),
+      [&](const ConcreteModuleTypeBuilder::ModuleInfo& info) {
         return info.name_ == name;
       });
-  if (it == modules_.end()) {
+  if (it == data_.modules_.end()) {
     return nullptr;
   }
   return it->meta_;
 }
 
-void ConcreteModuleType::setIterableModuleKind(IterableModuleKind kind) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
+void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) {
   iterableModuleKind_ = kind;
 }
 
 IterableModuleKind ConcreteModuleType::getIterableModuleKind() const {
-  TORCH_INTERNAL_ASSERT(jitType_);
-  return iterableModuleKind_;
+  return data_.iterableModuleKind_;
 }
 
-void ConcreteModuleType::setPoisoned() {
-  TORCH_INTERNAL_ASSERT(!jitType_)
+void ConcreteModuleTypeBuilder::setPoisoned() {
   isPoisoned_ = true;
 }
 
-void ConcreteModuleType::addJitType(ClassTypePtr type) {
-  TORCH_INTERNAL_ASSERT(!jitType_)
-  jitType_ = std::move(type);
-}
-
-void ConcreteModuleType::addPyClass(py::object pyClass) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
-  pyClass_ = std::move(pyClass);
-}
-
-void ConcreteModuleType::addConstant(std::string name, py::object value) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
+void ConcreteModuleTypeBuilder::addConstant(std::string name, py::object value) {
   constants_.emplace(std::move(name), std::move(value));
 }
 
-void ConcreteModuleType::addAttribute(
+void ConcreteModuleTypeBuilder::addAttribute(
     std::string name,
     TypePtr type,
     bool isParameter) {
   TORCH_INTERNAL_ASSERT(type);
-  TORCH_INTERNAL_ASSERT(!jitType_);
   // Function attributes should be handled separately
   TORCH_INTERNAL_ASSERT(type->cast<FunctionType>() == nullptr);
   attributes_.emplace(
-      std::move(name), Attribute(unshapedType(type), isParameter));
+      std::move(name),
+      ConcreteModuleTypeBuilder::Attribute(unshapedType(type), isParameter));
 }
 
-void ConcreteModuleType::addFunctionAttribute(
+void ConcreteModuleTypeBuilder::addFunctionAttribute(
     std::string name,
     const TypePtr& type,
     py::object pyFunction) {
   TORCH_INTERNAL_ASSERT(type);
-  TORCH_INTERNAL_ASSERT(!jitType_);
   functionAttributes_.emplace(
       std::move(name),
-      FunctionAttribute{type->expect<FunctionType>(), std::move(pyFunction)});
+      ConcreteModuleTypeBuilder::FunctionAttribute{type->expect<FunctionType>(),
+                                                std::move(pyFunction)});
 }
 
-void ConcreteModuleType::addModule(
+void ConcreteModuleTypeBuilder::addModule(
     std::string name,
     std::shared_ptr<ConcreteModuleType> meta) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
-  modules_.emplace_back(ModuleInfo{std::move(name), std::move(meta)});
+  modules_.emplace_back(
+      ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), std::move(meta)});
 }
 
-void ConcreteModuleType::addModuleInterface(
+void ConcreteModuleTypeBuilder::addModuleInterface(
     std::string name,
     const TypePtr& type) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
   TORCH_INTERNAL_ASSERT(type->cast<InterfaceType>() && type->is_module());
-  modules_.emplace_back(ModuleInfo{std::move(name), type});
+  modules_.emplace_back(
+      ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), type});
 }
 
-
-void ConcreteModuleType::addOverload(
+void ConcreteModuleTypeBuilder::addOverload(
     std::string methodName,
     std::vector<std::string> overloadedMethodNames) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
   overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames));
 }
 
-void ConcreteModuleType::addFailedAttribute(
+void ConcreteModuleTypeBuilder::addFailedAttribute(
     std::string name,
     std::string failureReason) {
-  TORCH_INTERNAL_ASSERT(!jitType_);
   failedAttributes_.emplace(std::move(name), std::move(failureReason));
 }
 
 c10::optional<py::object> ConcreteModuleType::findConstant(
     const std::string& name) const {
-  auto it = constants_.find(name);
-  if (it != constants_.end()) {
+  auto it = data_.constants_.find(name);
+  if (it != data_.constants_.end()) {
     return it->second.v_;
   }
   return c10::nullopt;
 }
 
 void ConcreteModuleType::dump() const {
-  std::cout << "ConcreteModuleType for: " << py::getattr(pyClass_, "__name__") << "\n";
+  std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n";
   std::cout << "Constants: \n";
-  for (const auto& pr : constants_) {
+  for (const auto& pr : data_.constants_) {
     std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
   }
   std::cout << "\nAttributes: \n";
-  for (const auto& pr : attributes_) {
+  for (const auto& pr : data_.attributes_) {
     std::cout << "\t" << pr.first << ": " << pr.second.type_->python_str()
               << "\n";
   }
   std::cout << "\nSubmodules: \n";
-  for (const auto& info : modules_) {
+  for (const auto& info : data_.modules_) {
     std::cout << "\t" << info.name_ << ": "
           << info.getJitType()->python_str() << "\n";
   }
   std::cout << "\nOverloads: \n";
-  for (const auto& pr : overloads_) {
+  for (const auto& pr : data_.overloads_) {
     std::cout << "\t" << pr.first << ": " << pr.second << "\n";
   }
-  std::string isPoisoned = isPoisoned_ ? "true" : "false";
+  std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
   std::cout << "isPoisoned: " << isPoisoned << "\n";
   if (jitType_) {
     std::cout << "jit type: " << jitType_->python_str() << "\n";
@@ -215,11 +252,10 @@
 
 std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
     const {
-  TORCH_INTERNAL_ASSERT(jitType_);
   // Convert to a more pybind-friendly representation, so we don't
   // need to bind ConcreteModuleType::Constant as well.
   std::unordered_map<std::string, py::object> ret;
-  for (const auto& pr : constants_) {
+  for (const auto& pr : data_.constants_) {
     ret.emplace(pr.first, pr.second.v_);
   }
   return ret;
@@ -227,11 +263,10 @@
 
 std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
     getAttributesPy() const {
-  TORCH_INTERNAL_ASSERT(jitType_);
   // Convert to a more pybind-friendly representation, so we don't
   // need to bind ConcreteModuleType::Attribute as well.
   std::unordered_map<std::string, std::pair<TypePtr, bool>> ret;
-  for (auto& pr : attributes_) {
+  for (auto& pr : data_.attributes_) {
     ret.emplace(
         pr.first,
         std::pair<TypePtr, bool>(pr.second.type_, pr.second.isParam_));
@@ -241,10 +276,9 @@
 
 std::vector<std::pair<std::string, TypePtr>> ConcreteModuleType::getModulesPy()
     const {
-  TORCH_INTERNAL_ASSERT(jitType_);
   std::vector<std::pair<std::string, TypePtr>> ret;
 
-  for (const ModuleInfo& info: modules_) {
+  for (const auto& info : data_.modules_) {
     ret.emplace_back(std::make_pair(info.name_, info.getJitType()));
   }
   return ret;
diff --git a/torch/csrc/jit/script/concrete_module_type.h b/torch/csrc/jit/script/concrete_module_type.h
index d098c40..a587772 100644
--- a/torch/csrc/jit/script/concrete_module_type.h
+++ b/torch/csrc/jit/script/concrete_module_type.h
@@ -9,7 +9,9 @@
 namespace torch {
 namespace jit {
 namespace script {
+
 enum class IterableModuleKind { NONE, LIST, DICT };
+class ConcreteModuleType;
 
 // You can think of an nn.Module as a template that corresponds to a family of
 // JIT types. The template "arguments" are things like the constant values.
@@ -41,21 +43,22 @@
 // ModuleValue::attr calls. This is so we can guarantee that if two Module's
 // share a JIT type (and thus a ConcreteModuleType), then they behave the same
 // way when you access attributes on them.
-class VISIBILITY_HIDDEN ConcreteModuleType {
- public:
-  // ConcreteModuleType has two states.
-  // 1. Building: First we build it up, during the ScriptModule conversion
-  // process
-  //      ... to transition, we freeze the type by associating with a JIT type.
-  // 2. Querying: Then we ask use it as a source of truth during method
-  // compilation.
-  //              During this time the ModuleType is effectively const.
-  // Yes, it could be two different types. Not terribly worth the verbosity.
 
-  /**
-   * Builder methods (jitType must be null)
-   */
-  void addPyClass(py::object pyClass);
+// ConcreteModuleType has two phases.
+// 1. Creation: First we build it up, during the ScriptModule conversion
+// process. This is represented by ConcreteModuleTypeBuilder.
+//    ...then the converter calls ConcreteModuleTypeBuilder::build(), producing a
+//       ConcreteModuleType ready for querying.
+// 2. Querying: We use ConcreteModuleType as a source of truth for
+// ModuleValue::attr calls during method compilation.
+
+// Represents a concrete type during in the process for construction. We use
+// this to decide whether we can share types between modules.
+class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
+ public:
+  explicit ConcreteModuleTypeBuilder(py::object pyClass) {
+    pyClass_ = std::move(pyClass);
+  }
   void addConstant(std::string name, py::object value);
   void addAttribute(std::string name, TypePtr type, bool isParameter);
   void addFunctionAttribute(
@@ -73,95 +76,20 @@
       std::vector<std::string> overloadedMethodNames);
   void addFailedAttribute(std::string name, std::string failureReason);
   void setIterableModuleKind(IterableModuleKind kind);
+
+  // If a ConcreteModuleType is poisoned, it will never compare equal to any other
+  // concrete type
   void setPoisoned();
 
-  /**
-   * Freezing methods
-   */
-
-  // Based on the data in this ConcreteType, create an equivalent JIT type and
-  // associate this module type with it.
-  ClassTypePtr createNewTypeFromThis();
-  // Associate the provided type with this ConcreteType
-  void addJitType(ClassTypePtr type);
-
-  /**
-   * Query methods (jitType must be non-null)
-   */
-  ClassTypePtr getJitType() const;
-  py::object getPyClass() const;
-  IterableModuleKind getIterableModuleKind() const;
-  c10::optional<py::object> findConstant(const std::string& name) const;
-  c10::optional<std::vector<std::string>> findOverloads(
-      const std::string& name) const;
-  c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
-  std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
-      const std::string& name) const;
-  c10::optional<std::string> findFailedAttribute(const std::string& name) const;
-
-  // These getters are only here to return things as types that can be
-  // automatically converted by pybind.
-  std::unordered_map<std::string, py::object> getConstantsPy() const;
-  std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
-      const;
-  std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
+  std::shared_ptr<ConcreteModuleType> build() const {
+    return std::make_shared<ConcreteModuleType>(*this);
+  }
 
   // This determines whether two modules can share a type. The container structs
   // used by ConcreteModuleType have been defined such that operator==
   // implements a meaningful comparison in that context.
-  friend bool operator==(
-      const ConcreteModuleType& lhs,
-      const ConcreteModuleType& rhs) {
-    if (lhs.jitType_ == rhs.jitType_) {
-      // If the computed types are the same, these modules can (obviously) share
-      // a type.
-      return true;
-    }
+  bool equals(const ConcreteModuleTypeBuilder& other) const;
 
-    if (lhs.isPoisoned_ || rhs.isPoisoned_) {
-      return false;
-    }
-
-    // clang-format off
-    // These are vaguely ordered so that cheap, discriminating checks happen first.
-    bool equal =
-      lhs.pyClass_.is(rhs.pyClass_) &&
-      lhs.iterableModuleKind_ == rhs.iterableModuleKind_ &&
-      lhs.constants_ == rhs.constants_ &&
-      lhs.attributes_ == rhs.attributes_ &&
-      lhs.overloads_ == rhs.overloads_ &&
-      lhs.functionAttributes_ == rhs.functionAttributes_;
-    // clang-format on
-    if (!equal) {
-      return false;
-    }
-
-    // We store modules in order of insertion (to make compilation
-    // deterministic). However, for the purposes of equality, insertion order
-    // should not matter, so sort them by name.
-    // We put this check last because it involves the most work.
-    auto lhsSorted = lhs.modules_;
-    std::sort(
-        lhsSorted.begin(),
-        lhsSorted.end(),
-        [](const ModuleInfo& a, const ModuleInfo& b) {
-          return a.name_ < b.name_;
-        });
-
-    auto rhsSorted = rhs.modules_;
-    std::sort(
-        rhsSorted.begin(),
-        rhsSorted.end(),
-        [](const ModuleInfo& a, const ModuleInfo& b) {
-          return a.name_ < b.name_;
-        });
-
-    return lhsSorted == rhsSorted;
-  }
-
-  void dump() const;
-
- private:
   struct Constant {
     /* implicit */ Constant(py::object v) : v_(std::move(v)) {}
     friend bool operator==(const Constant& lhs, const Constant& rhs) {
@@ -206,28 +134,18 @@
 
     ModuleInfo(std::string name, const TypePtr& type)
         : name_(std::move(name)), meta_(nullptr), type_(type) {}
-    friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs) {
-      if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
-        return *(lhs.meta_) == *(rhs.meta_);
-      } else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
-        return  *(lhs.type_) == *(rhs.type_);
-      } else {
-        return false;
-      }
-    }
-
-    TypePtr getJitType() const {
-      return meta_ == nullptr? type_ : meta_->getJitType();
-    }
+    TypePtr getJitType() const;
     std::string name_;
+    friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs);
 
     // Module Info contains either an ConcreateModuleType or a type (which is
     // a Module Interface), these two are union relationship.
     std::shared_ptr<ConcreteModuleType> meta_;
     TypePtr type_;
-
   };
 
+ private:
+  ClassTypePtr createTypeFromThis() const;
   // If true, this type will never compare equally to anything else. This is
   // used if we want to ensure that this type is not shared (for example, if it
   // came from a traced module)
@@ -256,11 +174,54 @@
   // The original `nn.Module` class that we derived this ScriptModule from.
   py::object pyClass_;
 
-  // The JIT type derived from this ConcreteModuleType.
-  ClassTypePtr jitType_ = nullptr;
   // NOTE: If you ever add any more state to this struct, you need to make sure
-  // operator== still makes sense! The only field that can be excluded from it
-  // is `jitType_`.
+  // operator== still makes sense!
+  friend ConcreteModuleType;
+};
+
+// Represents a finalized concrete type, used to service ModuleValue::attr calls
+// during method compilation.
+class VISIBILITY_HIDDEN ConcreteModuleType {
+ public:
+  explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
+
+  ClassTypePtr getJitType() const;
+  py::object getPyClass() const;
+  IterableModuleKind getIterableModuleKind() const;
+  c10::optional<py::object> findConstant(const std::string& name) const;
+  c10::optional<std::vector<std::string>> findOverloads(
+      const std::string& name) const;
+  c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
+  std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
+      const std::string& name) const;
+  c10::optional<std::string> findFailedAttribute(const std::string& name) const;
+
+  // These getters are only here to return things as types that can be
+  // automatically converted by pybind.
+  std::unordered_map<std::string, py::object> getConstantsPy() const;
+  std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
+      const;
+  std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
+
+  bool equals(const ConcreteModuleType& other) const {
+    if (jitType_ == other.jitType_) {
+      // If the computed types are the same, these modules can (obviously) share
+      // a type.
+      return true;
+    }
+
+    return data_.equals(other.data_);
+  }
+  bool equals(const ConcreteModuleTypeBuilder& other) const {
+    return data_.equals(other);
+  }
+
+  void dump() const;
+
+ private:
+  // The JIT type derived from this ConcreteModuleType.
+  ConcreteModuleTypeBuilder data_;
+  ClassTypePtr jitType_;
 };
 
 } // namespace script
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index beb7e79..f495fd2 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -1052,42 +1052,51 @@
     return Module(get_python_cu(), type);
   });
 
+  py::class_<ConcreteModuleTypeBuilder, std::shared_ptr<ConcreteModuleTypeBuilder>>(
+      m, "ConcreteModuleTypeBuilder")
+      .def(py::init<py::object>())
+      .def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
+      .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
+      .def(
+          "add_function_attribute",
+          &ConcreteModuleTypeBuilder::addFunctionAttribute)
+      .def("add_module", &ConcreteModuleTypeBuilder::addModule)
+      .def("add_module_interface", &ConcreteModuleTypeBuilder::addModuleInterface)
+      .def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
+      .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
+      .def("add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute)
+      .def(
+          "set_module_dict",
+          [](ConcreteModuleTypeBuilder& self) {
+            self.setIterableModuleKind(IterableModuleKind::DICT);
+          })
+      .def("build", &ConcreteModuleTypeBuilder::build)
+      .def(
+          "equals",
+          [](const ConcreteModuleTypeBuilder& self,
+             const ConcreteModuleTypeBuilder& other) { return self.equals(other); })
+      .def("set_module_list", [](ConcreteModuleTypeBuilder& self) {
+        self.setIterableModuleKind(IterableModuleKind::LIST);
+      });
+
   py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
       m, "ConcreteModuleType")
-      .def(py::init<>())
       .def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
       .def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
       .def("get_constants", &ConcreteModuleType::getConstantsPy)
       .def("get_attributes", &ConcreteModuleType::getAttributesPy)
       .def("get_modules", &ConcreteModuleType::getModulesPy)
-      .def("add_constant", &ConcreteModuleType::addConstant)
-      .def("add_attribute", &ConcreteModuleType::addAttribute)
-      .def("add_function_attribute", &ConcreteModuleType::addFunctionAttribute)
-      .def("add_module", &ConcreteModuleType::addModule)
-      .def("add_module_interface", &ConcreteModuleType::addModuleInterface)
-      .def("add_pyclass", &ConcreteModuleType::addPyClass)
-      .def("add_overload", &ConcreteModuleType::addOverload)
-      .def("add_jit_type", &ConcreteModuleType::addJitType)
-      .def("set_poisoned", &ConcreteModuleType::setPoisoned)
-      .def(
-          "set_module_dict",
-          [](ConcreteModuleType& self) {
-            self.setIterableModuleKind(IterableModuleKind::DICT);
-          })
-      .def(
-          "set_module_list",
-          [](ConcreteModuleType& self) {
-            self.setIterableModuleKind(IterableModuleKind::LIST);
-          })
-      .def(
-          "create_new_type_from_this",
-          &ConcreteModuleType::createNewTypeFromThis)
-      .def("add_failed_attribute", &ConcreteModuleType::addFailedAttribute)
       .def("dump", &ConcreteModuleType::dump)
       .def(
           "equals",
           [](const ConcreteModuleType& self, const ConcreteModuleType& other) {
-            return self == other;
+            return self.equals(other);
+          })
+      .def(
+          "equals",
+          [](const ConcreteModuleType& self,
+             const ConcreteModuleTypeBuilder& other) {
+            return self.equals(other);
           })
       .def(
           "_create_methods",
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index 4b1b6b5..277dab1 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -59,18 +59,17 @@
         3. a list or tuple of (2)
         """.format(type(v).__name__, attr, constants)))
 
-def infer_raw_concrete_type(nn_module):
+def infer_concrete_type_builder(nn_module):
     """
-    Build a ConcreteModuleType from an nn.Module. This ConcreteModuleType
-    doesn't have a JIT type associated with it yet, it must be filled in
-    by the caller.
+    Build a ConcreteModuleTypeBuilder from an nn.Module. This
+    ConcreteModuleType doesn't have a JIT type associated with it yet, it
+    must be filled in by the caller.
     """
-    concrete_type = torch._C.ConcreteModuleType()
-    concrete_type.add_pyclass(type(nn_module))
+    concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
     if isinstance(nn_module, (torch.nn.ModuleDict)):
-        concrete_type.set_module_dict()
+        concrete_type_builder.set_module_dict()
     if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
-        concrete_type.set_module_list()
+        concrete_type_builder.set_module_list()
 
     class_annotations = getattr(nn_module, '__annotations__', {})
 
@@ -96,7 +95,7 @@
             continue
         assert isinstance(item, torch.Tensor)
         attr_type = infer_type(name, item)
-        concrete_type.add_attribute(name, attr_type, True)
+        concrete_type_builder.add_attribute(name, attr_type, True)
         added_names.add(name)
 
     for name, item in nn_module._buffers.items():
@@ -109,18 +108,18 @@
             continue
         assert isinstance(item, torch.Tensor)
         attr_type = infer_type(name, item)
-        concrete_type.add_attribute(name, attr_type, False)
+        concrete_type_builder.add_attribute(name, attr_type, False)
         added_names.add(name)
 
     for name, item in nn_module._modules.items():
         attr_type = infer_type(name, item)
         if attr_type is not None:
             # if the type can be inferred, it should be a module interface type
-            concrete_type.add_module_interface(name, attr_type)
+            concrete_type_builder.add_module_interface(name, attr_type)
         else:
             # otherwise we get the concrete module type for item and add it to concrete_type
             sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
-            concrete_type.add_module(name, sub_concrete_type)
+            concrete_type_builder.add_module(name, sub_concrete_type)
 
         added_names.add(name)
 
@@ -146,7 +145,7 @@
                           "Consider removing it.".format(name))
             continue
         value = getattr(nn_module, name)
-        concrete_type.add_constant(name, _get_valid_constant(name, value))
+        concrete_type_builder.add_constant(name, _get_valid_constant(name, value))
         added_names.add(name)
 
     # populate overloads
@@ -154,7 +153,7 @@
     # update with any annotated overloads
     overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
     for name, overloaded_names in overloads.items():
-        concrete_type.add_overload(name, overloaded_names)
+        concrete_type_builder.add_overload(name, overloaded_names)
 
 
     for name, value in nn_module.__dict__.items():
@@ -171,7 +170,7 @@
         if inspect.isfunction(value):
             try:
                 scripted_fn = torch.jit.script(value)
-                concrete_type.add_function_attribute(
+                concrete_type_builder.add_function_attribute(
                     name,
                     torch._C._jit_try_infer_type(scripted_fn),
                     value)
@@ -182,14 +181,14 @@
                 hint = ("(This function exists as an attribute on the Python module, "
                         "but we failed to compile it to a TorchScript function. "
                         "\nThe error stack is reproduced here:\n{}").format(e)
-                concrete_type.add_failed_attribute(name, hint)
+                concrete_type_builder.add_failed_attribute(name, hint)
                 pass
 
             continue
 
         # Handle Script function attributes
         if isinstance(value, torch.jit.ScriptFunction):
-            concrete_type.add_function_attribute(
+            concrete_type_builder.add_function_attribute(
                 name,
                 torch._C._jit_try_infer_type(value),
                 value)
@@ -198,14 +197,14 @@
         # If we got here, this is a regular "data" attribute, Add it to the concrete type
         attr_type = infer_type(name, value)
         if attr_type is not None:
-            concrete_type.add_attribute(name, attr_type, False)
+            concrete_type_builder.add_attribute(name, attr_type, False)
         else:
             # TODO: could add more detail here. For example, what the user should do
             # when the pytype is `list` or `NoneType`
             hint = ("(This attribute exists on the Python module, "
                     "but we failed to convert Python type: '{}' "
                     "to a TorchScript type.)").format(type(value).__name__)
-            concrete_type.add_failed_attribute(name, hint)
+            concrete_type_builder.add_failed_attribute(name, hint)
 
     # Add @property methods as failed attributes, to give a better error message.
     for name, value in type(nn_module).__dict__.items():
@@ -213,9 +212,9 @@
             hint = ("\n(This attribute exists on the Python module, but it's an @property "
                     "method. @property methods are not yet supported in TorchScript. "
                     "Please file a feature request on Github)")
-            concrete_type.add_failed_attribute(name, hint)
+            concrete_type_builder.add_failed_attribute(name, hint)
 
-    return concrete_type
+    return concrete_type_builder
 
 class ConcreteTypeStore(object):
     def __init__(self):
@@ -234,7 +233,7 @@
                 hasattr(nn_module, "_concrete_type"):
             return nn_module._concrete_type
 
-        raw_concrete_type = infer_raw_concrete_type(nn_module)
+        concrete_type_builder = infer_concrete_type_builder(nn_module)
 
         nn_module_type = type(nn_module)
         if nn_module_type not in self.type_store:
@@ -243,13 +242,13 @@
         # Search the type store for an already-available JIT type
         known_types = self.type_store[nn_module_type]
         for known_type in known_types:
-            if raw_concrete_type.equals(known_type):
+            if known_type.equals(concrete_type_builder):
                 return known_type
 
         # We didn't find anything; generate a new JIT type from this concrete type
-        raw_concrete_type.create_new_type_from_this()
-        self.type_store[nn_module_type].append(raw_concrete_type)
-        return raw_concrete_type
+        concrete_type = concrete_type_builder.build()
+        self.type_store[nn_module_type].append(concrete_type)
+        return concrete_type
 
 concrete_type_store = ConcreteTypeStore()
 
@@ -272,11 +271,12 @@
         stubs:  ScriptMethodStubs to compile as part of the conversion process.
     """
     check_module_initialized(nn_module)
-    # Get a ConcreteType without a JIT type. We will generate one ourselves
-    # and fill it in.
-    concrete_type = infer_raw_concrete_type(nn_module)
-    concrete_type.set_poisoned()
-    concrete_type.create_new_type_from_this()
+    # Get a concrete type directly, without trying to re-use an existing JIT
+    # type from the type store.
+    concrete_type_builder = infer_concrete_type_builder(nn_module)
+    concrete_type_builder.set_poisoned()
+    concrete_type = concrete_type_builder.build()
+
     cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
     return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)