[reland][custom_op] Change the python type that maps to ListType in schema (#101451)

Reland of #101190. Original stack was reverted due to internal test
flakiness.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101451
Approved by: https://github.com/soulitzer
diff --git a/test/test_python_dispatch.py b/test/test_python_dispatch.py
index ad2c875..6270148 100644
--- a/test/test_python_dispatch.py
+++ b/test/test_python_dispatch.py
@@ -18,7 +18,8 @@
 from torch._custom_op import custom_op, CustomOp
 from torch.fx.experimental.proxy_tensor import make_fx
 import typing
-from typing import Optional, Tuple, Union, List, Callable
+import collections
+from typing import Optional, Tuple, Union, List, Callable, Sequence
 from torch import Tensor
 import itertools
 
@@ -535,9 +536,9 @@
                 assert len(args) == 2 and (args[0] is type(None) or args[1] is type(None))
                 elt = args[0] if args[1] is type(None) else args[1]
                 return generate_examples(elt) + [None]
-            if origin is tuple:
+            if origin is collections.abc.Sequence:
                 args = typing.get_args(typ)
-                assert len(args) == 2 and args[1] == ...
+                assert len(args) == 1
                 examples = generate_examples(args[0])
                 return list(itertools.product(examples, examples)) + []
             raise AssertionError(f"unsupported param type {typ}")
@@ -565,11 +566,43 @@
                 del foo
                 del foo_cpu
 
+    def test_sequences(self):
+        # Sequence[int] gets automagically turned into int[] in the schema.
+        # This test checks that we actually do support arbitrary sequence types.
+        class MySequence(collections.abc.Sequence):
+            def __init__(self):
+                self._container = [1, 2, 3]
+
+            def __getitem__(self, idx):
+                return self._container[idx]
+
+            def __len__(self):
+                return len(self._container)
+
+        @custom_op("blah::foo")
+        def foo(x: torch.Tensor, sizes: Sequence[int]) -> torch.Tensor:
+            ...
+
+        called = 0
+
+        @foo.impl('cpu')
+        def foo_cpu(x, sizes):
+            nonlocal called
+            called += 1
+            # Dispatcher will normalize the sequence type into a List
+            self.assertEqual(sizes, [1, 2, 3])
+            return x.clone()
+
+        x = torch.randn([])
+        seq = MySequence()
+        foo(x, seq)
+        self.assertEqual(called, 1)
+
     def test_unsupported_param_types(self):
         # Not comprehensive (it doesn't need to be), just a check that our mechanism works
         with self.assertRaisesRegex(ValueError, 'unsupported type'):
             @custom_op(f'{TestCustomOp.test_ns}::foo')
-            def foo(x: Tensor, y: Tuple[Optional[int], ...]) -> Tensor:
+            def foo(x: Tensor, y: List[Optional[int]]) -> Tensor:
                 ...
             del foo
 
@@ -582,7 +615,7 @@
 
         with self.assertRaisesRegex(ValueError, 'unsupported type'):
             # We could theoretically support this, but the syntax for suporting
-            # int[] is Tuple[int, ...]
+            # int[] is Sequence[int]
             @custom_op(f'{TestCustomOp.test_ns}::foo')
             def foo(x: Tensor, y: List[int]) -> Tensor:
                 ...
@@ -698,7 +731,7 @@
         foo._destroy()
 
         @custom_op(f'{TestCustomOp.test_ns}::foo')
-        def foo(x: Tuple[torch.Tensor, ...]) -> torch.Tensor:
+        def foo(x: Sequence[torch.Tensor]) -> torch.Tensor:
             ...
 
         x = torch.randn(3, requires_grad=True)
diff --git a/torch/_custom_op.py b/torch/_custom_op.py
index df13a6a..71f090e 100644
--- a/torch/_custom_op.py
+++ b/torch/_custom_op.py
@@ -740,11 +740,11 @@
         (typing.Optional[base_type], f"{cpp_type}?"),
     ]
     if list_base:
-        result.append((typing.Tuple[base_type, ...], f"{cpp_type}[]"))
+        result.append((typing.Sequence[base_type], f"{cpp_type}[]"))  # type: ignore[valid-type]
     if optional_base_list:
-        result.append((typing.Tuple[typing.Optional[base_type], ...], f"{cpp_type}?[]"))
+        result.append((typing.Sequence[typing.Optional[base_type]], f"{cpp_type}?[]"))  # type: ignore[valid-type]
     if optional_list_base:
-        result.append((typing.Optional[typing.Tuple[base_type, ...]], f"{cpp_type}[]?"))
+        result.append((typing.Optional[typing.Sequence[base_type]], f"{cpp_type}[]?"))  # type: ignore[valid-type]
     return result
 
 
diff --git a/torch/_prims/debug_prims.py b/torch/_prims/debug_prims.py
index 3290f5e..df84552 100644
--- a/torch/_prims/debug_prims.py
+++ b/torch/_prims/debug_prims.py
@@ -1,5 +1,5 @@
 import contextlib
-from typing import Tuple
+from typing import Sequence
 
 import torch
 from torch._custom_op import custom_op
@@ -29,8 +29,8 @@
     @custom_op("debugprims::load_tensor")
     def load_tensor(
         name: str,
-        size: Tuple[int, ...],
-        stride: Tuple[int, ...],
+        size: Sequence[int],
+        stride: Sequence[int],
         *,
         dtype: torch.dtype,
         device: torch.device,