[dynamo] add infinite generators `itertools.{count, repeat, cycle}` (#110967)

Fixes https://github.com/pytorch/pytorch/pull/110953/files#r1352868935

Depends on: https://github.com/pytorch/pytorch/pull/110953

Why not use these for `repeat(item, count)`:
> These are not preferred as they return an opaque VariableTracker. In particular, one cannot do `enumerate(repeat(1))`. `repeat(1, 10)` benefits from the integration enjoyed by `ListVariableIterator`

Follow ups:
- [ ] make listiterator an IteratorVariable, define iterator integrations on base IteratorVariable where unspecialized https://github.com/pytorch/pytorch/pull/110967#discussion_r1356656469
    - Please make a new issue for this
- [ ] explore integrating cpython itertools test suite https://github.com/pytorch/pytorch/pull/110967#discussion_r1358326402
- [ ] Use something other than `StopIteration` to handle iterator termination https://github.com/pytorch/pytorch/pull/110967#discussion_r1358336038
- [ ] Add test case for consuming iterator simultaneously from two code points https://github.com/pytorch/pytorch/pull/110967/files#r1358325511

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110967
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index f40c12a..3a80ed9 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -7609,6 +7609,103 @@
         self.assertEqual(list(eager), list(compiled))
         self.assertEqual(len(counters["graph_break"]), 0)
 
+    def test_itertools_infinite_repeat(self):
+        counters.clear()
+
+        def fn(x):
+            r = itertools.repeat(100.0)
+            idx = 0
+            for i in r:
+                x += i
+                idx += 1
+                if idx > 10:
+                    break
+            return x
+
+        x = torch.randn([2, 5])
+        eager = fn(x)
+
+        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+        compiled = compiled_fn(x)
+
+        self.assertEqual(list(eager), list(compiled))
+        self.assertEqual(len(counters["graph_break"]), 0)
+
+    def test_itertools_infinite_repeat_mutation(self):
+        counters.clear()
+
+        def fn(x):
+            r = itertools.repeat(x)
+            idx = 0
+            for i in r:
+                x += i
+                i += 1
+                idx += 1
+                if idx > 10:
+                    break
+            return x
+
+        x = torch.randn([2, 5])
+        eager = fn(x)
+
+        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+        compiled = compiled_fn(x)
+
+        self.assertEqual(list(eager), list(compiled))
+        self.assertEqual(len(counters["graph_break"]), 0)
+
+    def test_itertools_infinite_count(self):
+        for args in ([], [10], [5, -1]):
+            counters.clear()
+
+            def fn(x):
+                r = itertools.count(*args)
+                idx = 0
+                for i in r:
+                    x += i
+                    idx += 1
+                    if idx > 10:
+                        break
+                return x
+
+            x = torch.randn([2, 5])
+            eager = fn(x)
+
+            compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+            compiled = compiled_fn(x)
+
+            self.assertEqual(list(eager), list(compiled))
+            self.assertEqual(len(counters["graph_break"]), 0)
+
+    def test_itertools_infinite_cycle(self):
+        counters.clear()
+
+        def fn(x):
+            for iterator in (
+                iter([]),
+                iter([10, 11.0]),
+                itertools.repeat(-1, 3),
+                itertools.count(10),
+            ):
+                r = itertools.cycle(iterator)
+                idx = 0
+                x += 1
+                for i in r:
+                    x += i
+                    idx += 1
+                    if idx > 10:
+                        break
+            return x
+
+        x = torch.randn([2, 5])
+        eager = fn(x)
+
+        compiled_fn = torch._dynamo.optimize(backend="eager", nopython=True)(fn)
+        compiled = compiled_fn(x)
+
+        self.assertEqual(list(eager), list(compiled))
+        self.assertEqual(len(counters["graph_break"]), 0)
+
     def test_itertools_accumulate_symint_default_sum(self):
         # https://github.com/pytorch/pytorch/issues/110287
         counters.clear()
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 61c25c1..e1d7d96 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1065,11 +1065,10 @@
 
     def FOR_ITER(self, inst):
         it = self.pop()
-        if isinstance(it, ListIteratorVariable):
+        if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)):
             self.output.guards.update(it.guards)
             try:
