Revert "Limit loop unrolling (#120023)"
This reverts commit 6cc7f9a2e6bedff3109ea066278e9805713da4bb.
Reverted https://github.com/pytorch/pytorch/pull/120023 on behalf of https://github.com/anijain2305 due to breaks llms export ([comment](https://github.com/pytorch/pytorch/pull/120023#issuecomment-1974104633))
diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py
index b3f0919..3179381 100644
--- a/torch/_dynamo/config.py
+++ b/torch/_dynamo/config.py
@@ -406,11 +406,6 @@
# WARNING: this is an experimental flag and is subject to change.
_experimental_support_context_fn_in_torch_utils_checkpoint = False
-# Approximate maximum number of nodes to unroll loops into. A value of 0 will
-# unroll fully but could result in very large graphs that take an inordinate
-# amount of time to process.
-max_loop_unroll_nodes = int(os.environ.get("TORCHDYNAMO_MAX_LOOP_UNROLL_NODES", 5000))
-
if TYPE_CHECKING:
from torch.utils._config_typing import * # noqa: F401, F403
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 0df831d..6701115 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -598,44 +598,6 @@
return decorator
-class BackedgeTracker:
- """
- A BackedgeTracker is used to keep track of backedges and how many nodes
- we're expanding the graph by each time the backedge is seen. The general
- idea is that if you're looping and each loop is inlining a large sub-graph
- we detect that and skip frame when it's looking untenable.
- """
-
- def __init__(self):
- # How many loops have we performed?
- self.n_seen = 0
- # How many nodes were in the graph on our first loop?
- self.n_nodes_on_first_loop = None
-
- # Raises SkipFrame if the loop is getting too big.
- def append(self, count):
- self.n_seen += 1
- if self.n_seen == 1:
- self.n_nodes_on_first_loop = count
-
- # Don't skip if we haven't seen this particular backedge at least a few
- # times.
- if self.n_seen < 3:
- return
-
- # For now use the trivial hueristic of checking the raw number of nodes
- # since the first time we saw this backedge. In the future we could do
- # something more interesting like watching the rate of growth to trim
- # loops earlier (so we don't have to wait for `max` nodes before
- # skipping).
- added_nodes = count - self.n_nodes_on_first_loop
- if (
- config.max_loop_unroll_nodes > 0
- and added_nodes > config.max_loop_unroll_nodes
- ):
- raise exc.SkipFrame("unrolled loop getting too big")
-
-
class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]):
output: OutputGraph
symbolic_locals: Dict[str, VariableTracker]
@@ -652,8 +614,6 @@
inline_depth: int
inconsistent_side_effects: bool
current_speculation: Optional[SpeculationEntry]
- # Used to track how big the graph is getting on loop backedges.
- loop_backedge_trackers: Dict[Tuple[int, int], BackedgeTracker]
def mark_inconsistent_side_effects(self):
"""
@@ -1111,12 +1071,7 @@
self.push(ConstantVariable.create(value=val))
def jump(self, inst):
- target = self.indexof[inst.target]
- if self.instruction_pointer and target < self.instruction_pointer:
- key = (self.instruction_pointer, target)
- count = len(self.output.graph.nodes)
- self.loop_backedge_trackers[key].append(count)
- self.instruction_pointer = target
+ self.instruction_pointer = self.indexof[inst.target]
JUMP_FORWARD = jump
JUMP_ABSOLUTE = jump
@@ -2005,7 +1960,6 @@
):
super().__init__()
self.speculation_log = speculation_log
- self.loop_backedge_trackers = collections.defaultdict(BackedgeTracker)
# Mutable state checkpointed by copy_graphstate()
self.output = output