Custom op autograd tests (#30519)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30519

Re-enable them and write a few additional ones
ghstack-source-id: 95143051

Test Plan: unit tests

Differential Revision: D18729561

fbshipit-source-id: 8cefd8320913d72a450a3324bfd7c88faed072d7
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index 57c1092..7df1948 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -764,42 +764,85 @@
 }
 
 bool called_autograd = false;
-bool called_catchall = false;
+bool called_nonautograd = false;
 
-void catchall_kernel(Tensor a) {
-  called_catchall = true;
+void nonautograd_kernel(Tensor a) {
+  called_nonautograd = true;
 }
 
 void autograd_kernel(Tensor a) {
   called_autograd = true;
 }
 
-// TODO Reenable these
-// TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
-//   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
-//     .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
-//
-//   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
-//   ASSERT_TRUE(op.has_value());
-//
-//   called_autograd = false;
-//   c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::VariableTensorId));
-//   EXPECT_TRUE(called_autograd);
-// }
-//
-// TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
-//   auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
-//     .impl_unboxedOnlyCatchAllKernel<decltype(catchall_kernel), &catchall_kernel>()
-//     .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
-//
-//   auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
-//   ASSERT_TRUE(op.has_value());
-//
-//   called_catchall = called_autograd = false;
-//   c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::VariableTensorId));
-//   EXPECT_FALSE(called_catchall);
-//   EXPECT_TRUE(called_autograd);
-// }
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
+  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+    .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+  ASSERT_TRUE(op.has_value());
+
+  called_autograd = false;
+  c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set
+  EXPECT_TRUE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
+  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+    .impl_unboxedOnlyKernel<decltype(nonautograd_kernel), nonautograd_kernel>(TensorTypeId::CPUTensorId)
+    .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+  ASSERT_TRUE(op.has_value());
+
+  called_nonautograd = called_autograd = false;
+  c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set
+  EXPECT_FALSE(called_nonautograd);
+  EXPECT_TRUE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
+  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+    .impl_unboxedOnlyKernel<decltype(nonautograd_kernel), nonautograd_kernel>(TensorTypeId::CPUTensorId)
+    .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+  ASSERT_TRUE(op.has_value());
+
+  called_nonautograd = called_autograd = false;
+  at::AutoNonVariableTypeMode _var_guard(true);
+  c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId));
+  EXPECT_TRUE(called_nonautograd);
+  EXPECT_FALSE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallAutogradKernel) {
+  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+    .impl_unboxedOnlyCatchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
+    .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+  ASSERT_TRUE(op.has_value());
+
+  called_nonautograd = called_autograd = false;
+  c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId));  // note: all tensors have VariableTypeId set
+  EXPECT_FALSE(called_nonautograd);
+  EXPECT_TRUE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) {
+  auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+    .impl_unboxedOnlyCatchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
+    .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+  auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+  ASSERT_TRUE(op.has_value());
+
+  called_nonautograd = called_autograd = false;
+  at::AutoNonVariableTypeMode _var_guard(true);
+  c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId));
+  EXPECT_TRUE(called_nonautograd);
+  EXPECT_FALSE(called_autograd);
+}
 
 /**
  * This is used to check that a given type works correctly when passed as input