[custom_op] fix schema inference for kwarg-only args (#124637)
Test Plan:
- new tests
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124637
Approved by: https://github.com/williamwen42, https://github.com/albanD
diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py
index 9c36762..58c2ba7 100644
--- a/test/test_custom_ops.py
+++ b/test/test_custom_ops.py
@@ -598,6 +598,18 @@
self.assertExpectedInline(infer_schema(a), """(Tensor x) -> Tensor""")
+ def kwonly1(x: Tensor, *, y: int, z: float) -> Tensor:
+ return torch.empty([])
+
+ self.assertExpectedInline(
+ infer_schema(kwonly1), """(Tensor x, *, SymInt y, float z) -> Tensor"""
+ )
+
+ def kwonly2(*, y: Tensor) -> Tensor:
+ return torch.empty([])
+
+ self.assertExpectedInline(infer_schema(kwonly2), """(*, Tensor y) -> Tensor""")
+
def b(
x: Tensor,
y: int,
diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py
index e85803d..fd03f91 100644
--- a/torch/_library/infer_schema.py
+++ b/torch/_library/infer_schema.py
@@ -23,10 +23,17 @@
params = []
seen_args = set()
+ saw_kwarg_only_arg = False
for idx, (name, param) in enumerate(sig.parameters.items()):
if not supported_param(param):
error_fn("We do not support positional-only args, varargs, or varkwargs.")
+ if param.kind == inspect.Parameter.KEYWORD_ONLY:
+ # The first time we see a kwarg-only arg, add "*" to the schema.
+ if not saw_kwarg_only_arg:
+ params.append("*")
+ saw_kwarg_only_arg = True
+
if param.annotation is inspect.Parameter.empty:
error_fn(f"Parameter {name} must have a type annotation.")