Support torchbind op dispatch in python (#123367)

We override the `__call__` method and register fake, functional, proxy default dispatch mode implementation in its python_key_mode_table.

The idea is:
1. when inputs contains FakeScriptObject,  we dispatch it through _get_dispatch mechanism. We implement dispatch mode keys automatically in the operator's constructor.
2. when inputs are not fakified, we dispatch through the original c++ dispatcher.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123367
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index c6d3365..caf73d7 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -403,6 +403,10 @@
     return operatorDef_->op.hasKernelForDispatchKey(k);
   }
 
+  bool isKernelFallthroughKernel(DispatchKey k) const {
+    return operatorDef_->op.kernelForDispatchKey(k).isFallthrough();
+  }
+
   bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
     return operatorDef_->op.hasKernelForAnyDispatchKey(k);
   }
diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py
index 7523af9..b5312fe 100644
--- a/test/export/test_torchbind.py
+++ b/test/export/test_torchbind.py
@@ -85,6 +85,19 @@
                 test.tq_size_counter += 1
                 return len(self.queue)
 
+        self.torch_bind_ops = [
+            torch.ops._TorchScriptTesting.takes_foo,
+            torch.ops._TorchScriptTesting.takes_foo_python_meta,
+            torch.ops._TorchScriptTesting.takes_foo_list_return,
+            torch.ops._TorchScriptTesting.takes_foo_tuple_return,
+            torch.ops._TorchScriptTesting.take_an_instance,
+            torch.ops._TorchScriptTesting.take_an_instance_inferred,
+            torch.ops._TorchScriptTesting.takes_foo_cia,
+            torch.ops._TorchScriptTesting.queue_pop,
+            torch.ops._TorchScriptTesting.queue_push,
+            torch.ops._TorchScriptTesting.queue_size,
+        ]
+
     def tearDown(self):
         torch._library.fake_class_registry.deregister_fake_class(
             "_TorchScriptTesting::_Foo"
@@ -555,6 +568,181 @@
         self.assertEqual(tq.size(), 0)
         self.assertEqual(tq1.size(), 0)
 
+    def test_identifying_torchbind_ops(self):
+        for op in self.torch_bind_ops:
+            self.assertTrue(op._has_torchbind_op_overload)
+
+        for op in [
+            torch.ops.aten.add,
+            torch.ops.aten.cos,
+        ]:
+            self.assertFalse(op._has_torchbind_op_overload)
+
+    def test_torchbind_op_register_fallthrough(self):
+        TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU
+        TEST_DISPATCH_KEY_STR = "AutocastCPU"
+
+        for op_packet in self.torch_bind_ops:
+            op = op_packet.default
+            ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
+            with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+                lib.impl(
+                    op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR
+                )
+                self.assertTrue(
+                    torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
+                        op.name(), TEST_DISPATCH_KEY
+                    )
+                )
+
+    def test_torchbind_op_fallthrough_keys_respects_lib_impl(self):
+        TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU
+        TEST_DISPATCH_KEY_STR = "AutogradCPU"
+
+        tested = 0
+        for op_packet in self.torch_bind_ops:
+            op = op_packet.default
+            ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
+            if (
+                not torch._C._dispatch_has_kernel_for_dispatch_key(
+                    op.name(), TEST_DISPATCH_KEY
+                )
+                and TEST_DISPATCH_KEY not in op.py_kernels
+            ):
+                tested += 1
+                with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+                    lib.impl(
+                        op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR
+                    )
+                    self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys())
+
+                with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+                    lib.impl(
+                        op.name(),
+                        torch.library.fallthrough_kernel,
+                        TEST_DISPATCH_KEY_STR,
+                    )
+                    self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys())
+        self.assertTrue(tested > 0)
+
+    def test_make_fx_schema_checking_script_object(self):
+        class Model(torch.nn.Module):
+            def forward(self, tq, x, foo):
+                torch.ops._TorchScriptTesting.queue_push(foo, x.cos())
+                return tq
+
+        class ModelCallByKW(torch.nn.Module):
+            def forward(self, tq, x, foo):
+                torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo)
+                return tq
+
+        mod = Model()
+        modkw = ModelCallByKW()
+
+        foo = torch.classes._TorchScriptTesting._Foo(10, 20)
+        x = torch.ones(3, 3)
+        tq = torch.classes._TorchScriptTesting._TensorQueue(
+            torch.empty(
+                0,
+            ).fill_(-1)
+        )
+        ns = "_TorchScriptTesting"
+        with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+            op = torch.ops._TorchScriptTesting.queue_push
+            lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU")
+            lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView")
+            lib.impl(
+                op.__name__,
+                torch.library.fallthrough_kernel,
+                "PythonTLSSnapshot",
+            )
+
+            with self.assertRaisesRegex(
+                RuntimeError, "is expected to be a FakeScriptObject"
+            ):
+                _ = make_fx(mod, tracing_mode="fake")(tq, x, foo)
+
+            with self.assertRaisesRegex(
+                RuntimeError, "is expected to be a FakeScriptObject"
+            ):
+                _ = make_fx(modkw, tracing_mode="fake")(tq, x, foo)
+
+    @parametrize("fallthrough_via", ["lib_impl", "py_impl"])
+    def test_make_fx_tensor_queue_operators(self, fallthrough_via):
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, tq, x):
+                with torch.autocast("cuda", dtype=torch.bfloat16):
+                    torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
+                    torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
+                    x_sin = torch.ops._TorchScriptTesting.queue_pop(
+                        tq
+                    ) - torch.ops._TorchScriptTesting.queue_size(tq)
+                    x_cos = torch.ops._TorchScriptTesting.queue_pop(
+                        tq
+                    ) + torch.ops._TorchScriptTesting.queue_size(tq)
+                    return x_sin, x_cos, tq
+
+        mod = Model()
+
+        tq1 = torch.classes._TorchScriptTesting._TensorQueue(
+            torch.empty(
+                0,
+            ).fill_(-1)
+        )
+        tq2 = torch.classes._TorchScriptTesting._TensorQueue(
+            torch.empty(
+                0,
+            ).fill_(-1)
+        )
+        x = torch.ones(2, 3)
+
+        mod(tq1, x)
+
+        ops = [
+            torch.ops._TorchScriptTesting.queue_push,
+            torch.ops._TorchScriptTesting.queue_pop,
+            torch.ops._TorchScriptTesting.queue_size,
+        ]
+        if fallthrough_via == "lib_impl":
+            ns = "_TorchScriptTesting"
+            with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+                for op in ops:
+                    lib.impl(
+                        op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA"
+                    )
+
+                gm = make_fx(mod, tracing_mode="fake")(tq1, x)
+        else:
+            for op in ops:
+                op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)(
+                    torch.library.fallthrough_kernel
+                )
+            gm = make_fx(mod, tracing_mode="fake")(tq1, x)
+            for op in ops:
+                op.default._dispatch_cache.clear()
+                del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA]
+
+        self.assertExpectedInline(
+            gm.code.strip(),
+            """\
+def forward(self, arg0_1, arg1_1):
+    cos = torch.ops.aten.cos.default(arg1_1)
+    queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos);  cos = None
+    sin = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
+    queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin);  sin = None
+    queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
+    queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
+    sub = torch.ops.aten.sub.Tensor(queue_pop, 1);  queue_pop = None
+    queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
+    queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
+    add = torch.ops.aten.add.Tensor(queue_pop_1, 0);  queue_pop_1 = None
+    return (sub, add, arg0_1)""",
+        )
+        self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
+
 
 @skipIfTorchDynamo("torchbind not supported with dynamo yet")
 class TestRegisterFakeClass(TestCase):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 20aba946..583bd38 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -428,6 +428,16 @@
 
 # Defined in torch/csrc/jit/python/script_init.cpp
 #        and torch/csrc/jit/python/init.cpp
