Split ConcreteModuleType into two types (#29824)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29824
We have two distinct phases/uses for ConcreteModuleType:
1. We are building it up and using it to check whether we can
reuse JIT types. (RawConcreteModuleType)
2. We are using it to satisfy ModuleValue::attr queries.
(ConcreteModuleType)
These types share an underlying `ConcreteModuleTypeData` which
actually stores the relevant info.
Previously they were the same type because I was lazy, but it's been the
source of a bug. So split them to formalize the differing invariants for
the two phases.
Test Plan: Imported from OSS
Differential Revision: D18575010
Pulled By: suo
fbshipit-source-id: 3e4ebcd36e78b947150d8f0dbb74ecccad23e7c4
diff --git a/torch/csrc/jit/script/concrete_module_type.cpp b/torch/csrc/jit/script/concrete_module_type.cpp
index 2ab9e90..1e0b96b 100644
--- a/torch/csrc/jit/script/concrete_module_type.cpp
+++ b/torch/csrc/jit/script/concrete_module_type.cpp
@@ -3,16 +3,7 @@
namespace torch {
namespace jit {
namespace script {
-
-ClassTypePtr ConcreteModuleType::getJitType() const {
- TORCH_INTERNAL_ASSERT(jitType_);
- return jitType_;
-}
-
-ClassTypePtr ConcreteModuleType::createNewTypeFromThis() {
- TORCH_INTERNAL_ASSERT(!jitType_);
- TORCH_INTERNAL_ASSERT(pyClass_);
-
+ClassTypePtr ConcreteModuleTypeBuilder::createTypeFromThis() const {
auto cu = get_python_cu();
py::object pyQualName = py::module::import("torch._jit_internal")
.attr("_qualified_name")(pyClass_);
@@ -41,21 +32,85 @@
moduleInfo.name_, moduleInfo.getJitType(), /*is_parameter=*/false);
}
- jitType_ = std::move(cls);
+ return cls;
+}
+
+ConcreteModuleType::ConcreteModuleType(ConcreteModuleTypeBuilder data)
+ : data_(std::move(data)) {
+ jitType_ = data_.createTypeFromThis();
+}
+
+TypePtr ConcreteModuleTypeBuilder::ModuleInfo::getJitType() const {
+ return meta_ == nullptr ? type_ : meta_->getJitType();
+}
+
+bool operator==(
+ const ConcreteModuleTypeBuilder::ModuleInfo& lhs,
+ const ConcreteModuleTypeBuilder::ModuleInfo& rhs) {
+ if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
+ return lhs.meta_->equals(*rhs.meta_);
+ } else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
+ return *(lhs.type_) == *(rhs.type_);
+ } else {
+ return false;
+ }
+}
+
+bool ConcreteModuleTypeBuilder::equals(
+ const ConcreteModuleTypeBuilder& other) const {
+ if (isPoisoned_ || other.isPoisoned_) {
+ return false;
+ }
+
+ // clang-format off
+ // These are vaguely ordered so that cheap, discriminating checks happen first.
+ bool equal =
+ pyClass_.is(other.pyClass_) &&
+ iterableModuleKind_ == other.iterableModuleKind_ &&
+ constants_ == other.constants_ &&
+ attributes_ == other.attributes_ &&
+ overloads_ == other.overloads_ &&
+ functionAttributes_ == other.functionAttributes_;
+ // clang-format on
+ if (!equal) {
+ return false;
+ }
+
+ // We store modules in order of insertion (to make compilation
+ // deterministic). However, for the purposes of equality, insertion order
+ // should not matter, so sort them by name.
+ // We put this check last because it involves the most work.
+ auto thisSorted = modules_;
+ std::sort(
+ thisSorted.begin(),
+ thisSorted.end(),
+ [](const ModuleInfo& a, const ModuleInfo& b) {
+ return a.name_ < b.name_;
+ });
+
+ auto otherSorted = other.modules_;
+ std::sort(
+ otherSorted.begin(),
+ otherSorted.end(),
+ [](const ModuleInfo& a, const ModuleInfo& b) {
+ return a.name_ < b.name_;
+ });
+
+ return thisSorted == otherSorted;
+}
+
+ClassTypePtr ConcreteModuleType::getJitType() const {
return jitType_;
}
py::object ConcreteModuleType::getPyClass() const {
- TORCH_INTERNAL_ASSERT(jitType_);
- TORCH_INTERNAL_ASSERT(pyClass_);
- return pyClass_;
+ return data_.pyClass_;
}
c10::optional<std::vector<std::string>> ConcreteModuleType::findOverloads(
const std::string& name) const {
- TORCH_INTERNAL_ASSERT(jitType_);
- const auto it = overloads_.find(name);
- if (it != overloads_.end()) {
+ const auto it = data_.overloads_.find(name);
+ if (it != data_.overloads_.end()) {
return it->second;
}
return c10::nullopt;
@@ -63,9 +118,8 @@
c10::optional<Function*> ConcreteModuleType::findFunctionAttribute(
const std::string& name) const {
- TORCH_INTERNAL_ASSERT(jitType_);
- const auto it = functionAttributes_.find(name);
- if (it != functionAttributes_.end()) {
+ const auto it = data_.functionAttributes_.find(name);
+ if (it != data_.functionAttributes_.end()) {
return it->second.function_->function();
}
return c10::nullopt;
@@ -73,9 +127,8 @@
c10::optional<std::string> ConcreteModuleType::findFailedAttribute(
const std::string& name) const {
- TORCH_INTERNAL_ASSERT(jitType_);
- const auto it = failedAttributes_.find(name);
- if (it != failedAttributes_.end()) {
+ const auto it = data_.failedAttributes_.find(name);
+ if (it != data_.failedAttributes_.end()) {
return it->second;
}
return c10::nullopt;
@@ -83,130 +136,114 @@
std::shared_ptr<ConcreteModuleType> ConcreteModuleType::
findSubmoduleConcreteType(const std::string& name) const {
- TORCH_INTERNAL_ASSERT(jitType_);
const auto it = std::find_if(
- modules_.cbegin(), modules_.cend(), [&](const ModuleInfo& info) {
+ data_.modules_.cbegin(),
+ data_.modules_.cend(),
+ [&](const ConcreteModuleTypeBuilder::ModuleInfo& info) {
return info.name_ == name;
});
- if (it == modules_.end()) {
+ if (it == data_.modules_.end()) {
return nullptr;
}
return it->meta_;
}
-void ConcreteModuleType::setIterableModuleKind(IterableModuleKind kind) {
- TORCH_INTERNAL_ASSERT(!jitType_);
+void ConcreteModuleTypeBuilder::setIterableModuleKind(IterableModuleKind kind) {
iterableModuleKind_ = kind;
}
IterableModuleKind ConcreteModuleType::getIterableModuleKind() const {
- TORCH_INTERNAL_ASSERT(jitType_);
- return iterableModuleKind_;
+ return data_.iterableModuleKind_;
}
-void ConcreteModuleType::setPoisoned() {
- TORCH_INTERNAL_ASSERT(!jitType_)
+void ConcreteModuleTypeBuilder::setPoisoned() {
isPoisoned_ = true;
}
-void ConcreteModuleType::addJitType(ClassTypePtr type) {
- TORCH_INTERNAL_ASSERT(!jitType_)
- jitType_ = std::move(type);
-}
-
-void ConcreteModuleType::addPyClass(py::object pyClass) {
- TORCH_INTERNAL_ASSERT(!jitType_);
- pyClass_ = std::move(pyClass);
-}
-
-void ConcreteModuleType::addConstant(std::string name, py::object value) {
- TORCH_INTERNAL_ASSERT(!jitType_);
+void ConcreteModuleTypeBuilder::addConstant(std::string name, py::object value) {
constants_.emplace(std::move(name), std::move(value));
}
-void ConcreteModuleType::addAttribute(
+void ConcreteModuleTypeBuilder::addAttribute(
std::string name,
TypePtr type,
bool isParameter) {
TORCH_INTERNAL_ASSERT(type);
- TORCH_INTERNAL_ASSERT(!jitType_);
// Function attributes should be handled separately
TORCH_INTERNAL_ASSERT(type->cast<FunctionType>() == nullptr);
attributes_.emplace(
- std::move(name), Attribute(unshapedType(type), isParameter));
+ std::move(name),
+ ConcreteModuleTypeBuilder::Attribute(unshapedType(type), isParameter));
}
-void ConcreteModuleType::addFunctionAttribute(
+void ConcreteModuleTypeBuilder::addFunctionAttribute(
std::string name,
const TypePtr& type,
py::object pyFunction) {
TORCH_INTERNAL_ASSERT(type);
- TORCH_INTERNAL_ASSERT(!jitType_);
functionAttributes_.emplace(
std::move(name),
- FunctionAttribute{type->expect<FunctionType>(), std::move(pyFunction)});
+ ConcreteModuleTypeBuilder::FunctionAttribute{type->expect<FunctionType>(),
+ std::move(pyFunction)});
}
-void ConcreteModuleType::addModule(
+void ConcreteModuleTypeBuilder::addModule(
std::string name,
std::shared_ptr<ConcreteModuleType> meta) {
- TORCH_INTERNAL_ASSERT(!jitType_);
- modules_.emplace_back(ModuleInfo{std::move(name), std::move(meta)});
+ modules_.emplace_back(
+ ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), std::move(meta)});
}
-void ConcreteModuleType::addModuleInterface(
+void ConcreteModuleTypeBuilder::addModuleInterface(
std::string name,
const TypePtr& type) {
- TORCH_INTERNAL_ASSERT(!jitType_);
TORCH_INTERNAL_ASSERT(type->cast<InterfaceType>() && type->is_module());
- modules_.emplace_back(ModuleInfo{std::move(name), type});
+ modules_.emplace_back(
+ ConcreteModuleTypeBuilder::ModuleInfo{std::move(name), type});
}
-
-void ConcreteModuleType::addOverload(
+void ConcreteModuleTypeBuilder::addOverload(
std::string methodName,
std::vector<std::string> overloadedMethodNames) {
- TORCH_INTERNAL_ASSERT(!jitType_);
overloads_.emplace(std::move(methodName), std::move(overloadedMethodNames));
}
-void ConcreteModuleType::addFailedAttribute(
+void ConcreteModuleTypeBuilder::addFailedAttribute(
std::string name,
std::string failureReason) {
- TORCH_INTERNAL_ASSERT(!jitType_);
failedAttributes_.emplace(std::move(name), std::move(failureReason));
}
c10::optional<py::object> ConcreteModuleType::findConstant(
const std::string& name) const {
- auto it = constants_.find(name);
- if (it != constants_.end()) {
+ auto it = data_.constants_.find(name);
+ if (it != data_.constants_.end()) {
return it->second.v_;
}
return c10::nullopt;
}
void ConcreteModuleType::dump() const {
- std::cout << "ConcreteModuleType for: " << py::getattr(pyClass_, "__name__") << "\n";
+ std::cout << "ConcreteModuleType for: " << py::getattr(data_.pyClass_, "__name__") << "\n";
std::cout << "Constants: \n";
- for (const auto& pr : constants_) {
+ for (const auto& pr : data_.constants_) {
std::cout << "\t" << pr.first << ": " << pr.second.v_ << "\n";
}
std::cout << "\nAttributes: \n";
- for (const auto& pr : attributes_) {
+ for (const auto& pr : data_.attributes_) {
std::cout << "\t" << pr.first << ": " << pr.second.type_->python_str()
<< "\n";
}
std::cout << "\nSubmodules: \n";
- for (const auto& info : modules_) {
+ for (const auto& info : data_.modules_) {
std::cout << "\t" << info.name_ << ": "
<< info.getJitType()->python_str() << "\n";
}
std::cout << "\nOverloads: \n";
- for (const auto& pr : overloads_) {
+ for (const auto& pr : data_.overloads_) {
std::cout << "\t" << pr.first << ": " << pr.second << "\n";
}
- std::string isPoisoned = isPoisoned_ ? "true" : "false";
+ std::string isPoisoned = data_.isPoisoned_ ? "true" : "false";
std::cout << "isPoisoned: " << isPoisoned << "\n";
if (jitType_) {
std::cout << "jit type: " << jitType_->python_str() << "\n";
@@ -215,11 +252,10 @@
std::unordered_map<std::string, py::object> ConcreteModuleType::getConstantsPy()
const {
- TORCH_INTERNAL_ASSERT(jitType_);
// Convert to a more pybind-friendly representation, so we don't
// need to bind ConcreteModuleType::Constant as well.
std::unordered_map<std::string, py::object> ret;
- for (const auto& pr : constants_) {
+ for (const auto& pr : data_.constants_) {
ret.emplace(pr.first, pr.second.v_);
}
return ret;
@@ -227,11 +263,10 @@
std::unordered_map<std::string, std::pair<TypePtr, bool>> ConcreteModuleType::
getAttributesPy() const {
- TORCH_INTERNAL_ASSERT(jitType_);
// Convert to a more pybind-friendly representation, so we don't
// need to bind ConcreteModuleType::Attribute as well.
std::unordered_map<std::string, std::pair<TypePtr, bool>> ret;
- for (auto& pr : attributes_) {
+ for (auto& pr : data_.attributes_) {
ret.emplace(
pr.first,
std::pair<TypePtr, bool>(pr.second.type_, pr.second.isParam_));
@@ -241,10 +276,9 @@
std::vector<std::pair<std::string, TypePtr>> ConcreteModuleType::getModulesPy()
const {
- TORCH_INTERNAL_ASSERT(jitType_);
std::vector<std::pair<std::string, TypePtr>> ret;
- for (const ModuleInfo& info: modules_) {
+ for (const auto& info : data_.modules_) {
ret.emplace_back(std::make_pair(info.name_, info.getJitType()));
}
return ret;
diff --git a/torch/csrc/jit/script/concrete_module_type.h b/torch/csrc/jit/script/concrete_module_type.h
index d098c40..a587772 100644
--- a/torch/csrc/jit/script/concrete_module_type.h
+++ b/torch/csrc/jit/script/concrete_module_type.h
@@ -9,7 +9,9 @@
namespace torch {
namespace jit {
namespace script {
+
enum class IterableModuleKind { NONE, LIST, DICT };
+class ConcreteModuleType;
// You can think of an nn.Module as a template that corresponds to a family of
// JIT types. The template "arguments" are things like the constant values.
@@ -41,21 +43,22 @@
// ModuleValue::attr calls. This is so we can guarantee that if two Module's
// share a JIT type (and thus a ConcreteModuleType), then they behave the same
// way when you access attributes on them.
-class VISIBILITY_HIDDEN ConcreteModuleType {
- public:
- // ConcreteModuleType has two states.
- // 1. Building: First we build it up, during the ScriptModule conversion
- // process
- // ... to transition, we freeze the type by associating with a JIT type.
- // 2. Querying: Then we ask use it as a source of truth during method
- // compilation.
- // During this time the ModuleType is effectively const.
- // Yes, it could be two different types. Not terribly worth the verbosity.
- /**
- * Builder methods (jitType must be null)
- */
- void addPyClass(py::object pyClass);
+// ConcreteModuleType has two phases.
+// 1. Creation: First we build it up, during the ScriptModule conversion
+// process. This is represented by ConcreteModuleTypeBuilder.
+// ...then the converter calls ConcreteModuleTypeBuilder::build(), producing a
+// ConcreteModuleType ready for querying.
+// 2. Querying: We use ConcreteModuleType as a source of truth for
+// ModuleValue::attr calls during method compilation.
+
+// Represents a concrete type during in the process for construction. We use
+// this to decide whether we can share types between modules.
+class VISIBILITY_HIDDEN ConcreteModuleTypeBuilder {
+ public:
+ explicit ConcreteModuleTypeBuilder(py::object pyClass) {
+ pyClass_ = std::move(pyClass);
+ }
void addConstant(std::string name, py::object value);
void addAttribute(std::string name, TypePtr type, bool isParameter);
void addFunctionAttribute(
@@ -73,95 +76,20 @@
std::vector<std::string> overloadedMethodNames);
void addFailedAttribute(std::string name, std::string failureReason);
void setIterableModuleKind(IterableModuleKind kind);
+
+ // If a ConcreteModuleType is poisoned, it will never compare equal to any other
+ // concrete type
void setPoisoned();
- /**
- * Freezing methods
- */
-
- // Based on the data in this ConcreteType, create an equivalent JIT type and
- // associate this module type with it.
- ClassTypePtr createNewTypeFromThis();
- // Associate the provided type with this ConcreteType
- void addJitType(ClassTypePtr type);
-
- /**
- * Query methods (jitType must be non-null)
- */
- ClassTypePtr getJitType() const;
- py::object getPyClass() const;
- IterableModuleKind getIterableModuleKind() const;
- c10::optional<py::object> findConstant(const std::string& name) const;
- c10::optional<std::vector<std::string>> findOverloads(
- const std::string& name) const;
- c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
- std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
- const std::string& name) const;
- c10::optional<std::string> findFailedAttribute(const std::string& name) const;
-
- // These getters are only here to return things as types that can be
- // automatically converted by pybind.
- std::unordered_map<std::string, py::object> getConstantsPy() const;
- std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
- const;
- std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
+ std::shared_ptr<ConcreteModuleType> build() const {
+ return std::make_shared<ConcreteModuleType>(*this);
+ }
// This determines whether two modules can share a type. The container structs
// used by ConcreteModuleType have been defined such that operator==
// implements a meaningful comparison in that context.
- friend bool operator==(
- const ConcreteModuleType& lhs,
- const ConcreteModuleType& rhs) {
- if (lhs.jitType_ == rhs.jitType_) {
- // If the computed types are the same, these modules can (obviously) share
- // a type.
- return true;
- }
+ bool equals(const ConcreteModuleTypeBuilder& other) const;
- if (lhs.isPoisoned_ || rhs.isPoisoned_) {
- return false;
- }
-
- // clang-format off
- // These are vaguely ordered so that cheap, discriminating checks happen first.
- bool equal =
- lhs.pyClass_.is(rhs.pyClass_) &&
- lhs.iterableModuleKind_ == rhs.iterableModuleKind_ &&
- lhs.constants_ == rhs.constants_ &&
- lhs.attributes_ == rhs.attributes_ &&
- lhs.overloads_ == rhs.overloads_ &&
- lhs.functionAttributes_ == rhs.functionAttributes_;
- // clang-format on
- if (!equal) {
- return false;
- }
-
- // We store modules in order of insertion (to make compilation
- // deterministic). However, for the purposes of equality, insertion order
- // should not matter, so sort them by name.
- // We put this check last because it involves the most work.
- auto lhsSorted = lhs.modules_;
- std::sort(
- lhsSorted.begin(),
- lhsSorted.end(),
- [](const ModuleInfo& a, const ModuleInfo& b) {
- return a.name_ < b.name_;
- });
-
- auto rhsSorted = rhs.modules_;
- std::sort(
- rhsSorted.begin(),
- rhsSorted.end(),
- [](const ModuleInfo& a, const ModuleInfo& b) {
- return a.name_ < b.name_;
- });
-
- return lhsSorted == rhsSorted;
- }
-
- void dump() const;
-
- private:
struct Constant {
/* implicit */ Constant(py::object v) : v_(std::move(v)) {}
friend bool operator==(const Constant& lhs, const Constant& rhs) {
@@ -206,28 +134,18 @@
ModuleInfo(std::string name, const TypePtr& type)
: name_(std::move(name)), meta_(nullptr), type_(type) {}
- friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs) {
- if (lhs.meta_ != nullptr && rhs.meta_ != nullptr) {
- return *(lhs.meta_) == *(rhs.meta_);
- } else if (lhs.type_ != nullptr && rhs.type_ != nullptr) {
- return *(lhs.type_) == *(rhs.type_);
- } else {
- return false;
- }
- }
-
- TypePtr getJitType() const {
- return meta_ == nullptr? type_ : meta_->getJitType();
- }
+ TypePtr getJitType() const;
std::string name_;
+ friend bool operator==(const ModuleInfo& lhs, const ModuleInfo& rhs);
// Module Info contains either an ConcreateModuleType or a type (which is
// a Module Interface), these two are union relationship.
std::shared_ptr<ConcreteModuleType> meta_;
TypePtr type_;
-
};
+ private:
+ ClassTypePtr createTypeFromThis() const;
// If true, this type will never compare equally to anything else. This is
// used if we want to ensure that this type is not shared (for example, if it
// came from a traced module)
@@ -256,11 +174,54 @@
// The original `nn.Module` class that we derived this ScriptModule from.
py::object pyClass_;
- // The JIT type derived from this ConcreteModuleType.
- ClassTypePtr jitType_ = nullptr;
// NOTE: If you ever add any more state to this struct, you need to make sure
- // operator== still makes sense! The only field that can be excluded from it
- // is `jitType_`.
+ // operator== still makes sense!
+ friend ConcreteModuleType;
+};
+
+// Represents a finalized concrete type, used to service ModuleValue::attr calls
+// during method compilation.
+class VISIBILITY_HIDDEN ConcreteModuleType {
+ public:
+ explicit ConcreteModuleType(ConcreteModuleTypeBuilder data);
+
+ ClassTypePtr getJitType() const;
+ py::object getPyClass() const;
+ IterableModuleKind getIterableModuleKind() const;
+ c10::optional<py::object> findConstant(const std::string& name) const;
+ c10::optional<std::vector<std::string>> findOverloads(
+ const std::string& name) const;
+ c10::optional<Function*> findFunctionAttribute(const std::string& name) const;
+ std::shared_ptr<ConcreteModuleType> findSubmoduleConcreteType(
+ const std::string& name) const;
+ c10::optional<std::string> findFailedAttribute(const std::string& name) const;
+
+ // These getters are only here to return things as types that can be
+ // automatically converted by pybind.
+ std::unordered_map<std::string, py::object> getConstantsPy() const;
+ std::unordered_map<std::string, std::pair<TypePtr, bool>> getAttributesPy()
+ const;
+ std::vector<std::pair<std::string, TypePtr>> getModulesPy() const;
+
+ bool equals(const ConcreteModuleType& other) const {
+ if (jitType_ == other.jitType_) {
+ // If the computed types are the same, these modules can (obviously) share
+ // a type.
+ return true;
+ }
+
+ return data_.equals(other.data_);
+ }
+ bool equals(const ConcreteModuleTypeBuilder& other) const {
+ return data_.equals(other);
+ }
+
+ void dump() const;
+
+ private:
+ // The JIT type derived from this ConcreteModuleType.
+ ConcreteModuleTypeBuilder data_;
+ ClassTypePtr jitType_;
};
} // namespace script
diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp
index beb7e79..f495fd2 100644
--- a/torch/csrc/jit/script/init.cpp
+++ b/torch/csrc/jit/script/init.cpp
@@ -1052,42 +1052,51 @@
return Module(get_python_cu(), type);
});
+ py::class_<ConcreteModuleTypeBuilder, std::shared_ptr<ConcreteModuleTypeBuilder>>(
+ m, "ConcreteModuleTypeBuilder")
+ .def(py::init<py::object>())
+ .def("add_constant", &ConcreteModuleTypeBuilder::addConstant)
+ .def("add_attribute", &ConcreteModuleTypeBuilder::addAttribute)
+ .def(
+ "add_function_attribute",
+ &ConcreteModuleTypeBuilder::addFunctionAttribute)
+ .def("add_module", &ConcreteModuleTypeBuilder::addModule)
+ .def("add_module_interface", &ConcreteModuleTypeBuilder::addModuleInterface)
+ .def("add_overload", &ConcreteModuleTypeBuilder::addOverload)
+ .def("set_poisoned", &ConcreteModuleTypeBuilder::setPoisoned)
+ .def("add_failed_attribute", &ConcreteModuleTypeBuilder::addFailedAttribute)
+ .def(
+ "set_module_dict",
+ [](ConcreteModuleTypeBuilder& self) {
+ self.setIterableModuleKind(IterableModuleKind::DICT);
+ })
+ .def("build", &ConcreteModuleTypeBuilder::build)
+ .def(
+ "equals",
+ [](const ConcreteModuleTypeBuilder& self,
+ const ConcreteModuleTypeBuilder& other) { return self.equals(other); })
+ .def("set_module_list", [](ConcreteModuleTypeBuilder& self) {
+ self.setIterableModuleKind(IterableModuleKind::LIST);
+ });
+
py::class_<ConcreteModuleType, std::shared_ptr<ConcreteModuleType>>(
m, "ConcreteModuleType")
- .def(py::init<>())
.def_property_readonly("py_class", &ConcreteModuleType::getPyClass)
.def_property_readonly("jit_type", &ConcreteModuleType::getJitType)
.def("get_constants", &ConcreteModuleType::getConstantsPy)
.def("get_attributes", &ConcreteModuleType::getAttributesPy)
.def("get_modules", &ConcreteModuleType::getModulesPy)
- .def("add_constant", &ConcreteModuleType::addConstant)
- .def("add_attribute", &ConcreteModuleType::addAttribute)
- .def("add_function_attribute", &ConcreteModuleType::addFunctionAttribute)
- .def("add_module", &ConcreteModuleType::addModule)
- .def("add_module_interface", &ConcreteModuleType::addModuleInterface)
- .def("add_pyclass", &ConcreteModuleType::addPyClass)
- .def("add_overload", &ConcreteModuleType::addOverload)
- .def("add_jit_type", &ConcreteModuleType::addJitType)
- .def("set_poisoned", &ConcreteModuleType::setPoisoned)
- .def(
- "set_module_dict",
- [](ConcreteModuleType& self) {
- self.setIterableModuleKind(IterableModuleKind::DICT);
- })
- .def(
- "set_module_list",
- [](ConcreteModuleType& self) {
- self.setIterableModuleKind(IterableModuleKind::LIST);
- })
- .def(
- "create_new_type_from_this",
- &ConcreteModuleType::createNewTypeFromThis)
- .def("add_failed_attribute", &ConcreteModuleType::addFailedAttribute)
.def("dump", &ConcreteModuleType::dump)
.def(
"equals",
[](const ConcreteModuleType& self, const ConcreteModuleType& other) {
- return self == other;
+ return self.equals(other);
+ })
+ .def(
+ "equals",
+ [](const ConcreteModuleType& self,
+ const ConcreteModuleTypeBuilder& other) {
+ return self.equals(other);
})
.def(
"_create_methods",
diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py
index 4b1b6b5..277dab1 100644
--- a/torch/jit/_recursive.py
+++ b/torch/jit/_recursive.py
@@ -59,18 +59,17 @@
3. a list or tuple of (2)
""".format(type(v).__name__, attr, constants)))
-def infer_raw_concrete_type(nn_module):
+def infer_concrete_type_builder(nn_module):
"""
- Build a ConcreteModuleType from an nn.Module. This ConcreteModuleType
- doesn't have a JIT type associated with it yet, it must be filled in
- by the caller.
+ Build a ConcreteModuleTypeBuilder from an nn.Module. This
+ ConcreteModuleType doesn't have a JIT type associated with it yet, it
+ must be filled in by the caller.
"""
- concrete_type = torch._C.ConcreteModuleType()
- concrete_type.add_pyclass(type(nn_module))
+ concrete_type_builder = torch._C.ConcreteModuleTypeBuilder(type(nn_module))
if isinstance(nn_module, (torch.nn.ModuleDict)):
- concrete_type.set_module_dict()
+ concrete_type_builder.set_module_dict()
if isinstance(nn_module, (torch.nn.ModuleList, torch.nn.Sequential)):
- concrete_type.set_module_list()
+ concrete_type_builder.set_module_list()
class_annotations = getattr(nn_module, '__annotations__', {})
@@ -96,7 +95,7 @@
continue
assert isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
- concrete_type.add_attribute(name, attr_type, True)
+ concrete_type_builder.add_attribute(name, attr_type, True)
added_names.add(name)
for name, item in nn_module._buffers.items():
@@ -109,18 +108,18 @@
continue
assert isinstance(item, torch.Tensor)
attr_type = infer_type(name, item)
- concrete_type.add_attribute(name, attr_type, False)
+ concrete_type_builder.add_attribute(name, attr_type, False)
added_names.add(name)
for name, item in nn_module._modules.items():
attr_type = infer_type(name, item)
if attr_type is not None:
# if the type can be inferred, it should be a module interface type
- concrete_type.add_module_interface(name, attr_type)
+ concrete_type_builder.add_module_interface(name, attr_type)
else:
# otherwise we get the concrete module type for item and add it to concrete_type
sub_concrete_type = concrete_type_store.get_or_create_concrete_type(item)
- concrete_type.add_module(name, sub_concrete_type)
+ concrete_type_builder.add_module(name, sub_concrete_type)
added_names.add(name)
@@ -146,7 +145,7 @@
"Consider removing it.".format(name))
continue
value = getattr(nn_module, name)
- concrete_type.add_constant(name, _get_valid_constant(name, value))
+ concrete_type_builder.add_constant(name, _get_valid_constant(name, value))
added_names.add(name)
# populate overloads
@@ -154,7 +153,7 @@
# update with any annotated overloads
overloads.update(get_overload_name_mapping(get_overload_annotations(nn_module)))
for name, overloaded_names in overloads.items():
- concrete_type.add_overload(name, overloaded_names)
+ concrete_type_builder.add_overload(name, overloaded_names)
for name, value in nn_module.__dict__.items():
@@ -171,7 +170,7 @@
if inspect.isfunction(value):
try:
scripted_fn = torch.jit.script(value)
- concrete_type.add_function_attribute(
+ concrete_type_builder.add_function_attribute(
name,
torch._C._jit_try_infer_type(scripted_fn),
value)
@@ -182,14 +181,14 @@
hint = ("(This function exists as an attribute on the Python module, "
"but we failed to compile it to a TorchScript function. "
"\nThe error stack is reproduced here:\n{}").format(e)
- concrete_type.add_failed_attribute(name, hint)
+ concrete_type_builder.add_failed_attribute(name, hint)
pass
continue
# Handle Script function attributes
if isinstance(value, torch.jit.ScriptFunction):
- concrete_type.add_function_attribute(
+ concrete_type_builder.add_function_attribute(
name,
torch._C._jit_try_infer_type(value),
value)
@@ -198,14 +197,14 @@
# If we got here, this is a regular "data" attribute, Add it to the concrete type
attr_type = infer_type(name, value)
if attr_type is not None:
- concrete_type.add_attribute(name, attr_type, False)
+ concrete_type_builder.add_attribute(name, attr_type, False)
else:
# TODO: could add more detail here. For example, what the user should do
# when the pytype is `list` or `NoneType`
hint = ("(This attribute exists on the Python module, "
"but we failed to convert Python type: '{}' "
"to a TorchScript type.)").format(type(value).__name__)
- concrete_type.add_failed_attribute(name, hint)
+ concrete_type_builder.add_failed_attribute(name, hint)
# Add @property methods as failed attributes, to give a better error message.
for name, value in type(nn_module).__dict__.items():
@@ -213,9 +212,9 @@
hint = ("\n(This attribute exists on the Python module, but it's an @property "
"method. @property methods are not yet supported in TorchScript. "
"Please file a feature request on Github)")
- concrete_type.add_failed_attribute(name, hint)
+ concrete_type_builder.add_failed_attribute(name, hint)
- return concrete_type
+ return concrete_type_builder
class ConcreteTypeStore(object):
def __init__(self):
@@ -234,7 +233,7 @@
hasattr(nn_module, "_concrete_type"):
return nn_module._concrete_type
- raw_concrete_type = infer_raw_concrete_type(nn_module)
+ concrete_type_builder = infer_concrete_type_builder(nn_module)
nn_module_type = type(nn_module)
if nn_module_type not in self.type_store:
@@ -243,13 +242,13 @@
# Search the type store for an already-available JIT type
known_types = self.type_store[nn_module_type]
for known_type in known_types:
- if raw_concrete_type.equals(known_type):
+ if known_type.equals(concrete_type_builder):
return known_type
# We didn't find anything; generate a new JIT type from this concrete type
- raw_concrete_type.create_new_type_from_this()
- self.type_store[nn_module_type].append(raw_concrete_type)
- return raw_concrete_type
+ concrete_type = concrete_type_builder.build()
+ self.type_store[nn_module_type].append(concrete_type)
+ return concrete_type
concrete_type_store = ConcreteTypeStore()
@@ -272,11 +271,12 @@
stubs: ScriptMethodStubs to compile as part of the conversion process.
"""
check_module_initialized(nn_module)
- # Get a ConcreteType without a JIT type. We will generate one ourselves
- # and fill it in.
- concrete_type = infer_raw_concrete_type(nn_module)
- concrete_type.set_poisoned()
- concrete_type.create_new_type_from_this()
+ # Get a concrete type directly, without trying to re-use an existing JIT
+ # type from the type store.
+ concrete_type_builder = infer_concrete_type_builder(nn_module)
+ concrete_type_builder.set_poisoned()
+ concrete_type = concrete_type_builder.build()
+
cpp_module = torch._C._create_module_with_type(concrete_type.jit_type)
return create_script_module_impl(nn_module, concrete_type, cpp_module, stubs)