[JIT] Support for registering C++ lambdas as methods on custom C++ class
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/32553
Test Plan: Imported from OSS
Differential Revision: D19543269
Pulled By: jamesr66a
fbshipit-source-id: 7e566650295e9d1c4f2f716470e061308a6210a0
diff --git a/test/cpp/jit/test_custom_class.cpp b/test/cpp/jit/test_custom_class.cpp
index ff77296..f1b46b54 100644
--- a/test/cpp/jit/test_custom_class.cpp
+++ b/test/cpp/jit/test_custom_class.cpp
@@ -85,8 +85,16 @@
.def("merge", &Stack<std::string>::merge)
.def("__getstate__", &Stack<std::string>::__getstate__)
.def("__setstate__", &Stack<std::string>::__setstate__)
- .def("return_a_tuple", &Stack<std::string>::return_a_tuple);
-
+ .def("return_a_tuple", &Stack<std::string>::return_a_tuple)
+ .def(
+ "top",
+ [](const c10::intrusive_ptr<Stack<std::string>>& self)
+ -> std::string { return self->stack_.back(); });
+// clang-format off
+ // The following will fail with a static assert telling you you have to
+ // take an intrusive_ptr<Stack> as the first argument.
+ // .def("foo", [](int64_t a) -> int64_t{ return 3;});
+// clang-format on
} // namespace
} // namespace jit
diff --git a/test/test_jit.py b/test/test_jit.py
index 4743bae..26156d1 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -4990,6 +4990,16 @@
x = torch.rand(3, 4)
self.assertEqual(scripted(x), eic(x))
+ @skipIfRocm
+ @unittest.skipIf(IS_WINDOWS, "TODO: Fix this test case")
+ def test_torchbind_lambda_method(self):
+ def foo():
+ ss = torch.classes._TorchScriptTesting_StackString(["mom"])
+ return ss.top()
+
+ scripted = torch.jit.script(foo)
+ self.assertEqual(scripted(), "mom")
+
def test_jitter_bug(self):
@torch.jit.script
def fn2(input, kernel_size):
diff --git a/torch/custom_class.h b/torch/custom_class.h
index 3d8fe85..c052063 100644
--- a/torch/custom_class.h
+++ b/torch/custom_class.h
@@ -9,6 +9,7 @@
#include <c10/util/C++17.h>
#include <c10/util/Metaprogramming.h>
#include <c10/util/TypeList.h>
+#include <c10/util/TypeTraits.h>
#include <torch/csrc/jit/custom_class.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/script/compilation_unit.h>
@@ -101,9 +102,28 @@
defineMethod<void>("__init__", std::move(func));
return *this;
}
- template <typename Func>
- class_& def(std::string name, Func f) {
- auto res = def_(name, f, detail::args_t<decltype(f)>{});
+ template <
+ typename Method,
+ std::enable_if_t<
+ std::is_member_function_pointer<std::decay_t<Method>>::value,
+ bool> = false>
+ class_& def(std::string name, Method&& m) {
+ auto res = def_(
+ std::move(name),
+ std::forward<Method>(m),
+ detail::args_t<std::remove_reference_t<decltype(m)>>{});
+ return *this;
+ }
+ template <
+ typename Func,
+ std::enable_if_t<
+ !std::is_member_function_pointer<std::decay_t<Func>>::value,
+ bool> = false>
+ class_& def(std::string name, Func&& f) {
+ auto res = def_(
+ std::move(name),
+ std::forward<Func>(f),
+ detail::args_t<std::remove_reference_t<decltype(&Func::operator())>>{});
return *this;
}
@@ -145,12 +165,43 @@
auto method = classCU()->create_function(qualClassName + "." + name, graph);
classTypePtr->addMethod(method);
}
- template <typename Func, typename R, typename... Types>
+
+ template <
+ typename Func,
+ typename R,
+ typename... Types,
+ std::enable_if_t<
+ std::is_member_function_pointer<std::decay_t<Func>>::value,
+ bool> = false>
class_& def_(std::string name, Func f, detail::types<R, Types...> funcInfo) {
- auto func = [f](c10::intrusive_ptr<CurClass> cur, Types... args) {
+ auto func = [f = std::move(f)](
+ c10::intrusive_ptr<CurClass> cur, Types... args) {
return at::guts::invoke(f, *cur, args...);
};
- defineMethod<R>(name, std::move(func));
+ defineMethod<R>(std::move(name), std::move(func));
+ return *this;
+ }
+
+ template <typename R, typename Head, typename... Tail>
+ void assert_self_type(detail::types<R, Head, Tail...> funcInfo) {
+ static_assert(
+ std::is_same<std::decay_t<Head>, c10::intrusive_ptr<CurClass>>::value,
+ "First argument of a registered lambda method must be an intrusive_ptr<> of the corresponding class.");
+ }
+
+ template <
+ typename Func,
+ typename R,
+ typename... Types,
+ std::enable_if_t<
+ !std::is_member_function_pointer<std::decay_t<Func>>::value,
+ bool> = false>
+ class_& def_(
+ std::string name,
+ Func&& f,
+ detail::types<R, Types...> funcInfo) {
+ assert_self_type(funcInfo);
+ defineMethod<R>(std::move(name), std::forward<Func>(f));
return *this;
}
};