+def _maybe_call_torch_function_for_op_packet(
+    op_overload_packet: Any,
+    args: Any,
+    kwargs: Any,
+) -> Any: ...
+def _check_schema_allow_fake_script_object(
+    schema: FunctionSchema,
+    args: Any,
+    kwargs: Any,
+) -> _bool: ...
 def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
 def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
 def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
@@ -1493,6 +1503,10 @@
     name: str,
     dispatch_key_set: DispatchKeySet,
 ) -> _bool: ...
+def _dispatch_kernel_for_dispatch_key_is_fallthrough(
+    name: str,
+    dispatch: _dispatchkey,
+) -> _bool: ...
 def _dispatch_has_computed_kernel_for_dispatch_key(
     name: str,
     dispatch: _dispatchkey,
diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py
index f12ed95..4c42464 100644
--- a/torch/_dynamo/tensor_version_op.py
+++ b/torch/_dynamo/tensor_version_op.py
@@ -13,13 +13,13 @@
 
 
 @_tensor_version.py_impl(FakeTensorMode)
-def _tensor_version_fake(self):
+def _tensor_version_fake(fake_mode, self_tensor):
     """
     The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
     `._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
     of input tensors to the graph.
     """
-    return self.fake_mode.shape_env.create_unbacked_symint()
+    return fake_mode.shape_env.create_unbacked_symint()
 
 
 _unsafe_set_version_counter = _make_prim(
@@ -48,10 +48,10 @@
 
 
 @_tensor_version.py_impl(FunctionalTensorMode)
-def _tensor_version_functional(self):
+def _tensor_version_functional(mode, self):
     return self._version
 
 
 @_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
-def _unsafe_set_version_counter_functional(self, version):
+def _unsafe_set_version_counter_functional(ctx, self, version):
     torch._C._autograd._unsafe_set_version_counter(self, version)
diff --git a/torch/_library/utils.py b/torch/_library/utils.py
index e2d0110..2bab3d6 100644
--- a/torch/_library/utils.py
+++ b/torch/_library/utils.py
@@ -199,3 +199,21 @@
     the C++ op with a python module.
     """
     return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
+
+
+def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
+    assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
+    overload_types = []
+    args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
+    for a in args_flattened:
+        # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
+        # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
+        # where in one case we only include tensors with the python key, and in another
+        # we include **all** tensors.
+        if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
+            torch._C.DispatchKey.Python
+        ):
+            overload_types.append(type(a))
+    # TODO: check that I got these args correct (in C++, we pass in "0000"??)
+
+    return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
diff --git a/torch/_ops.py b/torch/_ops.py
index 8bfacf8..0809d7c 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -4,7 +4,7 @@
 import inspect
 import sys
 import types
-from typing import Any, Callable, Dict, Set, Type, Union
+from typing import Any, Callable, Dict, List, Set, Type, Union
 
 import torch._C
 import torch.utils._pytree as pytree
@@ -261,6 +261,7 @@
         if self.__class__ is HigherOrderOperator:
             self_name_space = "." + self.namespace if self.namespace else ""
             self.__module__ = self.__module__ + self_name_space
+
         self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
 
         for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
@@ -684,7 +685,10 @@
         assert key not in self._dispatch_cache, f"{self} {key}"
 
         if key == torch._C.DispatchKey.Python:
-            if not self.python_key_mode_table:
+            if (
+                not isinstance(self, TorchBindOpOverload)
+                and not self.python_key_mode_table
+            ):
                 self._dispatch_cache[key] = key
                 add_cached_op(self)
                 return key
@@ -698,12 +702,18 @@
                 assert (
                     curr_mode is not None
                 ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
+
                 if curr_mode not in self.python_key_mode_table:
-                    # TODO: This path is slow, should generally encourage this
-                    # case to not happen
-                    return self._op_dk(key, *args, **kwargs)
-                # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
-                return self.python_key_mode_table[curr_mode](*args, **kwargs)
+                    if isinstance(self, TorchBindOpOverload):
+                        with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
+                            return torch._library.utils.handle_dispatch_mode(
+                                mode, self, *args, **kwargs
+                            )
+                    else:
+                        return self._op_dk(key, *args, **kwargs)
+
+                with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
+                    return self.python_key_mode_table[curr_mode](mode, *args, **kwargs)
 
             self._dispatch_cache[key] = handler
             add_cached_op(self)
@@ -731,24 +741,8 @@
                             _set_mode_pre_dispatch(top_mode)
 
                     with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
-                        assert isinstance(curr_mode, TorchDispatchMode)
-                        overload_types = []
-                        args_flattened, _ = torch.utils._pytree.tree_flatten(
-                            (args, kwargs.values())
-                        )
-                        for a in args_flattened:
-                            # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
-                            # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
-                            # where in one case we only include tensors with the python key, and in another
-                            # we include **all** tensors.
-                            if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(
-                                a
-                            ).has(torch._C.DispatchKey.Python):
-                                overload_types.append(type(a))
-                        # TODO: check that I got these args correct (in C++, we pass in "0000"??)
-
-                        return curr_mode.__torch_dispatch__(
-                            self, overload_types, args, kwargs
+                        return torch._library.utils.handle_dispatch_mode(
+                            curr_mode, self, *args, **kwargs
                         )
 
                 # Note [Not Caching Per-Dispatch-Key Mode Handlers]
@@ -776,7 +770,6 @@
                     add_cached_op(self)
                 return handler
 
-        # print(self, key, final_key)
         r = self.py_kernels.get(final_key, final_key)
         if cache_result:
             self._dispatch_cache[key] = r
@@ -801,6 +794,98 @@
     # TODO: add more methods to expose information about input and output arguments
 
 
+# TorchBindOpOverload are those custom ops which have at least one overload's
+# schema consists of torch.ScriptObject (i.e. custom class) input.
+# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
+# when its inputs contain FakeScriptObject in a similar way as higher order ops.
+class TorchBindOpOverload(OpOverload):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def _fallthrough_keys(self) -> List[DispatchKey]:
+        # TODO: we should be calling the fallback for these, but a fallthrough is almost close
+        # enough to the fallback in most cases that we care about.
+        _DEFAULT_FALLTHROUGH_KEYS = [
+            DispatchKey.Autograd,
+            DispatchKey.AutogradCPU,
+            DispatchKey.AutogradCUDA,
+            DispatchKey.ADInplaceOrView,
+            DispatchKey.PythonTLSSnapshot,
+        ]
+
+        def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
+            if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
+                return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
+                    self.name(), key
+                )
+
+            return (
+                key not in self.py_kernels
+                or self.py_kernels[key] is torch.library.fallthrough_kernel
+            )
+
+        return [
+            key
+            for key in _DEFAULT_FALLTHROUGH_KEYS
+            if _may_use_fallthrough_instead_of_fallback(key)
+        ]
+
+    # use `self_` to avoid naming collide with arguments that
+    # are named "self". This way, they can be called by kwargs.
+    def __call__(self_, *args, **kwargs):  # noqa: B902
+        if _must_dispatch_in_python(args, kwargs):
+            # When any inputs are FakeScriptObject, we need to
+            # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher.
+            return self_._dispatch_in_python(args, kwargs, self_._fallthrough_keys())
+
+        return self_._op(*args, **kwargs)
+
+    def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
+        non_fallthrough_keys = torch._C._dispatch_keyset_full()
+        for key in fallthrough_keys:
+            non_fallthrough_keys = non_fallthrough_keys.remove(key)
+
+        dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
+        dispatch_key = dispatch_key_set.highestPriorityTypeId()
+
+        handler = (
+            self._get_dispatch(dispatch_key)
+            if dispatch_key not in self._dispatch_cache
+            else self._dispatch_cache[dispatch_key]
+        )
+
+        if isinstance(handler, DispatchKey):
+            # fallthrough keys can be registered at runtime via torch.library.impl
+            # so need to add it to fallthrough_keys and re-dispatch.
+            if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
+                self.name(), dispatch_key
+            ):
+                return self._dispatch_in_python(
+                    args, kwargs, fallthrough_keys + [dispatch_key]
+                )
+
+            raise RuntimeError(
+                f"Cannot handle FakeScriptObject with python dispatcher with dispatch key {handler}."
+                f"Please implement it by annotating a python callable with py_impl({handler})."
+            )
+
+        assert isinstance(handler, Callable)  # type: ignore[arg-type]
+        return handler(*args, **kwargs)
+
+
+def _must_dispatch_in_python(args, kwargs):
+    return pytree.tree_any(
+        lambda obj: isinstance(
+            obj, torch._library.fake_class_registry.FakeScriptObject
+        ),
+        (args, kwargs),
+    )
+
+
+def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
+    return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
+
+
 # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
 # You can obtain an OpOverload object through attribute query.
 class OpOverloadPacket:
@@ -812,6 +897,9 @@
         self._op = op
         self._overload_names = overload_names
         self._dir = []
+        self._has_torchbind_op_overload = any(
+            _has_script_object_arg(schema) for schema in self._schemas.values()
+        )
 
     # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
     def __deepcopy__(self, memo=None):
@@ -832,6 +920,13 @@
     def op(self):
         return self._op
 
+    @property
+    def _schemas(self):
+        return {
+            overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
+            for overload_name in self._overload_names
+        }
+
     def __getattr__(self, key):
         # It is not a valid op_name when __file__ is passed in
         if key == "__file__":
@@ -865,7 +960,11 @@
                 self._qualified_op_name, use_key
             )
             schema = torch._C._get_schema(self._qualified_op_name, use_key)
-            overload = OpOverload(self, op_, op_dk_, schema, tags)
+            overload = (
+                OpOverload(self, op_, op_dk_, schema, tags)
+                if not _has_script_object_arg(schema)
+                else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
+            )
             # cache the overload object
             setattr(self, key, overload)
             self._dir.append(key)
@@ -886,6 +985,12 @@
         # is still callable from JIT
         # We save the function ptr as the `op` attribute on
         # OpOverloadPacket to access it here.
+
+        # Directly calling OverloadPacket goes into C++, which will check
+        # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
+        # intercept it here and call TorchBindOpverload instead.
+        if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
+            return _call_overload_packet_from_python(self_, args, kwargs)
         return self_._op(*args, **(kwargs or {}))
 
     # TODO: use this to make a __dir__
@@ -893,6 +998,46 @@
         return [n if n else "default" for n in self._overload_names]
 
 
+# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
+# _jit_get_operations, which calls _get_operation_for_overload_or_packet.
+def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
+    # Re-use the torch function handling logic in cpp
+    torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
+        op, *args, **kwargs
+    )
+
+    if torch_function_called:
+        return ret
+
+    # The following mirrors getOpWithStack.
+    # In cpp, we do a schema matching for the arguments, and call ToIValue to
+    # to check whether the arguments are valid. But need to do similar things here
+    # and check the schema whether the FakeScriptObject is the corresponding fake class
+    # of the actual class used in schema.
+    exceptions = {}
+    found_op = None
+    for overload_name in op.overloads():
+        op_overload = getattr(op, overload_name)
+        try:
+            _ = torch._C._check_schema_allow_fake_script_object(
+                op_overload._schema, *args, **kwargs
+            )
+            found_op = op_overload
+            break
+        except RuntimeError as e:
+            exceptions[overload_name] = e
+
+    if found_op:
+        return found_op(*args, **kwargs)
+
+    err_msg = (
+        f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
+    )
+    for i, (key, msg) in enumerate(exceptions.items()):
+        err_msg += f"Overload name {key}:\n {msg}\n"
+    raise RuntimeError(err_msg)
+
+
 # Resolution of torch.fn is different from torch.ops.aten.fn
 # torch.fn uses the Python argparser, matches with the
 # appropriate schema, and calls into the unboxed version of the method
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 057d0fc..5eb4851 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -1667,6 +1667,14 @@
       });
 
   m.def(
+      "_check_schema_allow_fake_script_object",
+      [](const FunctionSchema& schema, py::args args, py::kwargs kwargs) {
+        // checkSchemaAllowFakeScriptObject will throw runtime error if there is
+        // a schema mismatch. Otherwise, it returns true.
+        return checkSchemaAllowFakeScriptObject(schema, args, kwargs);
+      });
+
+  m.def(
       "_jit_resolve_packet",
       [](const char* op_name, py::args args, py::kwargs kwargs) {
         try {
@@ -1735,6 +1743,20 @@
       py::arg("qualified_name"));
 
   m.def(
+      "_maybe_call_torch_function_for_op_packet",
+      [](py::handle op_overload_packet, py::args args, py::kwargs kwargs) {
+        py::list ns_method =
+            op_overload_packet.attr("_qualified_op_name").attr("split")("::");
+        return _maybe_handle_torch_function(
+            py::cast<std::string>(ns_method[0]),
+            py::cast<std::string>(ns_method[1]),
+            "",
+            false,
+            args,
+            kwargs);
+      });
+
+  m.def(
       "parse_ir",
       [](const std::string& input, bool parse_tensor_constants) {
         auto graph = std::make_shared<Graph>();
diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index 9360b83..ba0135a 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -757,6 +757,23 @@
   }
 }
 
+// This function is used to check if the schema is valid for the given args and
+// kwargs. It checks script object by checking wether the FakeScriptObject is
+// an instance of the corresponding fake class for the actual class used in
+// schema.
+bool checkSchemaAllowFakeScriptObject(
+    const FunctionSchema& schema,
+    py::args args,
+    const py::kwargs& kwargs) {
+  bool match = false;
+  try {
+    match = matchSchemaAllowFakeScriptObject(schema, std::move(args), kwargs);
+  } catch (schema_match_error& error) {
+    throw std::runtime_error(error.what());
+  }
+  return match;
+}
+
 py::object invokeOperatorFromPython(
     const std::vector<std::shared_ptr<Operator>>& operations,
     py::args args,
@@ -775,13 +792,13 @@
   return createPyObjectForStack(std::move(stack));
 }
 
-py::object _get_operation_for_overload_or_packet(
-    const std::vector<std::shared_ptr<Operator>>& operations,
-    Symbol symbol,
-    py::args args,
-    const py::kwargs& kwargs,
+py::tuple _maybe_handle_torch_function(
+    const std::string& ns,
+    const std::string& method_name,
+    const std::string& overload_name,
     bool is_overload,
-    c10::optional<c10::DispatchKey> dk) {
+    py::args args,
+    const py::kwargs& kwargs) {
   std::vector<PyObject*> overloaded_args;
   size_t total_arg_num = args.size() + kwargs.size();
   for (const auto i : c10::irange(args.size())) {
@@ -807,15 +824,11 @@
         false /* throw_error */);
   }
   if (!overloaded_args.empty() || at::impl::torch_function_mode_enabled()) {
-    py::object ret;
-    std::string ns = symbol.ns().toUnqualString();
-    std::string method_name = symbol.toUnqualString();
     auto self_func = py::module::import("torch")
                          .attr("ops")
                          .attr(ns.c_str())
                          .attr(method_name.c_str());
     if (is_overload) {
-      auto overload_name = operations[0]->schema().overload_name();
       if (overload_name.empty()) {
         self_func = self_func.attr("default");
       } else {
@@ -824,16 +837,36 @@
     }
     std::string module_name("torch.ops");
     module_name.append(ns);
-    return pybind11::reinterpret_steal<py::object>(
-        handle_torch_function_no_python_arg_parser(
-            overloaded_args,
-            args.ptr(),
-            kwargs.ptr(),
-            method_name.c_str(),
-            self_func.ptr(),
-            module_name.c_str()));
+    return py::make_tuple(
+        true,
+        pybind11::reinterpret_steal<py::object>(
+            handle_torch_function_no_python_arg_parser(
+                overloaded_args,
+                args.ptr(),
+                kwargs.ptr(),
+                method_name.c_str(),
+                self_func.ptr(),
+                module_name.c_str())));
   }
-  return invokeOperatorFromPython(operations, args, kwargs, dk);
+  return py::make_tuple(false, py::none());
+}
+
+py::object _get_operation_for_overload_or_packet(
+    const std::vector<std::shared_ptr<Operator>>& operations,
+    Symbol symbol,
+    py::args args,
+    const py::kwargs& kwargs,
+    bool is_overload,
+    c10::optional<c10::DispatchKey> dk) {
+  std::string ns = symbol.ns().toUnqualString();
+  std::string method_name = symbol.toUnqualString();
+  std::string overload_name = operations[0]->schema().overload_name();
+  auto res = _maybe_handle_torch_function(
+      ns, method_name, overload_name, is_overload, args, kwargs);
+  auto torch_function_called = py::cast<bool>(res[0]);
+  return torch_function_called
+      ? res[1]
+      : invokeOperatorFromPython(operations, args, kwargs, dk);
 }
 
 } // namespace torch::jit
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index cbb7791..a78c3e0 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -873,6 +873,116 @@
   int64_t e;
 };
 
+inline bool validateFakeScriptObjectSchema(
+    const c10::FunctionSchema& schema,
+    size_t argumentPosition,
+    py::handle object) {
+  auto argument = schema.arguments().at(argumentPosition);
+  auto class_type = argument.real_type()->expect<c10::ClassType>();
+  auto fake_class_registry =
+      py::module::import("torch._library.fake_class_registry");
+  auto fake_class = fake_class_registry.attr("find_fake_class")(
+      class_type->name().value().qualifiedName());
+  if (!py::isinstance(object.attr("wrapped_obj"), fake_class)) {
+    throw schema_match_error(c10::str(
+        schema.formatTypeMismatchMsg(
+            argument,
+            friendlyTypeName(object),
+            argumentPosition,
+            py::repr(object.attr("wrapped_obj"))),
+        "\nCast error details: ",
+        argument.name(),
+        " is expected to be a FakeScriptObject of ",
+        class_type->name().value().qualifiedName()));
+  }
+  return true;
+}
+
+inline bool matchSchemaAllowFakeScriptObject(
+    const FunctionSchema& schema,
+    const tuple_slice& args,
+    const py::kwargs& kwargs) {
+  size_t all_arguments = args.size() + kwargs.size();
+  if (all_arguments > schema.arguments().size()) {
+    throw schema_match_error(c10::str(
+        schema.name(),
+        "() expected at most ",
+        schema.arguments().size(),
+        " argument(s) but received ",
+        all_arguments,
+        " argument(s). Declaration: ",
+        schema));
+  }
+
+  int64_t arg_idx = 0;
+  auto fake_class_registry =
+      py::module::import("torch._library.fake_class_registry");
+
+  // First push all positional args.
+  for (const auto& arg : args) {
+    // ...but refuse to do it if the schema says that this was supposed
+    // to be keyword only
+    if (schema.arguments()[arg_idx].kwarg_only()) {
+      throw schema_match_error(c10::str(
+          schema.name(),
+          "() takes ",
+          arg_idx,
+          " positional argument(s) but ",
+          args.size(),
+          " was/were given.  Declaration: ",
+          schema));
+    }
+    // Use the type information from the schema to convert the PyObject.
+    const auto& argument = schema.arguments().at(arg_idx);
+    if (argument.real_type()->kind() == TypeKind::ClassType &&
+        py::isinstance(arg, fake_class_registry.attr("FakeScriptObject"))) {
+      validateFakeScriptObjectSchema(schema, arg_idx, arg);
+    } else {
+      argumentToIValue(schema, arg_idx, arg);
+    }
+
+    arg_idx++;
+  }
+
+  // Now for every remaining non-positional argument in the schema, look for it
+  // in the kwargs dict and push it if found, or use its default value if it
+  // has one.
+  size_t consumed_kwargs = 0;
+  for (size_t i = arg_idx; i < schema.arguments().size(); ++i) {
+    const auto& arg = schema.arguments()[i];
+    if (kwargs.contains(arg.name().c_str())) {
+      auto cur_kwarg = kwargs[arg.name().c_str()];
+      if (arg.real_type()->kind() == TypeKind::ClassType &&
+          py::isinstance(
+              cur_kwarg, fake_class_registry.attr("FakeScriptObject"))) {
+        validateFakeScriptObjectSchema(schema, i, cur_kwarg);
+      } else {
+        argumentToIValue(schema, i, cur_kwarg);
+      }
+      consumed_kwargs += 1;
+    } else if (arg.default_value()) {
+      continue;
+    } else {
+      throw schema_match_error(c10::str(
+          schema.name(),
+          "() is missing value for argument '",
+          arg.name(),
+          "'. Declaration: ",
+          schema));
+    }
+  }
+
+  if (consumed_kwargs != kwargs.size()) {
+    std::vector<std::string> names;
+    for (const auto& kwarg : kwargs) {
+      names.emplace_back(py::cast<std::string>(kwarg.first));
+    }
+    throw schema_match_error(schema.findErrorInKwargs(names));
+  }
+
+  return true;
+}
+
 inline Stack createStackForSchema(
     const FunctionSchema& schema,
     const tuple_slice& args,
@@ -1147,6 +1257,19 @@
     const py::kwargs& kwargs,
     c10::optional<c10::DispatchKey> dk = c10::nullopt);
 
+TORCH_PYTHON_API py::tuple _maybe_handle_torch_function(
+    const std::string& ns,
+    const std::string& method_name,
+    const std::string& overload_name,
+    bool is_overload,
+    py::args args,
+    const py::kwargs& kwargs);
+
+TORCH_PYTHON_API bool checkSchemaAllowFakeScriptObject(
+    const FunctionSchema& schema,
+    py::args args,
+    const py::kwargs& kwargs);
+
 TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
     const std::vector<std::shared_ptr<Operator>>& operations,
     Symbol symbol,
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 780f37a..2d115a8 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -29,9 +29,7 @@
 
 namespace py = pybind11;
 
-namespace torch {
-namespace impl {
-namespace dispatch {
+namespace torch::impl::dispatch {
 
 // NB: I'd like to index this on OperatorHandle, but I can't, as I can't
 // guarantee that the main interpreter has finish doing all registrations before
@@ -519,6 +517,16 @@
       });
 
   m.def(
+      // Returns whether or not the kernel for this dispatach key is a
+      // fallthrough kernel
+      "_dispatch_kernel_for_dispatch_key_is_fallthrough",
+      [](const char* name, c10::DispatchKey dispatch) -> bool {
+        auto op =
+            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
+        return op->isKernelFallthroughKernel(dispatch);
+      });
+
+  m.def(
       "_dispatch_has_kernel_for_any_dispatch_key",
       [](const char* name, c10::DispatchKeySet ks) -> bool {
         auto op =
@@ -938,6 +946,4 @@
   pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
 }
 
-} // namespace dispatch
-} // namespace impl
-} // namespace torch
+} // namespace torch::impl::dispatch
diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py
index 933de44..f66388d 100644
--- a/torch/testing/_internal/torchbind_impls.py
+++ b/torch/testing/_internal/torchbind_impls.py
@@ -3,7 +3,7 @@
 
 def register_if_not(qualname):
     entry = torch._library.simple_registry.singleton.find(qualname)
-    if entry.abstract_impl.kernel is not None:
+    if entry.abstract_impl.kernel is None:
         return torch.library.impl_abstract(qualname)
     else:
 
@@ -28,5 +28,5 @@
         return tq.push(x)
 
     @register_if_not("_TorchScriptTesting::queue_size")
-    def fake_queue_size(tq, x):
+    def fake_queue_size(tq):
         return tq.size()