Enable registering stackbased kernels with lambdas (#26658)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26658
By SFINAE'ing the lambda registration to only kernels that aren't stackbased kernels,
an attempt to register a stackbased lambda kernel will correctly fallback to the stackbased registration function and work as expected.
ghstack-source-id: 90610843
Test Plan: unit tests
Differential Revision: D17533871
fbshipit-source-id: 1bfe3106b0576d46798a51bdaa5b7b5508164766
diff --git a/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp b/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp
index 5745f1c..e5213e8 100644
--- a/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp
+++ b/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp
@@ -62,6 +62,26 @@
expectCallsIncrement(TensorTypeId::CPUTensorId);
}
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredAsLambda_thenCanBeCalled) {
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId,
+ [] (OperatorKernel*, Stack* stack) {
+ int input = torch::jit::pop(*stack).toInt();
+ torch::jit::pop(*stack); // pop the dummy tensor
+ torch::jit::push(*stack, input + 1);
+ }));
+ expectCallsIncrement(TensorTypeId::CPUTensorId);
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenCatchAllKernel_whenRegisteredAsLambda_thenCanBeCalled) {
+ auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().catchAllKernel(
+ [] (OperatorKernel*, Stack* stack) {
+ int input = torch::jit::pop(*stack).toInt();
+ torch::jit::pop(*stack); // pop the dummy tensor
+ torch::jit::push(*stack, input + 1);
+ }));
+ expectCallsIncrement(TensorTypeId::CPUTensorId);
+}
+
TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
auto registrar = RegisterOperators()
.op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, &incrementKernel))
diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h
index 41450d8..04fee09 100644
--- a/aten/src/ATen/core/op_registration/op_registration.h
+++ b/aten/src/ATen/core/op_registration/op_registration.h
@@ -325,7 +325,10 @@
*/
template<class Lambda>
// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
- guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value, Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && {
+ guts::enable_if_t<
+ guts::is_functor<guts::decay_t<Lambda>>::value
+ && !std::is_same<typename guts::infer_function_traits_t<guts::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
+ Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && {
static_assert(!std::is_base_of<OperatorKernel, guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
// We don't support stateful lambdas (i.e. lambdas with a capture), because their
@@ -362,7 +365,10 @@
*/
template<class Lambda>
// enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
- guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value, Options&&> catchAllKernel(Lambda&& lambda) && {
+ guts::enable_if_t<
+ guts::is_functor<guts::decay_t<Lambda>>::value
+ && !std::is_same<typename guts::infer_function_traits_t<guts::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
+ Options&&> catchAllKernel(Lambda&& lambda) && {
static_assert(!std::is_base_of<OperatorKernel, guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
// We don't support stateful lambdas (i.e. lambdas with a capture), because their