Enable registering stackbased kernels with lambdas (#26658)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26658

By SFINAE'ing the lambda registration to only kernels that aren't stackbased kernels,
an attempt to register a stackbased lambda kernel will correctly fallback to the stackbased registration function and work as expected.
ghstack-source-id: 90610843

Test Plan: unit tests

Differential Revision: D17533871

fbshipit-source-id: 1bfe3106b0576d46798a51bdaa5b7b5508164766
diff --git a/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp b/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp
index 5745f1c..e5213e8 100644
--- a/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp
+++ b/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp
@@ -62,6 +62,26 @@
   expectCallsIncrement(TensorTypeId::CPUTensorId);
 }
 
+TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredAsLambda_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId,
+    [] (OperatorKernel*, Stack* stack) {
+      int input = torch::jit::pop(*stack).toInt();
+      torch::jit::pop(*stack); // pop the dummy tensor
+      torch::jit::push(*stack, input + 1);
+    }));
+  expectCallsIncrement(TensorTypeId::CPUTensorId);
+}
+
+TEST(OperatorRegistrationTest_StackBasedKernel, givenCatchAllKernel_whenRegisteredAsLambda_thenCanBeCalled) {
+  auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().catchAllKernel(
+    [] (OperatorKernel*, Stack* stack) {
+      int input = torch::jit::pop(*stack).toInt();
+      torch::jit::pop(*stack); // pop the dummy tensor
+      torch::jit::push(*stack, input + 1);
+    }));
+  expectCallsIncrement(TensorTypeId::CPUTensorId);
+}
+
 TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) {
   auto registrar = RegisterOperators()
       .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, &incrementKernel))
diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h
index 41450d8..04fee09 100644
--- a/aten/src/ATen/core/op_registration/op_registration.h
+++ b/aten/src/ATen/core/op_registration/op_registration.h
@@ -325,7 +325,10 @@
      */
     template<class Lambda>
     // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
-    guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value, Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && {
+    guts::enable_if_t<
+        guts::is_functor<guts::decay_t<Lambda>>::value
+        && !std::is_same<typename guts::infer_function_traits_t<guts::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
+        Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && {
       static_assert(!std::is_base_of<OperatorKernel, guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
 
       // We don't support stateful lambdas (i.e. lambdas with a capture), because their
@@ -362,7 +365,10 @@
      */
     template<class Lambda>
     // enable_if: only enable it if Lambda is a functor (note: lambdas are functors)
-    guts::enable_if_t<guts::is_functor<guts::decay_t<Lambda>>::value, Options&&> catchAllKernel(Lambda&& lambda) && {
+    guts::enable_if_t<
+        guts::is_functor<guts::decay_t<Lambda>>::value
+        && !std::is_same<typename guts::infer_function_traits_t<guts::decay_t<Lambda>>::func_type, KernelFunction::BoxedKernelFunction>::value,
+        Options&&> catchAllKernel(Lambda&& lambda) && {
       static_assert(!std::is_base_of<OperatorKernel, guts::decay_t<Lambda>>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel<Functor>() API instead.");
 
       // We don't support stateful lambdas (i.e. lambdas with a capture), because their