[dynamo, 3.12] fix the block stack... again (#123978)
Some changes to how we handle blocks in 3.11+:
- We only keep track of with blocks that are not enclosed in a try block
- We do not compile partial graphs if we are in a block that is not in a tracked with block - i.e. any block enclosed in some non-with try/except/etc. block
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123978
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 19fc38c..e633e88 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -7025,7 +7025,7 @@
def test_variable_access_in_exception(self):
def fn():
- x = torch.ones(3, 3)
+ x = torch.ones(1)
try:
raise RuntimeError("bad")
except RuntimeError:
@@ -7033,7 +7033,87 @@
return x
opt_fn = torch._dynamo.optimize("eager")(fn)
- torch.allclose(opt_fn(), torch.tensor([3.0]))
+ self.assertEqual(opt_fn(), torch.tensor([2.0]))
+
+ def test_nested_sequential_with(self):
+ def fn(x):
+ with torch.set_grad_enabled(True):
+ with torch.set_grad_enabled(False):
+ x = x + 1
+ with torch.set_grad_enabled(True):
+ x = x + 1
+ return x
+
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
+
+ def test_nested_sequential_try(self):
+ def fn(x):
+ try:
+ try:
+ x = x + 1
+ except:
+ pass
+ try:
+ try:
+ x = x + 1
+ except:
+ pass
+ except:
+ pass
+ except:
+ pass
+ return x
+
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
+
+ def test_nested_sequential_try_with(self):
+ def fn(x):
+ with torch.set_grad_enabled(True):
+ try:
+ x = x + 1
+ except:
+ pass
+ try:
+ with torch.set_grad_enabled(False):
+ x = x + 1
+ except:
+ pass
+ return x
+
+ opt_fn = torch._dynamo.optimize("eager")(fn)
+ self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
+
+ def test_nested_sequential_try_with_graph_break(self):
+ def fn(x, n):
+ with torch.set_grad_enabled(True):
+ with torch.set_grad_enabled(False):
+ x = x + 1
+ torch._dynamo.graph_break()
+ try:
+ with torch.set_grad_enabled(False):
+ x = x + 1
+ if n == 0:
+ torch._dynamo.graph_break()
+ except:
+ pass
+ with torch.set_grad_enabled(False):
+ x = x + 1
+ torch._dynamo.graph_break()
+ x = x + 1
+ return x
+
+ counter = CompileCounter()
+ opt_fn = torch._dynamo.optimize(counter)(fn)
+ self.assertEqual(opt_fn(torch.ones(1), 0), torch.tensor([5.0]))
+ self.assertEqual(counter.frame_count, 1)
+
+ torch._dynamo.reset()
+ counter = CompileCounter()
+ opt_fn = torch._dynamo.optimize(counter)(fn)
+ self.assertEqual(opt_fn(torch.ones(1), 1), torch.tensor([5.0]))
+ self.assertEqual(counter.frame_count, 3)
def test_ordered_dict_alias_reconstruct(self):
od = collections.OrderedDict
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index b7520c8..1a6fc55 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -801,48 +801,43 @@
if sys.version_info >= (3, 11):
def update_block_stack(self, inst):
- # 3.11 no longer uses a block stack, but we still keep track of one
+ # 3.11+ no longer uses a block stack, but we still keep track of one
# so that we know which contexts are currently active.
# For our purposes, all exception table entries with the same target
# are considered to be part of the same "block".
+ # NOTE: we only keep track of with blocks that are not contained in try blocks.
+ # This is because we will not create continuation functions on graph breaks in try blocks,
+ # but we may for with blocks. We do not push blocks here since
+ # with blocks are pushed when handling BEFORE_WITH.
entry = inst.exn_tab_entry
- if not (
- # still in the same block
- entry
- and self.block_stack
- and self.block_stack[-1].target is entry.target
- ):
- if not entry:
- # no longer in any block
- # It is possible for NOPs to be between two instructions
- # in the same block, but the NOPs are not covered by an
- # exception table entry. In this case, assume that we
- # are still in the same block.
- # In 3.12+, JUMP_BACKWARD might also not be covered by
- # an exception table entry, so we also assume that we
- # are still in the same block. It is probably safe to do
- # this in 3.11, even though we haven't encountered this case before.
- if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
- # If we really escape from a block and the current
- # instruction is not in another block, then there
- # should be no other nested blocks that we are in.
- assert len(self.block_stack) == 1
- self.block_stack.pop()
- elif (
- # current instruction is in the previous block
- len(self.block_stack) > 1
- and self.block_stack[-2].target is entry.target
+ if entry:
+ # Detect when we have exited the top with block.
+ # The with blocks on the block stack are not enclosed in try
+ # blocks, so a with block's cleanup code should be in the
+ # previous with block (if any).
+ if (
+ len(self.block_stack) >= 2
+ and entry.target is not self.block_stack[-1].target
+ and entry.target is self.block_stack[-2].target
):
# exit the current block
self.block_stack.pop()
- else:
- # current instruction is in a new block
- # push block to stack - note, BEFORE_WITH blocks won't
- # be pushed here since BEFORE_WITH pushes the block, and
- # the current instruction would be counted as being in that block.
- self.block_stack.append(
- BlockStackEntry(entry.target, len(self.stack))
- )
+ else:
+ # no longer in any block
+ # It is possible for NOPs to be between two instructions
+ # in the same block, but the NOPs are not covered by an
+ # exception table entry. In this case, assume that we
+ # are still in the same block.
+ # In 3.12+, JUMP_BACKWARD might also not be covered by
+ # an exception table entry, so we also assume that we
+ # are still in the same block. It is probably safe to do
+ # this in 3.11, even though we haven't encountered this case before.
+ if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
+ # If we really escape from a block and the current
+ # instruction is not in another block, then there
+ # should be no other nested blocks that we are in.
+ assert len(self.block_stack) == 1
+ self.block_stack.pop()
else:
@@ -1392,6 +1387,8 @@
speculation.fail_and_restart_analysis()
def store_attr_graph_break(self, inst):
+ if not self.should_compile_partial_graph():
+ unimplemented("should_compile_partial_graph=False")
self.output.compile_subgraph(
self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
)
@@ -1882,15 +1879,27 @@
ctx,
inst.target,
)
+
if sys.version_info >= (3, 11):
- # see create_call_resume_at for block stack details
- target = self.next_instruction.exn_tab_entry.target
+ # See create_call_resume_at for block stack details.
+ # Only push a block if the current instruction's block is a
+ # with block that is not nested in a try block - that is, the current
+ # instruction's block target is the same as the top block's target.
+ if inst.exn_tab_entry and (
+ not self.block_stack
+ or inst.exn_tab_entry.target is not self.block_stack[-1].target
+ ):
+ target = None
+ else:
+ target = self.next_instruction.exn_tab_entry.target
else:
target = inst.target
- if isinstance(self, InstructionTranslator):
- self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx))
- else:
- self.block_stack.append(BlockStackEntry(target))
+
+ if target:
+ if isinstance(self, InstructionTranslator):
+ self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx))
+ else:
+ self.block_stack.append(BlockStackEntry(target))
self.push(exit)
self.push(ctx.enter(self))
@@ -2234,6 +2243,13 @@
return self.symbolic_locals[name]
def should_compile_partial_graph(self):
+ if sys.version_info >= (3, 11):
+ # Do not compile if current instruction's block is not the top with block
+ entry = self.current_instruction.exn_tab_entry
+ if entry and (
+ not self.block_stack or entry.target is not self.block_stack[-1].target
+ ):
+ return False
return (
all(b.can_restore() for b in self.block_stack)
and not self.one_graph