[dynamo] list index: add more list types to testing, support namedtuple, improve error handling (#110919)
Follow up: #110817
Minor improvements as discussed in prev PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110919
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index b5d6198..e030ea7 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -3068,18 +3068,31 @@
)
def test_list_index(self):
- for index in ([], [2], [0, 3]):
-
- def f(t):
- xs = ["bar", "foo", "baz", "buzz"]
- res = xs.index("baz", *index)
- return t + res
-
- res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
- torch.zeros(1)
+ for i, list_type in enumerate(
+ (
+ list,
+ tuple,
+ torch.Size,
+ collections.deque,
+ namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]),
)
+ ):
+ torch._dynamo.reset()
+ for index in ([], [2], [0, 3]):
- self.assertEqual(res, torch.tensor([2.0]))
+ def f(t):
+ if i == 4: # namedtuple
+ xs = list_type(1, 2, 3, 4)
+ else:
+ xs = list_type([1, 2, 3, 4])
+ res = xs.index(3, *index)
+ return t + res
+
+ res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
+ torch.zeros(1)
+ )
+
+ self.assertEqual(res, torch.tensor([2.0]))
def test_list_index_not_found(self):
def f(t):
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index 632eba9..ffcf3e2 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -155,8 +155,6 @@
elif name == "index":
from .builder import SourcelessBuilder
- assert len(kwargs) == 0
- assert len(args) > 0 and len(args) <= 3
return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.index), [self] + list(args), kwargs
)
@@ -653,7 +651,7 @@
if name not in fields:
method = check_and_create_method()
if not method:
- unimplemented(f"NamedTupleVariable.{name}")
+ super().var_getattr(tx, name)
return method
return self.items[fields.index(name)].add_options(self)