fix bug in CompilationUnit::define (#21886)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21886
ghimport-source-id: fefbd758bbe2fbcaaad84a376ac5f69c40bccb80
Test Plan: Imported from OSS
Differential Revision: D15867647
Pulled By: suo
fbshipit-source-id: 3e0f5bbc98ec93ccf26442c4c574626e45e53888
diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h
index 4247c60..b348cd4 100644
--- a/torch/csrc/jit/script/compilation_unit.h
+++ b/torch/csrc/jit/script/compilation_unit.h
@@ -152,6 +152,13 @@
}
private:
+ std::shared_ptr<Function> define(
+ const Def& def,
+ const ResolverPtr& resolver,
+ const Self& self,
+ const std::unordered_map<std::string, std::shared_ptr<Function>>&
+ function_table) const;
+
Function& register_function(std::shared_ptr<Function> fn) {
TORCH_CHECK(
0 == dict_.count(fn->name()),
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index f4bb7b8..2e0083e 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -3043,47 +3043,70 @@
define(source, nativeResolver(), nullptr);
}
+std::shared_ptr<Function> CompilationUnit::define(
+ const Def& def,
+ const ResolverPtr& resolver,
+ const Self& self,
+ const std::unordered_map<std::string, std::shared_ptr<Function>>&
+ function_table) const {
+ const std::string& name = def.name().name();
+ TORCH_INTERNAL_ASSERT(resolver);
+ auto _resolver = resolver;
+ if (!self) {
+ // if self is defined, then these are methods and do not go into the
+ // global namespace otherwise, they get defined together so we add them to
+ // the function table so the methods can see each other
+ _resolver =
+ std::make_shared<FunctionResolver>(resolver.get(), function_table);
+ }
+ auto creator = [def, _resolver, self](Function& method) {
+ to_ir(def, _resolver, self, method);
+ };
+ return std::make_shared<Function>(
+ name, is_optimized(), std::make_shared<Graph>(), creator);
+}
+
void CompilationUnit::define(
const std::vector<Def>& definitions,
const std::vector<ResolverPtr>& resolvers,
const Self& self) {
AT_ASSERT(definitions.size() == resolvers.size());
- auto resolver_it = resolvers.begin();
- std::vector<Function*> methods;
- std::unordered_map<std::string, std::shared_ptr<Function>> function_table;
-
// We need to compile `__init__` first, since it can determine what attributes
// are available to other methods. So reorder the definitions accordingly.
- std::vector<Def> ordered_defs = definitions;
- const auto it = std::find_if(
- ordered_defs.begin(), ordered_defs.end(), [](const Def& def) {
- return def.name().name() == "__init__";
- });
- if (it != ordered_defs.end()) {
- std::swap(ordered_defs[0], *it);
+ c10::optional<size_t> init_idx;
+ for (size_t i = 0; i < definitions.size(); i++) {
+ const auto& def = definitions[i];
+ if (def.name().name() == "__init__") {
+ init_idx = i;
+ break;
+ }
}
- for (const Def& def : ordered_defs) {
- const std::string& name = def.name().name();
- ResolverPtr resolver = *resolver_it++;
- AT_ASSERT(resolver);
- if (!self) {
- // if self is defined, then these are methods and do not go into the
- // global namespace otherwise, they get defined together so we add them to
- // the function table so the methods can see each other
- resolver =
- std::make_shared<FunctionResolver>(resolver.get(), function_table);
- }
- auto creator = [def, resolver, self](Function& method) {
- AT_ASSERT(resolver);
- to_ir(def, resolver, self, method);
- };
- std::shared_ptr<Function> fn(
- new Function(name, is_optimized(), std::make_shared<Graph>(), creator));
+ std::vector<Function*> methods;
+ std::unordered_map<std::string, std::shared_ptr<Function>> function_table;
+ if (init_idx.has_value()) {
+ // if we have an init, do it first.
+ auto fn = define(
+ definitions[*init_idx], resolvers[*init_idx], self, function_table);
+ const auto& name = fn->name();
function_table[name] = fn;
methods.push_back(fn.get());
register_function(std::move(fn));
}
+
+ for (size_t i = 0; i < definitions.size(); i++) {
+ if (init_idx.has_value() && i == *init_idx) {
+ // skip this def since it's already been compiled
+ continue;
+ }
+
+ auto fn = define(definitions[i], resolvers[i], self, function_table);
+ const auto& name = fn->name();
+ function_table[name] = fn;
+ methods.push_back(fn.get());
+ register_function(std::move(fn));
+ }
+
for (Function* method : methods) {
method->ensure_defined();
}