[dynamo] add infinite generators `itertools.{count, repeat, cycle}` (#110967)
Fixes https://github.com/pytorch/pytorch/pull/110953/files#r1352868935
Depends on: https://github.com/pytorch/pytorch/pull/110953
Why not use these for `repeat(item, count)`:
> These are not preferred as they return an opaque VariableTracker. In particular, one cannot do `enumerate(repeat(1))`. `repeat(1, 10)` benefits from the integration enjoyed by `ListVariableIterator`
Follow ups:
- [ ] make listiterator an IteratorVariable, define iterator integrations on base IteratorVariable where unspecialized https://github.com/pytorch/pytorch/pull/110967#discussion_r1356656469
- Please make a new issue for this
- [ ] explore integrating cpython itertools test suite https://github.com/pytorch/pytorch/pull/110967#discussion_r1358326402
- [ ] Use something other than `StopIteration` to handle iterator termination https://github.com/pytorch/pytorch/pull/110967#discussion_r1358336038
- [ ] Add test case for consuming iterator simultaneously from two code points https://github.com/pytorch/pytorch/pull/110967/files#r1358325511
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110967
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index f40c12a..3a80ed9 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -7609,6 +7609,103 @@
self.assertEqual(list(eager), list(compiled))
self.assertEqual(len(counters["graph_break"]), 0)
+ def test_itertools_infinite_repeat(self):
+ counters.clear()
+
+ def fn(x):
+ r = itertools.repeat(100.0)
+ idx = 0
+ for i in r:
+ x += i
+ idx += 1
+ if idx > 10:
+ break
+ return x
+
+ x = torch.randn([2, 5])
+ eager = fn(x)
+
+ compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+ compiled = compiled_fn(x)
+
+ self.assertEqual(list(eager), list(compiled))
+ self.assertEqual(len(counters["graph_break"]), 0)
+
+ def test_itertools_infinite_repeat_mutation(self):
+ counters.clear()
+
+ def fn(x):
+ r = itertools.repeat(x)
+ idx = 0
+ for i in r:
+ x += i
+ i += 1
+ idx += 1
+ if idx > 10:
+ break
+ return x
+
+ x = torch.randn([2, 5])
+ eager = fn(x)
+
+ compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+ compiled = compiled_fn(x)
+
+ self.assertEqual(list(eager), list(compiled))
+ self.assertEqual(len(counters["graph_break"]), 0)
+
+ def test_itertools_infinite_count(self):
+ for args in ([], [10], [5, -1]):
+ counters.clear()
+
+ def fn(x):
+ r = itertools.count(*args)
+ idx = 0
+ for i in r:
+ x += i
+ idx += 1
+ if idx > 10:
+ break
+ return x
+
+ x = torch.randn([2, 5])
+ eager = fn(x)
+
+ compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+ compiled = compiled_fn(x)
+
+ self.assertEqual(list(eager), list(compiled))
+ self.assertEqual(len(counters["graph_break"]), 0)
+
+ def test_itertools_infinite_cycle(self):
+ counters.clear()
+
+ def fn(x):
+ for iterator in (
+ iter([]),
+ iter([10, 11.0]),
+ itertools.repeat(-1, 3),
+ itertools.count(10),
+ ):
+ r = itertools.cycle(iterator)
+ idx = 0
+ x += 1
+ for i in r:
+ x += i
+ idx += 1
+ if idx > 10:
+ break
+ return x
+
+ x = torch.randn([2, 5])
+ eager = fn(x)
+
+ compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+ compiled = compiled_fn(x)
+
+ self.assertEqual(list(eager), list(compiled))
+ self.assertEqual(len(counters["graph_break"]), 0)
+
def test_itertools_accumulate_symint_default_sum(self):
# https://github.com/pytorch/pytorch/issues/110287
counters.clear()
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 61c25c1..e1d7d96 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1065,11 +1065,10 @@
def FOR_ITER(self, inst):
it = self.pop()
- if isinstance(it, ListIteratorVariable):
+ if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
self.output.guards.update(it.guards)
try:
- val, next_iter = it.next_variables()
- self.replace_all(it, next_iter)
+ val, next_iter = it.next_variables(self)
self.push(next_iter)
self.push(val)
except StopIteration:
@@ -2559,11 +2558,12 @@
if isinstance(tos, ConstantVariable) and tos.value is None:
self.pop()
return
- if isinstance(tos, ListIteratorVariable):
+ if isinstance(
+ tos, (variables.ListIteratorVariable, variables.IteratorVariable)
+ ):
self.output.guards.update(tos.guards)
try:
- val, next_iter = tos.next_variables()
- self.replace_all(tos, next_iter)
+ val, next_iter = tos.next_variables(self)
self.push(val)
# TODO(voz): Unclear if we need the push None in YIELD_VALUE?
self.YIELD_VALUE(inst)
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
index ea21dad..543c980 100644
--- a/torch/_dynamo/variables/__init__.py
+++ b/torch/_dynamo/variables/__init__.py
@@ -24,6 +24,12 @@
UserMethodVariable,
)
from .higher_order_ops import TorchHigherOrderOperatorVariable
+from .iter import (
+ CountIteratorVariable,
+ CycleIteratorVariable,
+ IteratorVariable,
+ RepeatIteratorVariable,
+)
from .lists import (
BaseListVariable,
ListIteratorVariable,
@@ -79,6 +85,10 @@
"GetAttrVariable",
"GradModeVariable",
"InspectSignatureVariable",
+ "IteratorVariable",
+ "RepeatIteratorVariable",
+ "CountIteratorVariable",
+ "CycleIteratorVariable",
"LambdaVariable",
"ListIteratorVariable",
"ListVariable",
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index e5b80b4..3252eef 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -800,6 +800,12 @@
def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
if self._dynamic_args(*args, **kwargs):
return self._dyn_proxy(tx, *args, **kwargs)
+
+ if isinstance(obj, variables.IteratorVariable):
+ # For non-list iterators, we will guard on vars that
+ # determine the control flow
+ return obj
+
# TODO This should probably be treated as a dict, or dicts should also be treated here
if self.fn == set:
cls = SetVariable
@@ -965,9 +971,10 @@
return variables.SuperVariable(a, b)
def call_next(self, tx, arg):
- if isinstance(arg, variables.ListIteratorVariable):
- val, next_iter = arg.next_variables()
- tx.replace_all(arg, next_iter)
+ if isinstance(
+ arg, (variables.ListIteratorVariable, variables.IteratorVariable)
+ ):
+ val, next_iter = arg.next_variables(tx)
return val
elif isinstance(arg, variables.BaseListVariable):
return arg.items[0].add_options(self, arg)
diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py
new file mode 100644
index 0000000..453fb82
--- /dev/null
+++ b/torch/_dynamo/variables/iter.py
@@ -0,0 +1,101 @@
+MAX_CYCLE = 3000
+
+from typing import List, Optional
+
+from ..exc import unimplemented
+
+from .base import VariableTracker
+from .constant import ConstantVariable
+
+
+class IteratorVariable(VariableTracker):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+
+ def next_variables(self, tx):
+ unimplemented("abstract method, must implement")
+
+
+class RepeatIteratorVariable(IteratorVariable):
+ def __init__(self, item: VariableTracker, **kwargs):
+ super().__init__(**kwargs)
+ self.item = item
+
+ # Repeat needs no mutation, clone self
+ def next_variables(self, tx):
+ # add_options will clone self.item
+ return self.item.add_options(self), self
+
+
+class CountIteratorVariable(IteratorVariable):
+ def __init__(self, item: int = 0, step: int = 1, **kwargs):
+ super().__init__(**kwargs)
+ if not isinstance(item, VariableTracker):
+ item = ConstantVariable.create(item)
+ if not isinstance(step, VariableTracker):
+ step = ConstantVariable.create(step)
+ self.item = item
+ self.step = step
+
+ def next_variables(self, tx):
+ assert self.mutable_local
+ next_item = self.item.call_method(tx, "__add__", [self.step], {})
+ next_iter = self.clone(item=next_item)
+ tx.replace_all(self, next_iter)
+ return self.item.add_options(self), next_iter
+
+
+class CycleIteratorVariable(IteratorVariable):
+ def __init__(
+ self,
+ iterator: IteratorVariable,
+ saved: List[VariableTracker] = None,
+ saved_index: int = 0,
+ item: Optional[VariableTracker] = None,
+ **kwargs,
+ ):
+ if saved is None:
+ saved = []
+ super().__init__(**kwargs)
+ self.iterator = iterator
+ self.saved = saved
+ self.saved_index = saved_index
+ self.item = item
+
+ def next_variables(self, tx):
+ assert self.mutable_local
+
+ if self.iterator is not None:
+ try:
+ new_item, next_inner_iter = self.iterator.next_variables(tx)
+ tx.replace_all(self.iterator, next_inner_iter)
+ if len(self.saved) > MAX_CYCLE:
+ unimplemented(
+ "input iterator to itertools.cycle has too many items"
+ )
+ next_iter = self.clone(
+ iterator=next_inner_iter,
+ saved=self.saved + [new_item],
+ item=new_item,
+ )
+
+ tx.replace_all(self, next_iter)
+ if self.item is None:
+ return next_iter.next_variables(tx)
+ return self.item.add_options(self), next_iter
+ except StopIteration:
+ next_iter = self.clone(iterator=None)
+ # this is redundant as next_iter will do the same
+ # but we do it anyway for safety
+ tx.replace_all(self, next_iter)
+ return next_iter.next_variables(tx)
+ elif len(self.saved) > 0:
+ next_iter = self.clone(
+ saved_index=(self.saved_index + 1) % len(self.saved),
+ item=self.saved[self.saved_index],
+ )
+ tx.replace_all(self, next_iter)
+ return self.item.add_options(self), next_iter
+ else:
+ raise StopIteration
+ return self.item.add_options(self), next_iter
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index a505b48..3af048d 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -683,16 +683,18 @@
self.items = items
self.index = index
- def next_variables(self):
+ def next_variables(self, tx):
assert self.mutable_local
if self.index >= len(self.items):
raise StopIteration()
- return self.items[self.index].add_options(self), ListIteratorVariable(
+ next_iter = ListIteratorVariable(
self.items,
self.index + 1,
mutable_local=MutableLocal(),
**VariableTracker.propagate([self]),
)
+ tx.replace_all(self, next_iter)
+ return self.items[self.index].add_options(self), next_iter
def as_python_constant(self):
if self.index > 0:
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index 75f356b..4caabd1 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -885,16 +885,20 @@
fn, args=rest_args, keywords=kwargs, **options
)
elif self.value is itertools.repeat:
- from .builder import SourcelessBuilder
-
if len(args) < 2:
- # We cannot risk infinite generator being consumed to exhaustion by dynamo
- # (i.e. infinite loop)
- unimplemented("Infinite repeat is not supported")
+ return variables.RepeatIteratorVariable(
+ *args, mutable_local=MutableLocal()
+ )
+
+ from .builder import SourcelessBuilder
return tx.inline_user_function_return(
SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
)
+ elif self.value is itertools.count:
+ return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
+ elif self.value is itertools.cycle:
+ return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
else:
try:
path = inspect.getfile(self.value)