When doing typed typecheck, also check signature with symint removed (#109727)

See the test case for what we didn't catch (SymInt vs const SymInt&
mismatch.)

It's necessary to test for both, because we will fall back to the
non-SymInt signature if there is no SymInt unboxed kernel available.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109727
Approved by: https://github.com/zou3519
diff --git a/.github/ci_commit_pins/fbgemm.txt b/.github/ci_commit_pins/fbgemm.txt
index ce0b3c1..2a0601e 100644
--- a/.github/ci_commit_pins/fbgemm.txt
+++ b/.github/ci_commit_pins/fbgemm.txt
@@ -1 +1 @@
-1b2746f642cc2c99fe9d1a0c34359c0de45341c2
+0346155d7f15fbe8be72687e665078edbe1ca5aa
diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h
index f1bfc9e..d8d0a3d 100644
--- a/aten/src/ATen/core/boxing/KernelFunction.h
+++ b/aten/src/ATen/core/boxing/KernelFunction.h
@@ -18,10 +18,10 @@
 template <typename T>
 using has_symint =
   guts::disjunction<
-    std::is_same<c10::SymInt, std::decay_t<T>>,
-    std::is_same<c10::SymIntArrayRef, std::decay_t<T>>,
-    std::is_same<at::OptionalSymIntArrayRef, std::decay_t<T>>,
-    std::is_same<c10::optional<c10::SymInt>, std::decay_t<T>>
+    std::is_same<c10::SymInt, T>,
+    std::is_same<c10::SymIntArrayRef, T>,
+    std::is_same<at::OptionalSymIntArrayRef, T>,
+    std::is_same<c10::optional<c10::SymInt>, T>
   >;
 
 template <typename T>
@@ -65,6 +65,14 @@
   typename guts::infer_function_traits<T>::type::parameter_types
 >;
 
+template <typename T>
+struct fn_remove_symint;
+
+template <typename Ret, typename... Args>
+struct fn_remove_symint<Ret(Args...)> {
+  using type = Ret(typename remove_symint<Args>::type...);
+};
+
 /**
  * KernelFunction is similar to std::function but stores a kernel function.
  * You can create a KernelFunction from a boxed or unboxed function/functor/lambda
diff --git a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
index 81d0553..5308499 100644
--- a/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
+++ b/aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h
@@ -179,10 +179,6 @@
       "You tried to register a kernel with an unsupported input type: std::array<Scalar, N>. Please use std::array<int64_t, N> instead.");
   };
 
-  // The following specialisations of assert_is_valid_input_type are technically not
-  // necessary since we would hit the base case and show an error message
-  // there if they didn't exist, but we can show a better error message
-  // in some common error scenarios.
   template<class T, bool AllowDeprecatedTypes>
   struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<float, T>::value>> {
     // There is no reason to support float when we have double. Keep the API lean.
@@ -204,6 +200,14 @@
     static_assert(guts::false_t<T>::value,
       "You tried to register a kernel with an unsupported integral input type. Please use int64_t instead.");
   };
+  template<class T, bool AllowDeprecatedTypes>
+  struct assert_is_valid_input_type<T, AllowDeprecatedTypes, std::enable_if_t<std::is_same<const c10::SymInt&, T>::value>> {
+    static_assert(guts::false_t<T>::value,
+      "You tried to register a kernel taking c10::SymInt by reference. Please accept it by value instead.");
+  };
+
+  // TODO: it probably would be good to tighten this up quite a bit more with
+  // an explicit list for everything
 
   //
   // assert_is_valid_output_type
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index 04bfade..cc3ddf7 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -438,6 +438,9 @@
     // will be done by the time a typed() handle is acquired.
 #if !defined C10_MOBILE
     operatorDef_->op.assertSignatureIsCorrect<FuncType>();
+    if (fn_has_symint<FuncType>::value) {
+      operatorDef_->op.assertSignatureIsCorrect<typename fn_remove_symint<FuncType>::type>();
+    }
 #endif
     return TypedOperatorHandle<FuncType>(operatorIterator_);
   }
diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp
index ea44155..a1c9c63 100644
--- a/aten/src/ATen/core/op_registration/op_registration_test.cpp
+++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp
@@ -2157,6 +2157,70 @@
   ASSERT_TRUE(std::includes(all_ops.begin(), all_ops.end(), cpu_ops.begin(), cpu_ops.end(), cmp_lambda));
 }
 
+Tensor symint_op(const Tensor& self, int64_t length) {
+  return self.clone();
+}
+
+TEST(OperatorRegistrationTest, TestSymNonSymCompatibility) {
+  auto m = MAKE_TORCH_LIBRARY(_test);
+  m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
+  auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
+  m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op));
+
+  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
+      "_test::symint_op", "");
+
+  opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
+  opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+
+  expectThrows<c10::Error>([&] {
+    opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+  }, "Tried to access or call an operator with a wrong signature");
+}
+
+Tensor symint_op2(const Tensor& self, c10::SymInt length) {
+  return self.clone();
+}
+
+TEST(OperatorRegistrationTest, TestSymSymCompatibility) {
+  auto m = MAKE_TORCH_LIBRARY(_test);
+  m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
+  auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
+  m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op2));
+
+  auto opHandle = c10::Dispatcher::singleton().findSchemaOrThrow(
+      "_test::symint_op", "");
+
+  opHandle.typed<Tensor(const Tensor&, int64_t)>().call(dummyTensor(c10::DispatchKey::CPU), 4);
+  opHandle.typed<Tensor(const Tensor&, c10::SymInt)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+  // TODO: We should reject this on principle, but today it accidentally works
+  // due to going through the boxed calling convention.
+  //
+  // First, we attempt to test if const SymInt& has SymInt. It does not,
+  // because we only accept something as SymInt if it has exactly SymInt in
+  // its signature. So we check if there is a non-symint kernel. But there is
+  // no non-SymInt kernel, because we only registered a real SymInt kernel.
+  // When this occurs, we fall back to the boxed calling convention.  And the
+  // boxed calling convention can deal with const SymInt& fine, as during
+  // boxing it will just create a SymInt to push onto the argument stack and
+  // everything is fine.
+  opHandle.typed<Tensor(const Tensor&, const c10::SymInt&)>().call(dummyTensor(c10::DispatchKey::CPU), c10::SymInt(4));
+}
+
+Tensor symint_op3(const Tensor& self, const c10::SymInt& length) {
+  return self.clone();
+}
+
+TEST(OperatorRegistrationTest, TestSymSymRefCompatibility) {
+  auto m = MAKE_TORCH_LIBRARY(_test);
+  m.def("_test::symint_op(Tensor self, SymInt length) -> Tensor");
+  auto m_cpu = MAKE_TORCH_LIBRARY_IMPL(_test, CPU);
+
+  expectThrows<c10::Error>([&] {
+    m_cpu.impl("symint_op", c10::DispatchKey::CPU, TORCH_FN(symint_op3));
+  }, "doesn't match the expected function schema");
+}
+
 }
 
 #pragma GCC diagnostic pop