[reland][custom_op] Change the python type that maps to ListType in schema (#101451)
Reland of #101190. Original stack was reverted due to internal test
flakiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101451
Approved by: https://github.com/soulitzer
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index ad2c875..6270148 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -18,7 +18,8 @@
from torch._custom_op import custom_op, CustomOp
from torch.fx.experimental.proxy_tensor import make_fx
import typing
-from typing import Optional, Tuple, Union, List, Callable
+import collections
+from typing import Optional, Tuple, Union, List, Callable, Sequence
from torch import Tensor
import itertools
@@ -535,9 +536,9 @@
assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
elt = args[0] if args[1] is type(None) else args[1]
return generate_examples(elt) + [None]
- if origin is tuple:
+ if origin is collections.abc.Sequence:
args = typing.get_args(typ)
- assert len(args) == 2 and args[1] == ...
+ assert len(args) == 1
examples = generate_examples(args[0])
return list(itertools.product(examples, examples)) + []
raise AssertionError(f"unsupported param type {typ}")
@@ -565,11 +566,43 @@
del foo
del foo_cpu
+ def test_sequences(self):
+ # Sequence[int] gets automagically turned into int[] in the schema.
+ # This test checks that we actually do support arbitrary sequence types.
+ class MySequence(collections.abc.Sequence):
+ def __init__(self):
+ self._container = [1, 2, 3]
+
+ def __getitem__(self, idx):
+ return self._container[idx]
+
+ def __len__(self):
+ return len(self._container)
+
+ @custom_op("blah::foo")
+ def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
+ ...
+
+ called = 0
+
+ @foo.impl('cpu')
+ def foo_cpu(x, sizes):
+ nonlocal called
+ called += 1
+ # Dispatcher will normalize the sequence type into a List
+ self.assertEqual(sizes, [1, 2, 3])
+ return x.clone()
+
+ x = torch.randn([])
+ seq = MySequence()
+ foo(x, seq)
+ self.assertEqual(called, 1)
+
def test_unsupported_param_types(self):
# Not comprehensive (it doesn't need to be), just a check that our mechanism works
with self.assertRaisesRegex(ValueError, 'unsupported type'):
@custom_op(f'{TestCustomOp.test_ns}::foo')
- def foo(x: Tensor, y: Tuple[Optional[int], ...]) -> Tensor:
+ def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
...
del foo
@@ -582,7 +615,7 @@
with self.assertRaisesRegex(ValueError, 'unsupported type'):
# We could theoretically support this, but the syntax for suporting
- # int[] is Tuple[int, ...]
+ # int[] is Sequence[int]
@custom_op(f'{TestCustomOp.test_ns}::foo')
def foo(x: Tensor, y: List[int]) -> Tensor:
...
@@ -698,7 +731,7 @@
foo._destroy()
@custom_op(f'{TestCustomOp.test_ns}::foo')
- def foo(x: Tuple[torch.Tensor, ...]) -> torch.Tensor:
+ def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
...
x = torch.randn(3, requires_grad=True)
diff --git a/torch/_custom_op.py b/torch/_custom_op.py
index df13a6a..71f090e 100644
--- a/torch/_custom_op.py
+++ b/torch/_custom_op.py
@@ -740,11 +740,11 @@
(typing.Optional[base_type], f"{cpp_type}?"),
]
if list_base:
- result.append((typing.Tuple[base_type, ...], f"{cpp_type}[]"))
+ result.append((typing.Sequence[base_type], f"{cpp_type}[]")) # type: ignore[valid-type]
if optional_base_list:
- result.append((typing.Tuple[typing.Optional[base_type], ...], f"{cpp_type}?[]"))
+ result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]")) # type: ignore[valid-type]
if optional_list_base:
- result.append((typing.Optional[typing.Tuple[base_type, ...]], f"{cpp_type}[]?"))
+ result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?")) # type: ignore[valid-type]
return result
diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py
index 3290f5e..df84552 100644
--- a/torch/_prims/debug_prims.py
+++ b/torch/_prims/debug_prims.py
@@ -1,5 +1,5 @@
import contextlib
-from typing import Tuple
+from typing import Sequence
import torch
from torch._custom_op import custom_op
@@ -29,8 +29,8 @@
@custom_op("debugprims::load_tensor")
def load_tensor(
name: str,
- size: Tuple[int, ...],
- stride: Tuple[int, ...],
+ size: Sequence[int],
+ stride: Sequence[int],
*,
dtype: torch.dtype,
device: torch.device,