[JIT] Support for registering C++ lambdas as methods on custom C++ class

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32553

Test Plan: Imported from OSS

Differential Revision: D19543269

Pulled By: jamesr66a

fbshipit-source-id: 7e566650295e9d1c4f2f716470e061308a6210a0
diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp
index ff77296..f1b46b54 100644
--- a/test/cpp/jit/test_custom_class.cpp
+++ b/test/cpp/jit/test_custom_class.cpp
@@ -85,8 +85,16 @@
         .def("merge", &Stack<std::string>::merge)
         .def("__getstate__", &Stack<std::string>::__getstate__)
         .def("__setstate__", &Stack<std::string>::__setstate__)
-        .def("return_a_tuple", &Stack<std::string>::return_a_tuple);
-
+        .def("return_a_tuple", &Stack<std::string>::return_a_tuple)
+        .def(
+            "top",
+            [](const c10::intrusive_ptr<Stack<std::string>>& self)
+                -> std::string { return self->stack_.back(); });
+// clang-format off
+        // The following will fail with a static assert telling you you have to
+        // take an intrusive_ptr<Stack> as the first argument.
+        // .def("foo", [](int64_t a) -> int64_t{ return 3;});
+// clang-format on
 } // namespace
 
 } // namespace jit
diff --git a/test/test_jit.py b/test/test_jit.py
index 4743bae..26156d1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4990,6 +4990,16 @@
         x = torch.rand(3, 4)
         self.assertEqual(scripted(x), eic(x))
 
+    @skipIfRocm
+    @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
+    def test_torchbind_lambda_method(self):
+        def foo():
+            ss = torch.classes._TorchScriptTesting_StackString(["mom"])
+            return ss.top()
+
+        scripted = torch.jit.script(foo)
+        self.assertEqual(scripted(), "mom")
+
     def test_jitter_bug(self):
         @torch.jit.script
         def fn2(input, kernel_size):
diff --git a/torch/custom_class.h b/torch/custom_class.h
index 3d8fe85..c052063 100644
--- a/torch/custom_class.h
+++ b/torch/custom_class.h
@@ -9,6 +9,7 @@
 #include <c10/util/C++17.h>
 #include <c10/util/Metaprogramming.h>
 #include <c10/util/TypeList.h>
+#include <c10/util/TypeTraits.h>
 #include <torch/csrc/jit/custom_class.h>
 #include <torch/csrc/jit/operator.h>
 #include <torch/csrc/jit/script/compilation_unit.h>
@@ -101,9 +102,28 @@
     defineMethod<void>("__init__", std::move(func));
     return *this;
   }
-  template <typename Func>
-  class_& def(std::string name, Func f) {
-    auto res = def_(name, f, detail::args_t<decltype(f)>{});
+  template <
+      typename Method,
+      std::enable_if_t<
+          std::is_member_function_pointer<std::decay_t<Method>>::value,
+          bool> = false>
+  class_& def(std::string name, Method&& m) {
+    auto res = def_(
+        std::move(name),
+        std::forward<Method>(m),
+        detail::args_t<std::remove_reference_t<decltype(m)>>{});
+    return *this;
+  }
+  template <
+      typename Func,
+      std::enable_if_t<
+          !std::is_member_function_pointer<std::decay_t<Func>>::value,
+          bool> = false>
+  class_& def(std::string name, Func&& f) {
+    auto res = def_(
+        std::move(name),
+        std::forward<Func>(f),
+        detail::args_t<std::remove_reference_t<decltype(&Func::operator())>>{});
     return *this;
   }
 
@@ -145,12 +165,43 @@
     auto method = classCU()->create_function(qualClassName + "." + name, graph);
     classTypePtr->addMethod(method);
   }
-  template <typename Func, typename R, typename... Types>
+
+  template <
+      typename Func,
+      typename R,
+      typename... Types,
+      std::enable_if_t<
+          std::is_member_function_pointer<std::decay_t<Func>>::value,
+          bool> = false>
   class_& def_(std::string name, Func f, detail::types<R, Types...> funcInfo) {
-    auto func = [f](c10::intrusive_ptr<CurClass> cur, Types... args) {
+    auto func = [f = std::move(f)](
+                    c10::intrusive_ptr<CurClass> cur, Types... args) {
       return at::guts::invoke(f, *cur, args...);
     };
-    defineMethod<R>(name, std::move(func));
+    defineMethod<R>(std::move(name), std::move(func));
+    return *this;
+  }
+
+  template <typename R, typename Head, typename... Tail>
+  void assert_self_type(detail::types<R, Head, Tail...> funcInfo) {
+    static_assert(
+        std::is_same<std::decay_t<Head>, c10::intrusive_ptr<CurClass>>::value,
+        "First argument of a registered lambda method must be an intrusive_ptr<> of the corresponding class.");
+  }
+
+  template <
+      typename Func,
+      typename R,
+      typename... Types,
+      std::enable_if_t<
+          !std::is_member_function_pointer<std::decay_t<Func>>::value,
+          bool> = false>
+  class_& def_(
+      std::string name,
+      Func&& f,
+      detail::types<R, Types...> funcInfo) {
+    assert_self_type(funcInfo);
+    defineMethod<R>(std::move(name), std::forward<Func>(f));
     return *this;
   }
 };