When doing typed typecheck, also check signature with symint removed (#109727)
See the test case for what we didn't catch (SymInt vs const SymInt&
mismatch.)
It's necessary to test for both, because we will fall back to the
non-SymInt signature if there is no SymInt unboxed kernel available.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109727
Approved by: https://github.com/zou3519
diff --git a/.github/ci_commit_pins/fbgemm.txt b/.github/ci_commit_pins/fbgemm.txt
index ce0b3c1..2a0601e 100644
--- a/.github/ci_commit_pins/fbgemm.txt
+++ b/.github/ci_commit_pins/fbgemm.txt
@@ -1 +1 @@
-1b2746f642cc2c99fe9d1a0c34359c0de45341c2
+0346155d7f15fbe8be72687e665078edbe1ca5aa
diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h
index f1bfc9e..d8d0a3d 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.h
+++ b/aten/src/ATen/core/boxing/KernelFunction.h
@@ -18,10 +18,10 @@
template <typename T>
using has_symint =
guts::disjunction<
- std::is_same<c10::SymInt, std::decay_t<T>>,
- std::is_same<c10::SymIntArrayRef, std::decay_t<T>>,
- std::is_same<at::OptionalSymIntArrayRef, std::decay_t<T>>,
- std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>>
+ std::is_same<c10::SymInt, T>,
+ std::is_same<c10::SymIntArrayRef, T>,
+ std::is_same<at::OptionalSymIntArrayRef, T>,
+ std::is_same<c10::optional<c10::SymInt>, T>
>;
template <typename T>
@@ -65,6 +65,14 @@
typename guts::infer_function_traits<T>::type::parameter_types
>;
+template <typename T>
+struct fn_remove_symint;
+
+template <typename Ret, typename... Args>
+struct fn_remove_symint<Ret(Args...)> {
+ using type = Ret(typename remove_symint<Args>::type...);
+};
+
/**
* KernelFunction is similar to std::function but stores a kernel function.
* You can create a KernelFunction from a boxed or unboxed function/functor/lambda
diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
index 81d0553..5308499 100644
--- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
+++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
@@ -179,10 +179,6 @@
"You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
};
- // The following specialisations of assert_is_valid_input_type are technically not
- // necessary since we would hit the base case and show an error message
- // there if they didn't exist, but we can show a better error message
- // in some common error scenarios.
template<class T, bool AllowDeprecatedTypes>
struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> {
// There is no reason to support float when we have double. Keep the API lean.
@@ -204,6 +200,14 @@
static_assert(guts::false_t<T>::value,
"You tried to register a kernel with an unsupported integral input type. Please use int64_t instead.");
};
+ template<class T, bool AllowDeprecatedTypes>
+ struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const c10::SymInt&, T>::value>> {
+ static_assert(guts::false_t<T>::value,
+ "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
+ };
+
+ // TODO: it probably would be good to tighten this up quite a bit more with
+ // an explicit list for everything
//
// assert_is_valid_output_type
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 04bfade..cc3ddf7 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -438,6 +438,9 @@
// will be done by the time a typed() handle is acquired.
#if !defined C10_MOBILE
operatorDef_->op.assertSignatureIsCorrect<FuncType>();
+ if (fn_has_symint<FuncType>::value) {
+ operatorDef_->op.assertSignatureIsCorrect<typename fn_remove_symint<FuncType>::type>();
+ }
#endif
return TypedOperatorHandle<FuncType>(operatorIterator_);
}
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index ea44155..a1c9c63 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -2157,6 +2157,70 @@
ASSERT_TRUE(std::includes(all_ops.begin(), all_ops.end(), cpu_ops.begin(), cpu_ops.end(), cmp_lambda));
}
+Tensor symint_op(const Tensor& self, int64_t length) {
+ return self.clone();
+}
+
+TEST(OperatorRegistrationTest, TestSymNonSymCompatibility) {
+ auto m = MAKE_TORCH_LIBRARY(_test);
+ m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
+ auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
+ m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op));
+
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
+ "_test::symint_op", "");
+
+ opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
+ opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+
+ expectThrows<c10::Error>([&] {
+ opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+ }, "Tried to access or call an operator with a wrong signature");
+}
+
+Tensor symint_op2(const Tensor& self, c10::SymInt length) {
+ return self.clone();
+}
+
+TEST(OperatorRegistrationTest, TestSymSymCompatibility) {
+ auto m = MAKE_TORCH_LIBRARY(_test);
+ m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
+ auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
+ m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op2));
+
+ auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
+ "_test::symint_op", "");
+
+ opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
+ opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+ // TODO: We should reject this on principle, but today it accidentally works
+ // due to going through the boxed calling convention.
+ //
+ // First, we attempt to test if const SymInt& has SymInt. It does not,
+ // because we only accept something as SymInt if it has exactly SymInt in
+ // its signature. So we check if there is a non-symint kernel. But there is
+ // no non-SymInt kernel, because we only registered a real SymInt kernel.
+ // When this occurs, we fall back to the boxed calling convention. And the
+ // boxed calling convention can deal with const SymInt& fine, as during
+ // boxing it will just create a SymInt to push onto the argument stack and
+ // everything is fine.
+ opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+}
+
+Tensor symint_op3(const Tensor& self, const c10::SymInt& length) {
+ return self.clone();
+}
+
+TEST(OperatorRegistrationTest, TestSymSymRefCompatibility) {
+ auto m = MAKE_TORCH_LIBRARY(_test);
+ m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
+ auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
+
+ expectThrows<c10::Error>([&] {
+ m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op3));
+ }, "doesn't match the expected function schema");
+}
+
}
#pragma GCC diagnostic pop