[reland][CustomOp] Add Dispatcher error callback (#101452)
Reland of #101015, original stack reverted due to internal test
flakiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101452
Approved by: https://github.com/soulitzer
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index f4de0b4..f2b8584 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -13,6 +13,7 @@
#include <mutex>
#include <condition_variable>
#include <type_traits>
+#include <c10/core/SafePyObject.h>
#include <ATen/core/grad_mode.h>
#include <ATen/core/enum_tag.h>
@@ -390,6 +391,10 @@
return operatorDef_->op.getTags();
}
+ void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
+ operatorDef_->op.setReportErrorCallback_(std::move(callback));
+ }
+
bool hasTag(const at::Tag& tag) const {
for(const auto& tag_: getTags()) {
if (tag == tag_) {
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
index b3f7700..627109c 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp
@@ -530,6 +530,11 @@
// If there is an invariant problem, report it now.
checkInvariants();
+ if (report_error_callback_ != nullptr) {
+ report_error_callback_->pyinterpreter()->reportErrorCallback(report_error_callback_->ptr(&report_error_callback_->pyinterpreter()), dispatchKey);
+ // reportErrorCallback should have raised an error
+ TORCH_INTERNAL_ASSERT(false);
+ }
if (dispatchKey == DispatchKey::Undefined) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"There were no tensor arguments to this function (e.g., you passed an "
@@ -574,6 +579,10 @@
return oss.str();
}
+void OperatorEntry::setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback) {
+ report_error_callback_ = std::move(callback);
+}
+
// Inspect the "canonical" information in OperatorEntry. This only prints out
// *non-derived* information including kernels registered to alias dispatch keys;
// i.e., what the source of truth says about the operator. This dumping function
diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h
index ea6d53a..e3e6ac1 100644
--- a/aten/src/ATen/core/dispatch/OperatorEntry.h
+++ b/aten/src/ATen/core/dispatch/OperatorEntry.h
@@ -7,6 +7,7 @@
#include <c10/util/Optional.h>
#include <c10/core/DispatchKey.h>
#include <c10/core/PyHandleCache.h>
+#include <c10/core/SafePyObject.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/DispatchKeyExtractor.h>
@@ -211,6 +212,7 @@
bool hasComputedKernelForDispatchKey(DispatchKey k) const;
// Returns all the operator tags added at the time of registration
const std::vector<at::Tag>& getTags() const;
+ void setReportErrorCallback_(std::unique_ptr<c10::SafePyObject> callback);
template <typename F>
PyObject* getPythonOp(PyInterpreter* self_interpreter, F slow_accessor) const {
@@ -286,6 +288,9 @@
c10::optional<CppSignatureWithDebug> cpp_signature_;
c10::optional<CppSignatureWithDebug> sym_cpp_signature_;
+ // A Python custom error handler for OperatorEntry::reportError
+ std::unique_ptr<c10::SafePyObject> report_error_callback_;
+
// Whether this operator needs to be observed with RecordFunction
const bool is_observed_;
diff --git a/c10/core/impl/PyInterpreter.cpp b/c10/core/impl/PyInterpreter.cpp
index d574de0..63c432e 100644
--- a/c10/core/impl/PyInterpreter.cpp
+++ b/c10/core/impl/PyInterpreter.cpp
@@ -27,6 +27,10 @@
PANIC(dispatch);
}
+ void reportErrorCallback(PyObject* callback, DispatchKey key) const override {
+ PANIC(reportErrorCallback);
+ }
+
void python_op_registration_trampoline(
const c10::OperatorHandle& op,
c10::DispatchKey,
diff --git a/c10/core/impl/PyInterpreter.h b/c10/core/impl/PyInterpreter.h
index e329fbc..6326128 100644
--- a/c10/core/impl/PyInterpreter.h
+++ b/c10/core/impl/PyInterpreter.h
@@ -141,6 +141,9 @@
virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const = 0;
+ virtual void reportErrorCallback(PyObject* callback, DispatchKey key)
+ const = 0;
+
// This is only invoked in the multipy/torchdeploy situation from
// pythonOpRegistrationTrampoline; this lets us get to the Python
// interpreter to actually find the appropriate Python op registration
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 6270148..086622a 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -908,6 +908,26 @@
self.assertTrue(f'{TestCustomOp.test_ns}.foo' in gm.code)
foo._destroy()
+ def test_not_implemented_error(self):
+ @custom_op(f'{TestCustomOp.test_ns}::foo')
+ def foo(x: torch.Tensor) -> torch.Tensor:
+ ...
+
+ x = torch.randn(3)
+ with self.assertRaisesRegex(NotImplementedError, "cpu impl registered"):
+ foo(x)
+
+ x = torch.randn(3, device='meta')
+ with self.assertRaisesRegex(NotImplementedError, "abstract impl registered"):
+ foo(x)
+
+ @custom_op(f'{TestCustomOp.test_ns}::bar')
+ def bar(sizes: Sequence[int]) -> torch.Tensor:
+ ...
+
+ with self.assertRaisesRegex(NotImplementedError, "no Tensor inputs"):
+ bar((1, 2, 3))
+
def test_abstract_registration_location(self):
loc = torch.testing._internal.custom_op_db.numpy_nonzero._get_impl('abstract').location
matches = re.match(r'.*custom_op_db.py:\d+', loc)
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 046ffde..0bd3ddb 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1297,6 +1297,7 @@
def _dispatch_check_all_invariants() -> None: ...
def _dispatch_call_boxed(handle: _DispatchOperatorHandle, *args, **kwargs) -> Any: ...
def _dispatch_find_schema_or_throw(name: str, overload_name: str) -> _DispatchOperatorHandle: ...
+def _dispatch_set_report_error_callback(handle: _DispatchOperatorHandle, callback: Callable) -> None: ...
def _dispatch_has_kernel(name: str) -> _bool: ...
def _dispatch_has_kernel_for_dispatch_key(
name: str,
diff --git a/torch/_custom_op.py b/torch/_custom_op.py
index 71f090e..b845d9d 100644
--- a/torch/_custom_op.py
+++ b/torch/_custom_op.py
@@ -157,6 +157,10 @@
get_autograd_not_implemented_kernel(weakref.proxy(result))
)
+ torch._C._dispatch_set_report_error_callback(
+ ophandle, functools.partial(report_error_callback, weakref.proxy(result))
+ )
+
return result
return inner
@@ -786,3 +790,33 @@
SUPPORTED_PARAM_TYPES = get_supported_param_types()
+
+
+def report_error_callback(custom_op: typing.Any, key: str) -> None:
+ if key == "Undefined":
+ raise NotImplementedError(
+ f"{custom_op}: There were no Tensor inputs to this operator "
+ f"(e.g. you passed an empty list of Tensors). If your operator is a "
+ f"factory function (that is, it takes no Tensors and constructs "
+ f"a new one), then please use CustomOp.impl_factory to register "
+ f"an implementation for it"
+ )
+ if key == "Meta":
+ raise NotImplementedError(
+ f"{custom_op}: when running with device='Meta' tensors: there is no "
+ f"abstract impl registered for this CustomOp. Please register one via "
+ f"CustomOp.impl_abstract to get this CustomOp to work with Meta tensors"
+ )
+ if key in ("CPU", "CUDA"):
+ device = key.lower()
+ raise NotImplementedError(
+ f"{custom_op}: when running with device='{device}' tensors: there is no "
+ f"{device} impl registered for this CustomOp. Please register one via "
+ f"CustomOp.impl(device_type='{device}')"
+ )
+ raise NotImplementedError(
+ f"{custom_op}: No implementation for dispatch key {key}. It is likely "
+ f"that we have not added this functionality yet, please either open an "
+ f"issue or if you're feeling adventurous, use the low-level "
+ f"torch.library API"
+ )
diff --git a/torch/csrc/PyInterpreter.cpp b/torch/csrc/PyInterpreter.cpp
index a01b1d3..709175f 100644
--- a/torch/csrc/PyInterpreter.cpp
+++ b/torch/csrc/PyInterpreter.cpp
@@ -44,6 +44,7 @@
void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
const override;
+ void reportErrorCallback(PyObject* callback, DispatchKey key) const override;
void python_dispatcher(
const c10::OperatorHandle& op,
c10::DispatchKeySet,
@@ -254,6 +255,16 @@
return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python);
}
+void ConcretePyInterpreterVTable::reportErrorCallback(
+ PyObject* callback,
+ DispatchKey key) const {
+ py::gil_scoped_acquire g;
+ auto func = py::reinterpret_borrow<py::object>(callback);
+ // Not all DispatchKeys are pybind'ed into Python and we do not have infra
+ // to ensure this, so just pass a string back to Python.
+ func(c10::toString(key));
+}
+
void ConcretePyInterpreterVTable::dispatch(
const c10::OperatorHandle& op,
torch::jit::Stack* stack) const {
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 852710e..51350e5 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -688,6 +688,14 @@
return names;
},
py::arg("dispatch_key") = static_cast<const char*>(""));
+ m.def(
+ "_dispatch_set_report_error_callback",
+ [](c10::OperatorHandle& handle, py::object callback) {
+ auto obj = callback.release().ptr();
+ auto callback_obj =
+ std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
+ handle.setReportErrorCallback_(std::move(callback_obj));
+ });
m.def(
"_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });