Enable registering fallthroughs to (op, dk) from torch.library (#106086)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106086
Approved by: https://github.com/zou3519, https://github.com/albanD
diff --git a/docs/source/library.rst b/docs/source/library.rst
index 2b89856..7f5f9cf 100644
--- a/docs/source/library.rst
+++ b/docs/source/library.rst
@@ -36,6 +36,8 @@
.. autoclass:: torch.library.Library
:members:
+.. autofunction:: fallthrough_kernel
+
We have also added some function decorators to make it convenient to register functions for operators:
* :func:`torch.library.impl`
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 06f26ec..7cd736d 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -3,7 +3,7 @@
import tempfile
import torch
from copy import deepcopy
-from torch.library import Library, impl
+from torch.library import Library, impl, fallthrough_kernel
from torch.fx.experimental.proxy_tensor import ShapeEnv
from torch import SymInt
from torch._subclasses.fake_tensor import FakeTensorMode
@@ -543,6 +543,25 @@
getattr(torch.ops, self.test_ns).foo.default,
getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
+ def test_register_fallthrough(self):
+ try:
+ my_lib = Library('aten', 'IMPL')
+ my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
+
+ a = torch.randn(2, 3, device='cpu', dtype=torch.float32)
+ b = torch.randn(3, 2, device='cpu', dtype=torch.float32)
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+ # dtype for mm should be float32 since we registered a fallthrough
+ self.assertEqual(torch.mm(a, b).dtype, torch.float32)
+ # ops that don't have a fallthrough registered should not be affected
+ self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
+ finally:
+ del my_lib
+
+ with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+ # default behavior should have been restored
+ self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
+
class TestPythonDispatch(TestCase):
def test_basic(self) -> None:
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 27b1c9c..e0957cf 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -314,17 +314,26 @@
py::object func) {
HANDLE_TH_ERRORS
auto& lib = self.cast<torch::Library&>();
- lib.impl(
- name,
- torch::dispatch(
- dispatch,
- CppFunction::makeFromBoxedFunctor(
- std::make_unique<PythonKernelHolder>(func, dispatch))),
- register_or_verify());
- python_registrations_[lib._resolve(name)].insert_or_assign(
- dispatch,
- std::make_shared<c10::SafePyObject>(
- func.release().ptr(), getPyInterpreter()));
+ if (func.is(py::module::import("torch.library")
+ .attr("fallthrough_kernel"))) {
+ lib.impl(
+ name,
+ torch::dispatch(dispatch, CppFunction::makeFallthrough()),
+ register_or_verify());
+ } else {
+ lib.impl(
+ name,
+ torch::dispatch(
+ dispatch,
+ CppFunction::makeFromBoxedFunctor(
+ std::make_unique<PythonKernelHolder>(
+ func, dispatch))),
+ register_or_verify());
+ python_registrations_[lib._resolve(name)].insert_or_assign(
+ dispatch,
+ std::make_shared<c10::SafePyObject>(
+ func.release().ptr(), getPyInterpreter()));
+ }
END_HANDLE_TH_ERRORS_PYBIND
},
"",
diff --git a/torch/library.py b/torch/library.py
index def51fd..07f8d33 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -4,7 +4,7 @@
import torch
import weakref
-__all__ = ['Library', 'impl', 'define']
+__all__ = ['Library', 'impl', 'define', 'fallthrough_kernel']
# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
@@ -15,6 +15,11 @@
# prim is reserved by TorchScript interpreter
_reserved_namespaces = ['prim']
+def fallthrough_kernel():
+ """
+ A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
+ """
+ raise NotImplementedError("fallthrough_kernel() should never be called.")
class Library:
"""
@@ -83,7 +88,8 @@
Args:
op_name: operator name (along with the overload) or OpOverload object.
- fn: function that's the operator implementation for the input dispatch key.
+ fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
+ to register a fallthrough.
dispatch_key: dispatch key that the input function should be registered for. By default, it uses
the dispatch key that the library was created with.