Add flag to temporarily enable first class modules (#21560)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21560
ghimport-source-id: a555ca33fcd3efd1147aaf90f26a8e63da1c1a67

Reviewed By: suo

Differential Revision: D15729502

Pulled By: zdevito

fbshipit-source-id: d6c11472bfc791e2ad1e9aa695b0439d72b79681
diff --git a/test/test_jit.py b/test/test_jit.py
index a97c066..84f245c 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -230,6 +230,12 @@
         s += t.sum()
     return s
 
+@contextmanager
+def enable_first_class_mode():
+    torch._C._jit_set_first_class_mode(True)
+    yield
+    torch._C._jit_set_first_class_mode(False)
+
 # helper function to generate test qparam
 def _helper_generate_qparam(script_module, input_data):
     class TestGenQParam:
@@ -2969,6 +2975,21 @@
 
         self.assertEqual(D()(v), v + v)
 
+    def test_first_class_module(self):
+        with enable_first_class_mode():
+            class Foo(torch.jit.ScriptModule):
+                def __init__(self):
+                    super(Foo, self).__init__()
+                    self.foo = nn.Parameter(torch.rand(3, 4))
+
+                @torch.jit.script_method
+                def forward(self, input):
+                    self.foo = input
+                    return self.foo
+            foo = Foo()
+            input = torch.rand(3, 4)
+            foo.forward(input)
+            self.assertEqual(input, foo.foo)
 
     def test_invalid_prefix_annotation(self):
         with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index eff1a0b..049e702 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -333,6 +333,9 @@
           "_jit_set_profiling_mode",
           [](bool profiling_flag) { getProfilingMode() = profiling_flag; })
       .def(
+          "_jit_set_first_class_mode",
+          [](bool enabled) { script::setRunAsFirstClass(enabled); })
+      .def(
           "_jit_fuser_get_fused_kernel_code",
           [](Graph& g, std::vector<at::Tensor> inps) {
             return debugGetFusedKernelCode(g, inps);
diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp
index 8812d86..c35bfab 100644
--- a/torch/csrc/jit/script/module.cpp
+++ b/torch/csrc/jit/script/module.cpp
@@ -12,6 +12,16 @@
 namespace jit {
 namespace script {
 
+// first class mode runs models as first class objects,
+// and does not force inlining everywhere. This is experimental
+// as we bring up the system since it will degrade performance
+// and may introduce bugs. test_jit.py provides context managers
+// that enable it for specific tests.
+thread_local bool experimental_run_as_first_class = false;
+void setRunAsFirstClass(bool enabled) {
+  experimental_run_as_first_class = enabled;
+}
+
 struct RecursiveMethodCallError : public std::exception {};
 void placeholderCreator(Function&) {
   throw RecursiveMethodCallError();
@@ -213,10 +223,40 @@
   return schema.cloneWithArguments(std::move(sliced));
 }
 
-Method::Method(Module* owner, Function* first_class_function)
+Method::Method(
+    Module* owner,
+    const std::shared_ptr<Function>& first_class_function)
     : owner_(owner), schema_(sliceFirst(first_class_function->getSchema())) {
-  std::tie(function_, initial_ivalues_) =
-      owner->lower_first_class_method(first_class_function);
+  if (experimental_run_as_first_class) {
+    function_ = first_class_function;
+    // initial_ivalues_ left blank
+  } else {
+    std::tie(function_, initial_ivalues_) =
+        owner->lower_first_class_method(first_class_function.get());
+  }
+}
+
+void Method::run(Stack& stack) {
+  if (experimental_run_as_first_class) {
+    stack.insert(stack.begin(), owner().module_object());
+  }
+  for (const auto& input : initial_ivalues_) {
+    push(stack, input.value());
+  }
+  function_->run(stack);
+}
+
+IValue Method::operator()(std::vector<IValue> stack, const Kwargs& kwargs) {
+  getSchema().checkAndNormalizeInputs(stack, kwargs);
+  if (experimental_run_as_first_class) {
+    stack.insert(stack.begin(), owner().module_object());
+  }
+  for (const auto& input : initial_ivalues_) {
+    push(stack, input.value());
+  }
+  // use run rather than operator() to skip the second schema check.
+  function_->run(stack);
+  return stack.front();
 }
 
 void Module::define(const std::string& src, const ResolverPtr& resolver) {
diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h
index b75dc64..a31688e 100644
--- a/torch/csrc/jit/script/module.h
+++ b/torch/csrc/jit/script/module.h
@@ -57,34 +57,19 @@
     std::function<std::shared_ptr<Module>(const std::vector<std::string>&)>;
 
 struct TORCH_API Method {
-  Method(Module* owner, Function* function);
+  Method(Module* owner, const std::shared_ptr<Function>& function);
 
   // the module that contains this method.
   Module& owner() const {
     return *owner_;
   }
 
-  void run(Stack& stack) {
-    for (auto input : initial_ivalues_) {
-      push(stack, input.value());
-    }
-    function_->run(stack);
-  }
+  void run(Stack& stack);
   void run(Stack&& stack) {
     run(stack);
   }
 
-  IValue operator()(
-      std::vector<IValue> stack,
-      const Kwargs& kwargs = Kwargs()) {
-    getSchema().checkAndNormalizeInputs(stack, kwargs);
-    for (auto input : initial_ivalues_) {
-      push(stack, input.value());
-    }
-    // use run rather than operator() to skip the second schema check.
-    function_->run(std::move(stack));
-    return stack.front();
-  }
+  IValue operator()(std::vector<IValue> stack, const Kwargs& kwargs = Kwargs());
 
   const std::vector<Slot>& initial_ivalues() const {
     return initial_ivalues_;
@@ -315,7 +300,8 @@
       return methods_[*offset].get();
     }
 
-    if (Function* fn = class_compilation_unit().find_function(name).get()) {
+    if (const std::shared_ptr<Function>& fn =
+            class_compilation_unit().find_function(name)) {
       // lock because technically this is marked const,
       // but we have to update the internal Method cache.
       // This can be removed when class_compilation_unit() is the source of
@@ -448,8 +434,7 @@
   void define(const std::string& src, const ResolverPtr& resolver = nullptr);
 
   template <typename... Types>
-  IValue create_class(const c10::QualifiedName& name, Types&&... args)
-      const {
+  IValue create_class(const c10::QualifiedName& name, Types&&... args) const {
     return create_class(name, {IValue(std::forward<Types>(args))...});
   }
 
@@ -582,6 +567,8 @@
   friend struct Method;
 };
 
+TORCH_API void setRunAsFirstClass(bool enabled);
+
 } // namespace script
 } // namespace jit
 } // namespace torch