Enable registering fallthroughs to (op, dk) from torch.library (#106086)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106086
Approved by: https://github.com/zou3519, https://github.com/albanD
diff --git a/docs/source/library.rst b/docs/source/library.rst
index 2b89856..7f5f9cf 100644
--- a/docs/source/library.rst
+++ b/docs/source/library.rst
@@ -36,6 +36,8 @@
 .. autoclass:: torch.library.Library
   :members:
 
+.. autofunction:: fallthrough_kernel
+
 We have also added some function decorators to make it convenient to register functions for operators:
 
 * :func:`torch.library.impl`
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index 06f26ec..7cd736d 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -3,7 +3,7 @@
 import tempfile
 import torch
 from copy import deepcopy
-from torch.library import Library, impl
+from torch.library import Library, impl, fallthrough_kernel
 from torch.fx.experimental.proxy_tensor import ShapeEnv
 from torch import SymInt
 from torch._subclasses.fake_tensor import FakeTensorMode
@@ -543,6 +543,25 @@
             getattr(torch.ops, self.test_ns).foo.default,
             getattr(torch.ops, self.test_ns).foo_functional.default, (x, y, z, w))
 
+    def test_register_fallthrough(self):
+        try:
+            my_lib = Library('aten', 'IMPL')
+            my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
+
+            a = torch.randn(2, 3, device='cpu', dtype=torch.float32)
+            b = torch.randn(3, 2, device='cpu', dtype=torch.float32)
+            with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+                # dtype for mm should be float32 since we registered a fallthrough
+                self.assertEqual(torch.mm(a, b).dtype, torch.float32)
+                # ops that don't have a fallthrough registered should not be affected
+                self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
+        finally:
+            del my_lib
+
+        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
+            # default behavior should have been restored
+            self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
+
 
 class TestPythonDispatch(TestCase):
     def test_basic(self) -> None:
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 27b1c9c..e0957cf 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -314,17 +314,26 @@
              py::object func) {
             HANDLE_TH_ERRORS
             auto& lib = self.cast<torch::Library&>();
-            lib.impl(
-                name,
-                torch::dispatch(
-                    dispatch,
-                    CppFunction::makeFromBoxedFunctor(
-                        std::make_unique<PythonKernelHolder>(func, dispatch))),
-                register_or_verify());
-            python_registrations_[lib._resolve(name)].insert_or_assign(
-                dispatch,
-                std::make_shared<c10::SafePyObject>(
-                    func.release().ptr(), getPyInterpreter()));
+            if (func.is(py::module::import("torch.library")
+                            .attr("fallthrough_kernel"))) {
+              lib.impl(
+                  name,
+                  torch::dispatch(dispatch, CppFunction::makeFallthrough()),
+                  register_or_verify());
+            } else {
+              lib.impl(
+                  name,
+                  torch::dispatch(
+                      dispatch,
+                      CppFunction::makeFromBoxedFunctor(
+                          std::make_unique<PythonKernelHolder>(
+                              func, dispatch))),
+                  register_or_verify());
+              python_registrations_[lib._resolve(name)].insert_or_assign(
+                  dispatch,
+                  std::make_shared<c10::SafePyObject>(
+                      func.release().ptr(), getPyInterpreter()));
+            }
             END_HANDLE_TH_ERRORS_PYBIND
           },
           "",
diff --git a/torch/library.py b/torch/library.py
index def51fd..07f8d33 100644
--- a/torch/library.py
+++ b/torch/library.py
@@ -4,7 +4,7 @@
 import torch
 import weakref
 
-__all__ = ['Library', 'impl', 'define']
+__all__ = ['Library', 'impl', 'define', 'fallthrough_kernel']
 
 # Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
 # The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
@@ -15,6 +15,11 @@
 # prim is reserved by TorchScript interpreter
 _reserved_namespaces = ['prim']
 
+def fallthrough_kernel():
+    """
+    A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
+    """
+    raise NotImplementedError("fallthrough_kernel() should never be called.")
 
 class Library:
     """
@@ -83,7 +88,8 @@
 
         Args:
             op_name: operator name (along with the overload) or OpOverload object.
-            fn: function that's the operator implementation for the input dispatch key.
+            fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
+                to register a fallthrough.
             dispatch_key: dispatch key that the input function should be registered for. By default, it uses
                           the dispatch key that the library was created with.