Custom op autograd tests (#30519)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30519
Re-enable them and write a few additional ones
ghstack-source-id: 95143051
Test Plan: unit tests
Differential Revision: D18729561
fbshipit-source-id: 8cefd8320913d72a450a3324bfd7c88faed072d7
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index 57c1092..7df1948 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -764,42 +764,85 @@
}
bool called_autograd = false;
-bool called_catchall = false;
+bool called_nonautograd = false;
-void catchall_kernel(Tensor a) {
- called_catchall = true;
+void nonautograd_kernel(Tensor a) {
+ called_nonautograd = true;
}
void autograd_kernel(Tensor a) {
called_autograd = true;
}
-// TODO Reenable these
-// TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
-// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
-// .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
-//
-// auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
-// ASSERT_TRUE(op.has_value());
-//
-// called_autograd = false;
-// c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::VariableTensorId));
-// EXPECT_TRUE(called_autograd);
-// }
-//
-// TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
-// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
-// .impl_unboxedOnlyCatchAllKernel<decltype(catchall_kernel), &catchall_kernel>()
-// .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
-//
-// auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
-// ASSERT_TRUE(op.has_value());
-//
-// called_catchall = called_autograd = false;
-// c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::VariableTensorId));
-// EXPECT_FALSE(called_catchall);
-// EXPECT_TRUE(called_autograd);
-// }
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) {
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+ .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+ auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+ ASSERT_TRUE(op.has_value());
+
+ called_autograd = false;
+ c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set
+ EXPECT_TRUE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) {
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+ .impl_unboxedOnlyKernel<decltype(nonautograd_kernel), nonautograd_kernel>(TensorTypeId::CPUTensorId)
+ .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+ auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+ ASSERT_TRUE(op.has_value());
+
+ called_nonautograd = called_autograd = false;
+ c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set
+ EXPECT_FALSE(called_nonautograd);
+ EXPECT_TRUE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) {
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+ .impl_unboxedOnlyKernel<decltype(nonautograd_kernel), nonautograd_kernel>(TensorTypeId::CPUTensorId)
+ .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+ auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+ ASSERT_TRUE(op.has_value());
+
+ called_nonautograd = called_autograd = false;
+ at::AutoNonVariableTypeMode _var_guard(true);
+ c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId));
+ EXPECT_TRUE(called_nonautograd);
+ EXPECT_FALSE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallAutogradKernel) {
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+ .impl_unboxedOnlyCatchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
+ .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+ auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+ ASSERT_TRUE(op.has_value());
+
+ called_nonautograd = called_autograd = false;
+ c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set
+ EXPECT_FALSE(called_nonautograd);
+ EXPECT_TRUE(called_autograd);
+}
+
+TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) {
+ auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options()
+ .impl_unboxedOnlyCatchAllKernel<decltype(nonautograd_kernel), nonautograd_kernel>()
+ .impl_unboxedOnlyKernel<decltype(autograd_kernel), &autograd_kernel>(TensorTypeId::VariableTensorId));
+
+ auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
+ ASSERT_TRUE(op.has_value());
+
+ called_nonautograd = called_autograd = false;
+ at::AutoNonVariableTypeMode _var_guard(true);
+ c10::Dispatcher::singleton().callUnboxed<void, Tensor>(*op, dummyTensor(TensorTypeId::CPUTensorId));
+ EXPECT_TRUE(called_nonautograd);
+ EXPECT_FALSE(called_autograd);
+}
/**
* This is used to check that a given type works correctly when passed as input