[reland][CustomOp] Add Dispatcher error callback (#101452)

Reland of #101015, original stack reverted due to internal test
flakiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101452
Approved by: https://github.com/soulitzer
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index f4de0b4..f2b8584 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -13,6 +13,7 @@
 #include <mutex>
 #include <condition_variable>
 #include <type_traits>
+#include <c10/core/SafePyObject.h>
 
 #include <ATen/core/grad_mode.h>
 #include <ATen/core/enum_tag.h>
@@ -390,6 +391,10 @@
     return operatorDef_->op.getTags();
   }
 
+  void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
+    operatorDef_->op.setReportErrorCallback_(std::move(callback));
+  }
+
   bool hasTag(const at::Tag& tag) const {
     for(const auto& tag_: getTags()) {
       if (tag == tag_) {
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
index b3f7700..627109c 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
@@ -530,6 +530,11 @@
   // If there is an invariant problem, report it now.
   checkInvariants();
 
+  if (report_error_callback_ != nullptr) {
+    report_error_callback_->pyinterpreter()->reportErrorCallback(report_error_callback_->ptr(&report_error_callback_->pyinterpreter()), dispatchKey);
+    // reportErrorCallback should have raised an error
+    TORCH_INTERNAL_ASSERT(false);
+  }
   if (dispatchKey == DispatchKey::Undefined) {
     TORCH_CHECK_NOT_IMPLEMENTED(false,
           "There were no tensor arguments to this function (e.g., you passed an "
@@ -574,6 +579,10 @@
   return oss.str();
 }
 
+void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
+  report_error_callback_ = std::move(callback);
+}
+
 // Inspect the "canonical" information in OperatorEntry.  This only prints out
 // *non-derived* information including kernels registered to alias dispatch keys;
 // i.e., what the source of truth says about the operator.  This dumping function
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h
index ea6d53a..e3e6ac1 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.h
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.h
@@ -7,6 +7,7 @@
 #include <c10/util/Optional.h>
 #include <c10/core/DispatchKey.h>
 #include <c10/core/PyHandleCache.h>
+#include <c10/core/SafePyObject.h>
 #include <ATen/core/ivalue.h>
 #include <ATen/core/boxing/KernelFunction.h>
 #include <ATen/core/dispatch/DispatchKeyExtractor.h>
@@ -211,6 +212,7 @@
   bool hasComputedKernelForDispatchKey(DispatchKey k) const;
   // Returns all the operator tags added at the time of registration
   const std::vector<at::Tag>& getTags() const;
+  void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
 
   template <typename F>
   PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
@@ -286,6 +288,9 @@
   c10::optional<CppSignatureWithDebug> cpp_signature_;
   c10::optional<CppSignatureWithDebug> sym_cpp_signature_;
 
+  // A Python custom error handler for OperatorEntry::reportError
+  std::unique_ptr<c10::SafePyObject> report_error_callback_;
+
   // Whether this operator needs to be observed with RecordFunction
   const bool is_observed_;
 
diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp
index d574de0..63c432e 100644
--- a/c10/core/impl/PyInterpreter.cpp
+++ b/c10/core/impl/PyInterpreter.cpp
@@ -27,6 +27,10 @@
     PANIC(dispatch);
   }
 
+  void reportErrorCallback(PyObject* callback, DispatchKey key) const override {
+    PANIC(reportErrorCallback);
+  }
+
   void python_op_registration_trampoline(
       const c10::OperatorHandle& op,
       c10::DispatchKey,
diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h
index e329fbc..6326128 100644
--- a/c10/core/impl/PyInterpreter.h
+++ b/c10/core/impl/PyInterpreter.h
@@ -141,6 +141,9 @@
   virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
       const = 0;
 
+  virtual void reportErrorCallback(PyObject* callback, DispatchKey key)
+      const = 0;
+
   // This is only invoked in the multipy/torchdeploy situation from
   // pythonOpRegistrationTrampoline; this lets us get to the Python
   // interpreter to actually find the appropriate Python op registration
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 6270148..086622a 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -908,6 +908,26 @@
         self.assertTrue(f'{TestCustomOp.test_ns}.foo' in gm.code)
         foo._destroy()
 
+    def test_not_implemented_error(self):
+        @custom_op(f'{TestCustomOp.test_ns}::foo')
+        def foo(x: torch.Tensor) -> torch.Tensor:
+            ...
+
+        x = torch.randn(3)
+        with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
+            foo(x)
+
+        x = torch.randn(3, device='meta')
+        with self.assertRaisesRegex(NotImplementedError, "abstract impl registered"):
+            foo(x)
+
+        @custom_op(f'{TestCustomOp.test_ns}::bar')
+        def bar(sizes: Sequence[int]) -> torch.Tensor:
+            ...
+
+        with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
+            bar((1, 2, 3))
+
     def test_abstract_registration_location(self):
         loc = torch.testing._internal.custom_op_db.numpy_nonzero._get_impl('abstract').location
         matches = re.match(r'.*custom_op_db.py:\d+', loc)
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 046ffde..0bd3ddb 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1297,6 +1297,7 @@
 def _dispatch_check_all_invariants() -> None: ...
 def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ...
 def _dispatch_find_schema_or_throw(name: str, overload_name: str) -> _DispatchOperatorHandle: ...
+def _dispatch_set_report_error_callback(handle: _DispatchOperatorHandle, callback: Callable) -> None: ...
 def _dispatch_has_kernel(name: str) -> _bool: ...
 def _dispatch_has_kernel_for_dispatch_key(
     name: str,
diff --git a/torch/_custom_op.py b/torch/_custom_op.py
index 71f090e..b845d9d 100644
--- a/torch/_custom_op.py
+++ b/torch/_custom_op.py
@@ -157,6 +157,10 @@
             get_autograd_not_implemented_kernel(weakref.proxy(result))
         )
 
+        torch._C._dispatch_set_report_error_callback(
+            ophandle, functools.partial(report_error_callback, weakref.proxy(result))
+        )
+
         return result
 
     return inner
@@ -786,3 +790,33 @@
 
 
 SUPPORTED_PARAM_TYPES = get_supported_param_types()
+
+
+def report_error_callback(custom_op: typing.Any, key: str) -> None:
+    if key == "Undefined":
+        raise NotImplementedError(
+            f"{custom_op}: There were no Tensor inputs to this operator "
+            f"(e.g. you passed an empty list of Tensors). If your operator is a "
+            f"factory function (that is, it takes no Tensors and constructs "
+            f"a new one), then please use CustomOp.impl_factory to register "
+            f"an implementation for it"
+        )
+    if key == "Meta":
+        raise NotImplementedError(
+            f"{custom_op}: when running with device='Meta' tensors: there is no "
+            f"abstract impl registered for this CustomOp. Please register one via "
+            f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
+        )
+    if key in ("CPU", "CUDA"):
+        device = key.lower()
+        raise NotImplementedError(
+            f"{custom_op}: when running with device='{device}' tensors: there is no "
+            f"{device} impl registered for this CustomOp. Please register one via "
+            f"CustomOp.impl(device_type='{device}')"
+        )
+    raise NotImplementedError(
+        f"{custom_op}: No implementation for dispatch key {key}. It is likely "
+        f"that we have not added this functionality yet, please either open an "
+        f"issue or if you're feeling adventurous, use the low-level "
+        f"torch.library API"
+    )
diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp
index a01b1d3..709175f 100644
--- a/torch/csrc/PyInterpreter.cpp
+++ b/torch/csrc/PyInterpreter.cpp
@@ -44,6 +44,7 @@
 
   void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
       const override;
+  void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
   void python_dispatcher(
       const c10::OperatorHandle& op,
       c10::DispatchKeySet,
@@ -254,6 +255,16 @@
   return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
 }
 
+void ConcretePyInterpreterVTable::reportErrorCallback(
+    PyObject* callback,
+    DispatchKey key) const {
+  py::gil_scoped_acquire g;
+  auto func = py::reinterpret_borrow<py::object>(callback);
+  // Not all DispatchKeys are pybind'ed into Python and we do not have infra
+  // to ensure this, so just pass a string back to Python.
+  func(c10::toString(key));
+}
+
 void ConcretePyInterpreterVTable::dispatch(
     const c10::OperatorHandle& op,
     torch::jit::Stack* stack) const {
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 852710e..51350e5 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -688,6 +688,14 @@
         return names;
       },
       py::arg("dispatch_key") = static_cast<const char*>(""));
+  m.def(
+      "_dispatch_set_report_error_callback",
+      [](c10::OperatorHandle& handle, py::object callback) {
+        auto obj = callback.release().ptr();
+        auto callback_obj =
+            std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
+        handle.setReportErrorCallback_(std::move(callback_obj));
+      });
 
   m.def(
       "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });