fix bug in CompilationUnit::define (#21886)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21886
ghimport-source-id: fefbd758bbe2fbcaaad84a376ac5f69c40bccb80

Test Plan: Imported from OSS

Differential Revision: D15867647

Pulled By: suo

fbshipit-source-id: 3e0f5bbc98ec93ccf26442c4c574626e45e53888
diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h
index 4247c60..b348cd4 100644
--- a/torch/csrc/jit/script/compilation_unit.h
+++ b/torch/csrc/jit/script/compilation_unit.h
@@ -152,6 +152,13 @@
   }
 
  private:
+  std::shared_ptr<Function> define(
+      const Def& def,
+      const ResolverPtr& resolver,
+      const Self& self,
+      const std::unordered_map<std::string, std::shared_ptr<Function>>&
+          function_table) const;
+
   Function& register_function(std::shared_ptr<Function> fn) {
     TORCH_CHECK(
         0 == dict_.count(fn->name()),
diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp
index f4bb7b8..2e0083e 100644
--- a/torch/csrc/jit/script/compiler.cpp
+++ b/torch/csrc/jit/script/compiler.cpp
@@ -3043,47 +3043,70 @@
   define(source, nativeResolver(), nullptr);
 }
 
+std::shared_ptr<Function> CompilationUnit::define(
+    const Def& def,
+    const ResolverPtr& resolver,
+    const Self& self,
+    const std::unordered_map<std::string, std::shared_ptr<Function>>&
+        function_table) const {
+  const std::string& name = def.name().name();
+  TORCH_INTERNAL_ASSERT(resolver);
+  auto _resolver = resolver;
+  if (!self) {
+    // if self is defined, then these are methods and do not go into the
+    // global namespace otherwise, they get defined together so we add them to
+    // the function table so the methods can see each other
+    _resolver =
+        std::make_shared<FunctionResolver>(resolver.get(), function_table);
+  }
+  auto creator = [def, _resolver, self](Function& method) {
+    to_ir(def, _resolver, self, method);
+  };
+  return std::make_shared<Function>(
+      name, is_optimized(), std::make_shared<Graph>(), creator);
+}
+
 void CompilationUnit::define(
     const std::vector<Def>& definitions,
     const std::vector<ResolverPtr>& resolvers,
     const Self& self) {
   AT_ASSERT(definitions.size() == resolvers.size());
-  auto resolver_it = resolvers.begin();
-  std::vector<Function*> methods;
-  std::unordered_map<std::string, std::shared_ptr<Function>> function_table;
-
   // We need to compile `__init__` first, since it can determine what attributes
   // are available to other methods. So reorder the definitions accordingly.
-  std::vector<Def> ordered_defs = definitions;
-  const auto it = std::find_if(
-      ordered_defs.begin(), ordered_defs.end(), [](const Def& def) {
-        return def.name().name() == "__init__";
-      });
-  if (it != ordered_defs.end()) {
-    std::swap(ordered_defs[0], *it);
+  c10::optional<size_t> init_idx;
+  for (size_t i = 0; i < definitions.size(); i++) {
+    const auto& def = definitions[i];
+    if (def.name().name() == "__init__") {
+      init_idx = i;
+      break;
+    }
   }
 
-  for (const Def& def : ordered_defs) {
-    const std::string& name = def.name().name();
-    ResolverPtr resolver = *resolver_it++;
-    AT_ASSERT(resolver);
-    if (!self) {
-      // if self is defined, then these are methods and do not go into the
-      // global namespace otherwise, they get defined together so we add them to
-      // the function table so the methods can see each other
-      resolver =
-          std::make_shared<FunctionResolver>(resolver.get(), function_table);
-    }
-    auto creator = [def, resolver, self](Function& method) {
-      AT_ASSERT(resolver);
-      to_ir(def, resolver, self, method);
-    };
-    std::shared_ptr<Function> fn(
-        new Function(name, is_optimized(), std::make_shared<Graph>(), creator));
+  std::vector<Function*> methods;
+  std::unordered_map<std::string, std::shared_ptr<Function>> function_table;
+  if (init_idx.has_value()) {
+    // if we have an init, do it first.
+    auto fn = define(
+        definitions[*init_idx], resolvers[*init_idx], self, function_table);
+    const auto& name = fn->name();
     function_table[name] = fn;
     methods.push_back(fn.get());
     register_function(std::move(fn));
   }
+
+  for (size_t i = 0; i < definitions.size(); i++) {
+    if (init_idx.has_value() && i == *init_idx) {
+      // skip this def since it's already been compiled
+      continue;
+    }
+
+    auto fn = define(definitions[i], resolvers[i], self, function_table);
+    const auto& name = fn->name();
+    function_table[name] = fn;
+    methods.push_back(fn.get());
+    register_function(std::move(fn));
+  }
+
   for (Function* method : methods) {
     method->ensure_defined();
   }