Support torchbind op dispatch in python (#123367)
We override the `__call__` method and register fake, functional, proxy default dispatch mode implementation in its python_key_mode_table.
The idea is:
1. when inputs contains FakeScriptObject, we dispatch it through _get_dispatch mechanism. We implement dispatch mode keys automatically in the operator's constructor.
2. when inputs are not fakified, we dispatch through the original c++ dispatcher.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123367
Approved by: https://github.com/zou3519
diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h
index c6d3365..caf73d7 100644
--- a/aten/src/ATen/core/dispatch/Dispatcher.h
+++ b/aten/src/ATen/core/dispatch/Dispatcher.h
@@ -403,6 +403,10 @@
return operatorDef_->op.hasKernelForDispatchKey(k);
}
+ bool isKernelFallthroughKernel(DispatchKey k) const {
+ return operatorDef_->op.kernelForDispatchKey(k).isFallthrough();
+ }
+
bool hasKernelForAnyDispatchKey(DispatchKeySet k) const {
return operatorDef_->op.hasKernelForAnyDispatchKey(k);
}
diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py
index 7523af9..b5312fe 100644
--- a/test/export/test_torchbind.py
+++ b/test/export/test_torchbind.py
@@ -85,6 +85,19 @@
test.tq_size_counter += 1
return len(self.queue)
+ self.torch_bind_ops = [
+ torch.ops._TorchScriptTesting.takes_foo,
+ torch.ops._TorchScriptTesting.takes_foo_python_meta,
+ torch.ops._TorchScriptTesting.takes_foo_list_return,
+ torch.ops._TorchScriptTesting.takes_foo_tuple_return,
+ torch.ops._TorchScriptTesting.take_an_instance,
+ torch.ops._TorchScriptTesting.take_an_instance_inferred,
+ torch.ops._TorchScriptTesting.takes_foo_cia,
+ torch.ops._TorchScriptTesting.queue_pop,
+ torch.ops._TorchScriptTesting.queue_push,
+ torch.ops._TorchScriptTesting.queue_size,
+ ]
+
def tearDown(self):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
@@ -555,6 +568,181 @@
self.assertEqual(tq.size(), 0)
self.assertEqual(tq1.size(), 0)
+ def test_identifying_torchbind_ops(self):
+ for op in self.torch_bind_ops:
+ self.assertTrue(op._has_torchbind_op_overload)
+
+ for op in [
+ torch.ops.aten.add,
+ torch.ops.aten.cos,
+ ]:
+ self.assertFalse(op._has_torchbind_op_overload)
+
+ def test_torchbind_op_register_fallthrough(self):
+ TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU
+ TEST_DISPATCH_KEY_STR = "AutocastCPU"
+
+ for op_packet in self.torch_bind_ops:
+ op = op_packet.default
+ ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
+ with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+ lib.impl(
+ op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR
+ )
+ self.assertTrue(
+ torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
+ op.name(), TEST_DISPATCH_KEY
+ )
+ )
+
+ def test_torchbind_op_fallthrough_keys_respects_lib_impl(self):
+ TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU
+ TEST_DISPATCH_KEY_STR = "AutogradCPU"
+
+ tested = 0
+ for op_packet in self.torch_bind_ops:
+ op = op_packet.default
+ ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
+ if (
+ not torch._C._dispatch_has_kernel_for_dispatch_key(
+ op.name(), TEST_DISPATCH_KEY
+ )
+ and TEST_DISPATCH_KEY not in op.py_kernels
+ ):
+ tested += 1
+ with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+ lib.impl(
+ op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR
+ )
+ self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys())
+
+ with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+ lib.impl(
+ op.name(),
+ torch.library.fallthrough_kernel,
+ TEST_DISPATCH_KEY_STR,
+ )
+ self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys())
+ self.assertTrue(tested > 0)
+
+ def test_make_fx_schema_checking_script_object(self):
+ class Model(torch.nn.Module):
+ def forward(self, tq, x, foo):
+ torch.ops._TorchScriptTesting.queue_push(foo, x.cos())
+ return tq
+
+ class ModelCallByKW(torch.nn.Module):
+ def forward(self, tq, x, foo):
+ torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo)
+ return tq
+
+ mod = Model()
+ modkw = ModelCallByKW()
+
+ foo = torch.classes._TorchScriptTesting._Foo(10, 20)
+ x = torch.ones(3, 3)
+ tq = torch.classes._TorchScriptTesting._TensorQueue(
+ torch.empty(
+ 0,
+ ).fill_(-1)
+ )
+ ns = "_TorchScriptTesting"
+ with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+ op = torch.ops._TorchScriptTesting.queue_push
+ lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU")
+ lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView")
+ lib.impl(
+ op.__name__,
+ torch.library.fallthrough_kernel,
+ "PythonTLSSnapshot",
+ )
+
+ with self.assertRaisesRegex(
+ RuntimeError, "is expected to be a FakeScriptObject"
+ ):
+ _ = make_fx(mod, tracing_mode="fake")(tq, x, foo)
+
+ with self.assertRaisesRegex(
+ RuntimeError, "is expected to be a FakeScriptObject"
+ ):
+ _ = make_fx(modkw, tracing_mode="fake")(tq, x, foo)
+
+ @parametrize("fallthrough_via", ["lib_impl", "py_impl"])
+ def test_make_fx_tensor_queue_operators(self, fallthrough_via):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, tq, x):
+ with torch.autocast("cuda", dtype=torch.bfloat16):
+ torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
+ torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
+ x_sin = torch.ops._TorchScriptTesting.queue_pop(
+ tq
+ ) - torch.ops._TorchScriptTesting.queue_size(tq)
+ x_cos = torch.ops._TorchScriptTesting.queue_pop(
+ tq
+ ) + torch.ops._TorchScriptTesting.queue_size(tq)
+ return x_sin, x_cos, tq
+
+ mod = Model()
+
+ tq1 = torch.classes._TorchScriptTesting._TensorQueue(
+ torch.empty(
+ 0,
+ ).fill_(-1)
+ )
+ tq2 = torch.classes._TorchScriptTesting._TensorQueue(
+ torch.empty(
+ 0,
+ ).fill_(-1)
+ )
+ x = torch.ones(2, 3)
+
+ mod(tq1, x)
+
+ ops = [
+ torch.ops._TorchScriptTesting.queue_push,
+ torch.ops._TorchScriptTesting.queue_pop,
+ torch.ops._TorchScriptTesting.queue_size,
+ ]
+ if fallthrough_via == "lib_impl":
+ ns = "_TorchScriptTesting"
+ with torch.library._scoped_library(ns, "FRAGMENT") as lib:
+ for op in ops:
+ lib.impl(
+ op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA"
+ )
+
+ gm = make_fx(mod, tracing_mode="fake")(tq1, x)
+ else:
+ for op in ops:
+ op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)(
+ torch.library.fallthrough_kernel
+ )
+ gm = make_fx(mod, tracing_mode="fake")(tq1, x)
+ for op in ops:
+ op.default._dispatch_cache.clear()
+ del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA]
+
+ self.assertExpectedInline(
+ gm.code.strip(),
+ """\
+def forward(self, arg0_1, arg1_1):
+ cos = torch.ops.aten.cos.default(arg1_1)
+ queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = None
+ sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
+ queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = None
+ queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
+ queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
+ sub = torch.ops.aten.sub.Tensor(queue_pop, 1); queue_pop = None
+ queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
+ queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1)
+ add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None
+ return (sub, add, arg0_1)""",
+ )
+ self._assertEqualSkipScriptObject(gm(tq1, x), mod(tq2, x))
+
@skipIfTorchDynamo("torchbind not supported with dynamo yet")
class TestRegisterFakeClass(TestCase):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 20aba946..583bd38 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -428,6 +428,16 @@
# Defined in torch/csrc/jit/python/script_init.cpp
# and torch/csrc/jit/python/init.cpp
+def _maybe_call_torch_function_for_op_packet(
+ op_overload_packet: Any,
+ args: Any,
+ kwargs: Any,
+) -> Any: ...
+def _check_schema_allow_fake_script_object(
+ schema: FunctionSchema,
+ args: Any,
+ kwargs: Any,
+) -> _bool: ...
def _create_function_from_graph(qualname: str, graph: Graph) -> ScriptFunction: ...
def _debug_set_autodiff_subgraph_inlining(disabled: _bool) -> None: ...
def _ivalue_tags_match(lhs: ScriptModule, rhs: ScriptModule) -> _bool: ...
@@ -1493,6 +1503,10 @@
name: str,
dispatch_key_set: DispatchKeySet,
) -> _bool: ...
+def _dispatch_kernel_for_dispatch_key_is_fallthrough(
+ name: str,
+ dispatch: _dispatchkey,
+) -> _bool: ...
def _dispatch_has_computed_kernel_for_dispatch_key(
name: str,
dispatch: _dispatchkey,
diff --git a/torch/_dynamo/tensor_version_op.py b/torch/_dynamo/tensor_version_op.py
index f12ed95..4c42464 100644
--- a/torch/_dynamo/tensor_version_op.py
+++ b/torch/_dynamo/tensor_version_op.py
@@ -13,13 +13,13 @@
@_tensor_version.py_impl(FakeTensorMode)
-def _tensor_version_fake(self):
+def _tensor_version_fake(fake_mode, self_tensor):
"""
The initial dynamo capture of _tensor_version + _unsafe_set_version_counter turns the
`._version` into an unbacked SymInt so that we don't need to specialize on the `._version`
of input tensors to the graph.
"""
- return self.fake_mode.shape_env.create_unbacked_symint()
+ return fake_mode.shape_env.create_unbacked_symint()
_unsafe_set_version_counter = _make_prim(
@@ -48,10 +48,10 @@
@_tensor_version.py_impl(FunctionalTensorMode)
-def _tensor_version_functional(self):
+def _tensor_version_functional(mode, self):
return self._version
@_unsafe_set_version_counter.py_impl(FunctionalTensorMode)
-def _unsafe_set_version_counter_functional(self, version):
+def _unsafe_set_version_counter_functional(ctx, self, version):
torch._C._autograd._unsafe_set_version_counter(self, version)
diff --git a/torch/_library/utils.py b/torch/_library/utils.py
index e2d0110..2bab3d6 100644
--- a/torch/_library/utils.py
+++ b/torch/_library/utils.py
@@ -199,3 +199,21 @@
the C++ op with a python module.
"""
return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True)
+
+
+def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
+ assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
+ overload_types = []
+ args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
+ for a in args_flattened:
+ # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
+ # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
+ # where in one case we only include tensors with the python key, and in another
+ # we include **all** tensors.
+ if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
+ torch._C.DispatchKey.Python
+ ):
+ overload_types.append(type(a))
+ # TODO: check that I got these args correct (in C++, we pass in "0000"??)
+
+ return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
diff --git a/torch/_ops.py b/torch/_ops.py
index 8bfacf8..0809d7c 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -4,7 +4,7 @@
import inspect
import sys
import types
-from typing import Any, Callable, Dict, Set, Type, Union
+from typing import Any, Callable, Dict, List, Set, Type, Union
import torch._C
import torch.utils._pytree as pytree
@@ -261,6 +261,7 @@
if self.__class__ is HigherOrderOperator:
self_name_space = "." + self.namespace if self.namespace else ""
self.__module__ = self.__module__ + self_name_space
+
self.non_fallthrough_keys = torch._C._dispatch_keyset_full()
for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS:
@@ -684,7 +685,10 @@
assert key not in self._dispatch_cache, f"{self} {key}"
if key == torch._C.DispatchKey.Python:
- if not self.python_key_mode_table:
+ if (
+ not isinstance(self, TorchBindOpOverload)
+ and not self.python_key_mode_table
+ ):
self._dispatch_cache[key] = key
add_cached_op(self)
return key
@@ -698,12 +702,18 @@
assert (
curr_mode is not None
), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode."
+
if curr_mode not in self.python_key_mode_table:
- # TODO: This path is slow, should generally encourage this
- # case to not happen
- return self._op_dk(key, *args, **kwargs)
- # TODO(voz): The idea behind this is that we do not yet support dispatch by key + mode, only key.
- return self.python_key_mode_table[curr_mode](*args, **kwargs)
+ if isinstance(self, TorchBindOpOverload):
+ with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
+ return torch._library.utils.handle_dispatch_mode(
+ mode, self, *args, **kwargs
+ )
+ else:
+ return self._op_dk(key, *args, **kwargs)
+
+ with torch.utils._python_dispatch._pop_mode_temporarily() as mode:
+ return self.python_key_mode_table[curr_mode](mode, *args, **kwargs)
self._dispatch_cache[key] = handler
add_cached_op(self)
@@ -731,24 +741,8 @@
_set_mode_pre_dispatch(top_mode)
with _temporarily_pop_modes_from_pre_dispatch() as curr_mode:
- assert isinstance(curr_mode, TorchDispatchMode)
- overload_types = []
- args_flattened, _ = torch.utils._pytree.tree_flatten(
- (args, kwargs.values())
- )
- for a in args_flattened:
- # TODO: need to double check the semantics of the "types" argument to torch_dispatch.
- # It's generated in PyInterpreter.cpp, but seems to be generated in two places,
- # where in one case we only include tensors with the python key, and in another
- # we include **all** tensors.
- if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(
- a
- ).has(torch._C.DispatchKey.Python):
- overload_types.append(type(a))
- # TODO: check that I got these args correct (in C++, we pass in "0000"??)
-
- return curr_mode.__torch_dispatch__(
- self, overload_types, args, kwargs
+ return torch._library.utils.handle_dispatch_mode(
+ curr_mode, self, *args, **kwargs
)
# Note [Not Caching Per-Dispatch-Key Mode Handlers]
@@ -776,7 +770,6 @@
add_cached_op(self)
return handler
- # print(self, key, final_key)
r = self.py_kernels.get(final_key, final_key)
if cache_result:
self._dispatch_cache[key] = r
@@ -801,6 +794,98 @@
# TODO: add more methods to expose information about input and output arguments
+# TorchBindOpOverload are those custom ops which have at least one overload's
+# schema consists of torch.ScriptObject (i.e. custom class) input.
+# TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python
+# when its inputs contain FakeScriptObject in a similar way as higher order ops.
+class TorchBindOpOverload(OpOverload):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _fallthrough_keys(self) -> List[DispatchKey]:
+ # TODO: we should be calling the fallback for these, but a fallthrough is almost close
+ # enough to the fallback in most cases that we care about.
+ _DEFAULT_FALLTHROUGH_KEYS = [
+ DispatchKey.Autograd,
+ DispatchKey.AutogradCPU,
+ DispatchKey.AutogradCUDA,
+ DispatchKey.ADInplaceOrView,
+ DispatchKey.PythonTLSSnapshot,
+ ]
+
+ def _may_use_fallthrough_instead_of_fallback(key: DispatchKey):
+ if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key):
+ return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
+ self.name(), key
+ )
+
+ return (
+ key not in self.py_kernels
+ or self.py_kernels[key] is torch.library.fallthrough_kernel
+ )
+
+ return [
+ key
+ for key in _DEFAULT_FALLTHROUGH_KEYS
+ if _may_use_fallthrough_instead_of_fallback(key)
+ ]
+
+ # use `self_` to avoid naming collide with arguments that
+ # are named "self". This way, they can be called by kwargs.
+ def __call__(self_, *args, **kwargs): # noqa: B902
+ if _must_dispatch_in_python(args, kwargs):
+ # When any inputs are FakeScriptObject, we need to
+ # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher.
+ return self_._dispatch_in_python(args, kwargs, self_._fallthrough_keys())
+
+ return self_._op(*args, **kwargs)
+
+ def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
+ non_fallthrough_keys = torch._C._dispatch_keyset_full()
+ for key in fallthrough_keys:
+ non_fallthrough_keys = non_fallthrough_keys.remove(key)
+
+ dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys)
+ dispatch_key = dispatch_key_set.highestPriorityTypeId()
+
+ handler = (
+ self._get_dispatch(dispatch_key)
+ if dispatch_key not in self._dispatch_cache
+ else self._dispatch_cache[dispatch_key]
+ )
+
+ if isinstance(handler, DispatchKey):
+ # fallthrough keys can be registered at runtime via torch.library.impl
+ # so need to add it to fallthrough_keys and re-dispatch.
+ if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
+ self.name(), dispatch_key
+ ):
+ return self._dispatch_in_python(
+ args, kwargs, fallthrough_keys + [dispatch_key]
+ )
+
+ raise RuntimeError(
+ f"Cannot handle FakeScriptObject with python dispatcher with dispatch key {handler}."
+ f"Please implement it by annotating a python callable with py_impl({handler})."
+ )
+
+ assert isinstance(handler, Callable) # type: ignore[arg-type]
+ return handler(*args, **kwargs)
+
+
+def _must_dispatch_in_python(args, kwargs):
+ return pytree.tree_any(
+ lambda obj: isinstance(
+ obj, torch._library.fake_class_registry.FakeScriptObject
+ ),
+ (args, kwargs),
+ )
+
+
+def _has_script_object_arg(schema: torch.FunctionSchema) -> bool:
+ return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments)
+
+
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
# You can obtain an OpOverload object through attribute query.
class OpOverloadPacket:
@@ -812,6 +897,9 @@
self._op = op
self._overload_names = overload_names
self._dir = []
+ self._has_torchbind_op_overload = any(
+ _has_script_object_arg(schema) for schema in self._schemas.values()
+ )
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
def __deepcopy__(self, memo=None):
@@ -832,6 +920,13 @@
def op(self):
return self._op
+ @property
+ def _schemas(self):
+ return {
+ overload_name: torch._C._get_schema(self._qualified_op_name, overload_name)
+ for overload_name in self._overload_names
+ }
+
def __getattr__(self, key):
# It is not a valid op_name when __file__ is passed in
if key == "__file__":
@@ -865,7 +960,11 @@
self._qualified_op_name, use_key
)
schema = torch._C._get_schema(self._qualified_op_name, use_key)
- overload = OpOverload(self, op_, op_dk_, schema, tags)
+ overload = (
+ OpOverload(self, op_, op_dk_, schema, tags)
+ if not _has_script_object_arg(schema)
+ else TorchBindOpOverload(self, op_, op_dk_, schema, tags)
+ )
# cache the overload object
setattr(self, key, overload)
self._dir.append(key)
@@ -886,6 +985,12 @@
# is still callable from JIT
# We save the function ptr as the `op` attribute on
# OpOverloadPacket to access it here.
+
+ # Directly calling OverloadPacket goes into C++, which will check
+ # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
+ # intercept it here and call TorchBindOpverload instead.
+ if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
+ return _call_overload_packet_from_python(self_, args, kwargs)
return self_._op(*args, **(kwargs or {}))
# TODO: use this to make a __dir__
@@ -893,6 +998,46 @@
return [n if n else "default" for n in self._overload_names]
+# Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp
+# _jit_get_operations, which calls _get_operation_for_overload_or_packet.
+def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs):
+ # Re-use the torch function handling logic in cpp
+ torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet(
+ op, *args, **kwargs
+ )
+
+ if torch_function_called:
+ return ret
+
+ # The following mirrors getOpWithStack.
+ # In cpp, we do a schema matching for the arguments, and call ToIValue to
+ # to check whether the arguments are valid. But need to do similar things here
+ # and check the schema whether the FakeScriptObject is the corresponding fake class
+ # of the actual class used in schema.
+ exceptions = {}
+ found_op = None
+ for overload_name in op.overloads():
+ op_overload = getattr(op, overload_name)
+ try:
+ _ = torch._C._check_schema_allow_fake_script_object(
+ op_overload._schema, *args, **kwargs
+ )
+ found_op = op_overload
+ break
+ except RuntimeError as e:
+ exceptions[overload_name] = e
+
+ if found_op:
+ return found_op(*args, **kwargs)
+
+ err_msg = (
+ f"Fail to match any TorchBindOverload of {op} with following exceptions:\n"
+ )
+ for i, (key, msg) in enumerate(exceptions.items()):
+ err_msg += f"Overload name {key}:\n {msg}\n"
+ raise RuntimeError(err_msg)
+
+
# Resolution of torch.fn is different from torch.ops.aten.fn
# torch.fn uses the Python argparser, matches with the
# appropriate schema, and calls into the unboxed version of the method
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 057d0fc..5eb4851 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -1667,6 +1667,14 @@
});
m.def(
+ "_check_schema_allow_fake_script_object",
+ [](const FunctionSchema& schema, py::args args, py::kwargs kwargs) {
+ // checkSchemaAllowFakeScriptObject will throw runtime error if there is
+ // a schema mismatch. Otherwise, it returns true.
+ return checkSchemaAllowFakeScriptObject(schema, args, kwargs);
+ });
+
+ m.def(
"_jit_resolve_packet",
[](const char* op_name, py::args args, py::kwargs kwargs) {
try {
@@ -1735,6 +1743,20 @@
py::arg("qualified_name"));
m.def(
+ "_maybe_call_torch_function_for_op_packet",
+ [](py::handle op_overload_packet, py::args args, py::kwargs kwargs) {
+ py::list ns_method =
+ op_overload_packet.attr("_qualified_op_name").attr("split")("::");
+ return _maybe_handle_torch_function(
+ py::cast<std::string>(ns_method[0]),
+ py::cast<std::string>(ns_method[1]),
+ "",
+ false,
+ args,
+ kwargs);
+ });
+
+ m.def(
"parse_ir",
[](const std::string& input, bool parse_tensor_constants) {
auto graph = std::make_shared<Graph>();
diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp
index 9360b83..ba0135a 100644
--- a/torch/csrc/jit/python/pybind_utils.cpp
+++ b/torch/csrc/jit/python/pybind_utils.cpp
@@ -757,6 +757,23 @@
}
}
+// This function is used to check if the schema is valid for the given args and
+// kwargs. It checks script object by checking wether the FakeScriptObject is
+// an instance of the corresponding fake class for the actual class used in
+// schema.
+bool checkSchemaAllowFakeScriptObject(
+ const FunctionSchema& schema,
+ py::args args,
+ const py::kwargs& kwargs) {
+ bool match = false;
+ try {
+ match = matchSchemaAllowFakeScriptObject(schema, std::move(args), kwargs);
+ } catch (schema_match_error& error) {
+ throw std::runtime_error(error.what());
+ }
+ return match;
+}
+
py::object invokeOperatorFromPython(
const std::vector<std::shared_ptr<Operator>>& operations,
py::args args,
@@ -775,13 +792,13 @@
return createPyObjectForStack(std::move(stack));
}
-py::object _get_operation_for_overload_or_packet(
- const std::vector<std::shared_ptr<Operator>>& operations,
- Symbol symbol,
- py::args args,
- const py::kwargs& kwargs,
+py::tuple _maybe_handle_torch_function(
+ const std::string& ns,
+ const std::string& method_name,
+ const std::string& overload_name,
bool is_overload,
- c10::optional<c10::DispatchKey> dk) {
+ py::args args,
+ const py::kwargs& kwargs) {
std::vector<PyObject*> overloaded_args;
size_t total_arg_num = args.size() + kwargs.size();
for (const auto i : c10::irange(args.size())) {
@@ -807,15 +824,11 @@
false /* throw_error */);
}
if (!overloaded_args.empty() || at::impl::torch_function_mode_enabled()) {
- py::object ret;
- std::string ns = symbol.ns().toUnqualString();
- std::string method_name = symbol.toUnqualString();
auto self_func = py::module::import("torch")
.attr("ops")
.attr(ns.c_str())
.attr(method_name.c_str());
if (is_overload) {
- auto overload_name = operations[0]->schema().overload_name();
if (overload_name.empty()) {
self_func = self_func.attr("default");
} else {
@@ -824,16 +837,36 @@
}
std::string module_name("torch.ops");
module_name.append(ns);
- return pybind11::reinterpret_steal<py::object>(
- handle_torch_function_no_python_arg_parser(
- overloaded_args,
- args.ptr(),
- kwargs.ptr(),
- method_name.c_str(),
- self_func.ptr(),
- module_name.c_str()));
+ return py::make_tuple(
+ true,
+ pybind11::reinterpret_steal<py::object>(
+ handle_torch_function_no_python_arg_parser(
+ overloaded_args,
+ args.ptr(),
+ kwargs.ptr(),
+ method_name.c_str(),
+ self_func.ptr(),
+ module_name.c_str())));
}
- return invokeOperatorFromPython(operations, args, kwargs, dk);
+ return py::make_tuple(false, py::none());
+}
+
+py::object _get_operation_for_overload_or_packet(
+ const std::vector<std::shared_ptr<Operator>>& operations,
+ Symbol symbol,
+ py::args args,
+ const py::kwargs& kwargs,
+ bool is_overload,
+ c10::optional<c10::DispatchKey> dk) {
+ std::string ns = symbol.ns().toUnqualString();
+ std::string method_name = symbol.toUnqualString();
+ std::string overload_name = operations[0]->schema().overload_name();
+ auto res = _maybe_handle_torch_function(
+ ns, method_name, overload_name, is_overload, args, kwargs);
+ auto torch_function_called = py::cast<bool>(res[0]);
+ return torch_function_called
+ ? res[1]
+ : invokeOperatorFromPython(operations, args, kwargs, dk);
}
} // namespace torch::jit
diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h
index cbb7791..a78c3e0 100644
--- a/torch/csrc/jit/python/pybind_utils.h
+++ b/torch/csrc/jit/python/pybind_utils.h
@@ -873,6 +873,116 @@
int64_t e;
};
+inline bool validateFakeScriptObjectSchema(
+ const c10::FunctionSchema& schema,
+ size_t argumentPosition,
+ py::handle object) {
+ auto argument = schema.arguments().at(argumentPosition);
+ auto class_type = argument.real_type()->expect<c10::ClassType>();
+ auto fake_class_registry =
+ py::module::import("torch._library.fake_class_registry");
+ auto fake_class = fake_class_registry.attr("find_fake_class")(
+ class_type->name().value().qualifiedName());
+ if (!py::isinstance(object.attr("wrapped_obj"), fake_class)) {
+ throw schema_match_error(c10::str(
+ schema.formatTypeMismatchMsg(
+ argument,
+ friendlyTypeName(object),
+ argumentPosition,
+ py::repr(object.attr("wrapped_obj"))),
+ "\nCast error details: ",
+ argument.name(),
+ " is expected to be a FakeScriptObject of ",
+ class_type->name().value().qualifiedName()));
+ }
+ return true;
+}
+
+inline bool matchSchemaAllowFakeScriptObject(
+ const FunctionSchema& schema,
+ const tuple_slice& args,
+ const py::kwargs& kwargs) {
+ size_t all_arguments = args.size() + kwargs.size();
+ if (all_arguments > schema.arguments().size()) {
+ throw schema_match_error(c10::str(
+ schema.name(),
+ "() expected at most ",
+ schema.arguments().size(),
+ " argument(s) but received ",
+ all_arguments,
+ " argument(s). Declaration: ",
+ schema));
+ }
+
+ int64_t arg_idx = 0;
+ auto fake_class_registry =
+ py::module::import("torch._library.fake_class_registry");
+
+ // First push all positional args.
+ for (const auto& arg : args) {
+ // ...but refuse to do it if the schema says that this was supposed
+ // to be keyword only
+ if (schema.arguments()[arg_idx].kwarg_only()) {
+ throw schema_match_error(c10::str(
+ schema.name(),
+ "() takes ",
+ arg_idx,
+ " positional argument(s) but ",
+ args.size(),
+ " was/were given. Declaration: ",
+ schema));
+ }
+ // Use the type information from the schema to convert the PyObject.
+ const auto& argument = schema.arguments().at(arg_idx);
+ if (argument.real_type()->kind() == TypeKind::ClassType &&
+ py::isinstance(arg, fake_class_registry.attr("FakeScriptObject"))) {
+ validateFakeScriptObjectSchema(schema, arg_idx, arg);
+ } else {
+ argumentToIValue(schema, arg_idx, arg);
+ }
+
+ arg_idx++;
+ }
+
+ // Now for every remaining non-positional argument in the schema, look for it
+ // in the kwargs dict and push it if found, or use its default value if it
+ // has one.
+ size_t consumed_kwargs = 0;
+ for (size_t i = arg_idx; i < schema.arguments().size(); ++i) {
+ const auto& arg = schema.arguments()[i];
+ if (kwargs.contains(arg.name().c_str())) {
+ auto cur_kwarg = kwargs[arg.name().c_str()];
+ if (arg.real_type()->kind() == TypeKind::ClassType &&
+ py::isinstance(
+ cur_kwarg, fake_class_registry.attr("FakeScriptObject"))) {
+ validateFakeScriptObjectSchema(schema, i, cur_kwarg);
+ } else {
+ argumentToIValue(schema, i, cur_kwarg);
+ }
+ consumed_kwargs += 1;
+ } else if (arg.default_value()) {
+ continue;
+ } else {
+ throw schema_match_error(c10::str(
+ schema.name(),
+ "() is missing value for argument '",
+ arg.name(),
+ "'. Declaration: ",
+ schema));
+ }
+ }
+
+ if (consumed_kwargs != kwargs.size()) {
+ std::vector<std::string> names;
+ for (const auto& kwarg : kwargs) {
+ names.emplace_back(py::cast<std::string>(kwarg.first));
+ }
+ throw schema_match_error(schema.findErrorInKwargs(names));
+ }
+
+ return true;
+}
+
inline Stack createStackForSchema(
const FunctionSchema& schema,
const tuple_slice& args,
@@ -1147,6 +1257,19 @@
const py::kwargs& kwargs,
c10::optional<c10::DispatchKey> dk = c10::nullopt);
+TORCH_PYTHON_API py::tuple _maybe_handle_torch_function(
+ const std::string& ns,
+ const std::string& method_name,
+ const std::string& overload_name,
+ bool is_overload,
+ py::args args,
+ const py::kwargs& kwargs);
+
+TORCH_PYTHON_API bool checkSchemaAllowFakeScriptObject(
+ const FunctionSchema& schema,
+ py::args args,
+ const py::kwargs& kwargs);
+
TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
const std::vector<std::shared_ptr<Operator>>& operations,
Symbol symbol,
diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp
index 780f37a..2d115a8 100644
--- a/torch/csrc/utils/python_dispatch.cpp
+++ b/torch/csrc/utils/python_dispatch.cpp
@@ -29,9 +29,7 @@
namespace py = pybind11;
-namespace torch {
-namespace impl {
-namespace dispatch {
+namespace torch::impl::dispatch {
// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
// guarantee that the main interpreter has finish doing all registrations before
@@ -519,6 +517,16 @@
});
m.def(
+ // Returns whether or not the kernel for this dispatach key is a
+ // fallthrough kernel
+ "_dispatch_kernel_for_dispatch_key_is_fallthrough",
+ [](const char* name, c10::DispatchKey dispatch) -> bool {
+ auto op =
+ c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
+ return op->isKernelFallthroughKernel(dispatch);
+ });
+
+ m.def(
"_dispatch_has_kernel_for_any_dispatch_key",
[](const char* name, c10::DispatchKeySet ks) -> bool {
auto op =
@@ -938,6 +946,4 @@
pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}
-} // namespace dispatch
-} // namespace impl
-} // namespace torch
+} // namespace torch::impl::dispatch
diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py
index 933de44..f66388d 100644
--- a/torch/testing/_internal/torchbind_impls.py
+++ b/torch/testing/_internal/torchbind_impls.py
@@ -3,7 +3,7 @@
def register_if_not(qualname):
entry = torch._library.simple_registry.singleton.find(qualname)
- if entry.abstract_impl.kernel is not None:
+ if entry.abstract_impl.kernel is None:
return torch.library.impl_abstract(qualname)
else:
@@ -28,5 +28,5 @@
return tq.push(x)
@register_if_not("_TorchScriptTesting::queue_size")
- def fake_queue_size(tq, x):
+ def fake_queue_size(tq):
return tq.size()