-                val, next_iter = it.next_variables()
-                self.replace_all(it, next_iter)
+                val, next_iter = it.next_variables(self)
                 self.push(next_iter)
                 self.push(val)
             except StopIteration:
@@ -2559,11 +2558,12 @@
             if isinstance(tos, ConstantVariable) and tos.value is None:
                 self.pop()
                 return
-            if isinstance(tos, ListIteratorVariable):
+            if isinstance(
+                tos, (variables.ListIteratorVariable, variables.IteratorVariable)
+            ):
                 self.output.guards.update(tos.guards)
                 try:
-                    val, next_iter = tos.next_variables()
-                    self.replace_all(tos, next_iter)
+                    val, next_iter = tos.next_variables(self)
                     self.push(val)
                     # TODO(voz): Unclear if we need the push None in YIELD_VALUE?
                     self.YIELD_VALUE(inst)
diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py
index ea21dad..543c980 100644
--- a/torch/_dynamo/variables/__init__.py
+++ b/torch/_dynamo/variables/__init__.py
@@ -24,6 +24,12 @@
     UserMethodVariable,
 )
 from .higher_order_ops import TorchHigherOrderOperatorVariable
+from .iter import (
+    CountIteratorVariable,
+    CycleIteratorVariable,
+    IteratorVariable,
+    RepeatIteratorVariable,
+)
 from .lists import (
     BaseListVariable,
     ListIteratorVariable,
@@ -79,6 +85,10 @@
     "GetAttrVariable",
     "GradModeVariable",
     "InspectSignatureVariable",
+    "IteratorVariable",
+    "RepeatIteratorVariable",
+    "CountIteratorVariable",
+    "CycleIteratorVariable",
     "LambdaVariable",
     "ListIteratorVariable",
     "ListVariable",
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index e5b80b4..3252eef 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -800,6 +800,12 @@
     def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs):
         if self._dynamic_args(*args, **kwargs):
             return self._dyn_proxy(tx, *args, **kwargs)
+
+        if isinstance(obj, variables.IteratorVariable):
+            # For non-list iterators, we will guard on vars that
+            # determine the control flow
+            return obj
+
         # TODO This should probably be treated as a dict, or dicts should also be treated here
         if self.fn == set:
             cls = SetVariable
@@ -965,9 +971,10 @@
         return variables.SuperVariable(a, b)
 
     def call_next(self, tx, arg):
-        if isinstance(arg, variables.ListIteratorVariable):
-            val, next_iter = arg.next_variables()
-            tx.replace_all(arg, next_iter)
+        if isinstance(
+            arg, (variables.ListIteratorVariable, variables.IteratorVariable)
+        ):
+            val, next_iter = arg.next_variables(tx)
             return val
         elif isinstance(arg, variables.BaseListVariable):
             return arg.items[0].add_options(self, arg)
diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py
new file mode 100644
index 0000000..453fb82
--- /dev/null
+++ b/torch/_dynamo/variables/iter.py
@@ -0,0 +1,101 @@
+MAX_CYCLE = 3000
+
+from typing import List, Optional
+
+from ..exc import unimplemented
+
+from .base import VariableTracker
+from .constant import ConstantVariable
+
+
+class IteratorVariable(VariableTracker):
+    def __init__(self, **kwargs):
+        super().__init__(**kwargs)
+
+    def next_variables(self, tx):
+        unimplemented("abstract method, must implement")
+
+
+class RepeatIteratorVariable(IteratorVariable):
+    def __init__(self, item: VariableTracker, **kwargs):
+        super().__init__(**kwargs)
+        self.item = item
+
+    # Repeat needs no mutation, clone self
+    def next_variables(self, tx):
+        # add_options will clone self.item
+        return self.item.add_options(self), self
+
+
+class CountIteratorVariable(IteratorVariable):
+    def __init__(self, item: int = 0, step: int = 1, **kwargs):
+        super().__init__(**kwargs)
+        if not isinstance(item, VariableTracker):
+            item = ConstantVariable.create(item)
+        if not isinstance(step, VariableTracker):
+            step = ConstantVariable.create(step)
+        self.item = item
+        self.step = step
+
+    def next_variables(self, tx):
+        assert self.mutable_local
+        next_item = self.item.call_method(tx, "__add__", [self.step], {})
+        next_iter = self.clone(item=next_item)
+        tx.replace_all(self, next_iter)
+        return self.item.add_options(self), next_iter
+
+
+class CycleIteratorVariable(IteratorVariable):
+    def __init__(
+        self,
+        iterator: IteratorVariable,
+        saved: List[VariableTracker] = None,
+        saved_index: int = 0,
+        item: Optional[VariableTracker] = None,
+        **kwargs,
+    ):
+        if saved is None:
+            saved = []
+        super().__init__(**kwargs)
+        self.iterator = iterator
+        self.saved = saved
+        self.saved_index = saved_index
+        self.item = item
+
+    def next_variables(self, tx):
+        assert self.mutable_local
+
+        if self.iterator is not None:
+            try:
+                new_item, next_inner_iter = self.iterator.next_variables(tx)
+                tx.replace_all(self.iterator, next_inner_iter)
+                if len(self.saved) > MAX_CYCLE:
+                    unimplemented(
+                        "input iterator to itertools.cycle has too many items"
+                    )
+                next_iter = self.clone(
+                    iterator=next_inner_iter,
+                    saved=self.saved + [new_item],
+                    item=new_item,
+                )
+
+                tx.replace_all(self, next_iter)
+                if self.item is None:
+                    return next_iter.next_variables(tx)
+                return self.item.add_options(self), next_iter
+            except StopIteration:
+                next_iter = self.clone(iterator=None)
+                # this is redundant as next_iter will do the same
+                # but we do it anyway for safety
+                tx.replace_all(self, next_iter)
+                return next_iter.next_variables(tx)
+        elif len(self.saved) > 0:
+            next_iter = self.clone(
+                saved_index=(self.saved_index + 1) % len(self.saved),
+                item=self.saved[self.saved_index],
+            )
+            tx.replace_all(self, next_iter)
+            return self.item.add_options(self), next_iter
+        else:
+            raise StopIteration
+        return self.item.add_options(self), next_iter
diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py
index a505b48..3af048d 100644
--- a/torch/_dynamo/variables/lists.py
+++ b/torch/_dynamo/variables/lists.py
@@ -683,16 +683,18 @@
         self.items = items
         self.index = index
 
-    def next_variables(self):
+    def next_variables(self, tx):
         assert self.mutable_local
         if self.index >= len(self.items):
             raise StopIteration()
-        return self.items[self.index].add_options(self), ListIteratorVariable(
+        next_iter = ListIteratorVariable(
             self.items,
             self.index + 1,
             mutable_local=MutableLocal(),
             **VariableTracker.propagate([self]),
         )
+        tx.replace_all(self, next_iter)
+        return self.items[self.index].add_options(self), next_iter
 
     def as_python_constant(self):
         if self.index > 0:
diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py
index 75f356b..4caabd1 100644
--- a/torch/_dynamo/variables/misc.py
+++ b/torch/_dynamo/variables/misc.py
@@ -885,16 +885,20 @@
                 fn, args=rest_args, keywords=kwargs, **options
             )
         elif self.value is itertools.repeat:
-            from .builder import SourcelessBuilder
-
             if len(args) < 2:
-                # We cannot risk infinite generator being consumed to exhaustion by dynamo
-                # (i.e. infinite loop)
-                unimplemented("Infinite repeat is not supported")
+                return variables.RepeatIteratorVariable(
+                    *args, mutable_local=MutableLocal()
+                )
+
+            from .builder import SourcelessBuilder
 
             return tx.inline_user_function_return(
                 SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
             )
+        elif self.value is itertools.count:
+            return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
+        elif self.value is itertools.cycle:
+            return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
         else:
             try:
                 path = inspect.getfile(self.value)