[dynamo, 3.12] fix the block stack... again (#123978)

Some changes to how we handle blocks in 3.11+:
- We only keep track of with blocks that are not enclosed in a try block
- We do not compile partial graphs if we are in a block that is not in a tracked with block - i.e. any block enclosed in some non-with try/except/etc. block

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123978
Approved by: https://github.com/jansel
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 19fc38c..e633e88 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -7025,7 +7025,7 @@
 
     def test_variable_access_in_exception(self):
         def fn():
-            x = torch.ones(3, 3)
+            x = torch.ones(1)
             try:
                 raise RuntimeError("bad")
             except RuntimeError:
@@ -7033,7 +7033,87 @@
             return x
 
         opt_fn = torch._dynamo.optimize("eager")(fn)
-        torch.allclose(opt_fn(), torch.tensor([3.0]))
+        self.assertEqual(opt_fn(), torch.tensor([2.0]))
+
+    def test_nested_sequential_with(self):
+        def fn(x):
+            with torch.set_grad_enabled(True):
+                with torch.set_grad_enabled(False):
+                    x = x + 1
+                with torch.set_grad_enabled(True):
+                    x = x + 1
+                return x
+
+        opt_fn = torch._dynamo.optimize("eager")(fn)
+        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
+
+    def test_nested_sequential_try(self):
+        def fn(x):
+            try:
+                try:
+                    x = x + 1
+                except:
+                    pass
+                try:
+                    try:
+                        x = x + 1
+                    except:
+                        pass
+                except:
+                    pass
+            except:
+                pass
+            return x
+
+        opt_fn = torch._dynamo.optimize("eager")(fn)
+        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
+
+    def test_nested_sequential_try_with(self):
+        def fn(x):
+            with torch.set_grad_enabled(True):
+                try:
+                    x = x + 1
+                except:
+                    pass
+                try:
+                    with torch.set_grad_enabled(False):
+                        x = x + 1
+                except:
+                    pass
+            return x
+
+        opt_fn = torch._dynamo.optimize("eager")(fn)
+        self.assertEqual(opt_fn(torch.ones(1)), torch.tensor([3.0]))
+
+    def test_nested_sequential_try_with_graph_break(self):
+        def fn(x, n):
+            with torch.set_grad_enabled(True):
+                with torch.set_grad_enabled(False):
+                    x = x + 1
+                    torch._dynamo.graph_break()
+                try:
+                    with torch.set_grad_enabled(False):
+                        x = x + 1
+                        if n == 0:
+                            torch._dynamo.graph_break()
+                except:
+                    pass
+                with torch.set_grad_enabled(False):
+                    x = x + 1
+                    torch._dynamo.graph_break()
+                x = x + 1
+            return x
+
+        counter = CompileCounter()
+        opt_fn = torch._dynamo.optimize(counter)(fn)
+        self.assertEqual(opt_fn(torch.ones(1), 0), torch.tensor([5.0]))
+        self.assertEqual(counter.frame_count, 1)
+
+        torch._dynamo.reset()
+        counter = CompileCounter()
+        opt_fn = torch._dynamo.optimize(counter)(fn)
+        self.assertEqual(opt_fn(torch.ones(1), 1), torch.tensor([5.0]))
+        self.assertEqual(counter.frame_count, 3)
 
     def test_ordered_dict_alias_reconstruct(self):
         od = collections.OrderedDict
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index b7520c8..1a6fc55 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -801,48 +801,43 @@
     if sys.version_info >= (3, 11):
 
         def update_block_stack(self, inst):
-            # 3.11 no longer uses a block stack, but we still keep track of one
+            # 3.11+ no longer uses a block stack, but we still keep track of one
             # so that we know which contexts are currently active.
             # For our purposes, all exception table entries with the same target
             # are considered to be part of the same "block".
+            # NOTE: we only keep track of with blocks that are not contained in try blocks.
+            # This is because we will not create continuation functions on graph breaks in try blocks,
+            # but we may for with blocks. We do not push blocks here since
+            # with blocks are pushed when handling BEFORE_WITH.
             entry = inst.exn_tab_entry
-            if not (
-                # still in the same block
-                entry
-                and self.block_stack
-                and self.block_stack[-1].target is entry.target
-            ):
-                if not entry:
-                    # no longer in any block
-                    # It is possible for NOPs to be between two instructions
-                    # in the same block, but the NOPs are not covered by an
-                    # exception table entry. In this case, assume that we
-                    # are still in the same block.
-                    # In 3.12+, JUMP_BACKWARD might also not be covered by
-                    # an exception table entry, so we also assume that we
-                    # are still in the same block. It is probably safe to do
-                    # this in 3.11, even though we haven't encountered this case before.
-                    if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
-                        # If we really escape from a block and the current
-                        # instruction is not in another block, then there
-                        # should be no other nested blocks that we are in.
-                        assert len(self.block_stack) == 1
-                        self.block_stack.pop()
-                elif (
-                    # current instruction is in the previous block
-                    len(self.block_stack) > 1
-                    and self.block_stack[-2].target is entry.target
+            if entry:
+                # Detect when we have exited the top with block.
+                # The with blocks on the block stack are not enclosed in try
+                # blocks, so a with block's cleanup code should be in the
+                # previous with block (if any).
+                if (
+                    len(self.block_stack) >= 2
+                    and entry.target is not self.block_stack[-1].target
+                    and entry.target is self.block_stack[-2].target
                 ):
                     # exit the current block
                     self.block_stack.pop()
-                else:
-                    # current instruction is in a new block
-                    # push block to stack - note, BEFORE_WITH blocks won't
-                    # be pushed here since BEFORE_WITH pushes the block, and
-                    # the current instruction would be counted as being in that block.
-                    self.block_stack.append(
-                        BlockStackEntry(entry.target, len(self.stack))
-                    )
+            else:
+                # no longer in any block
+                # It is possible for NOPs to be between two instructions
+                # in the same block, but the NOPs are not covered by an
+                # exception table entry. In this case, assume that we
+                # are still in the same block.
+                # In 3.12+, JUMP_BACKWARD might also not be covered by
+                # an exception table entry, so we also assume that we
+                # are still in the same block. It is probably safe to do
+                # this in 3.11, even though we haven't encountered this case before.
+                if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"):
+                    # If we really escape from a block and the current
+                    # instruction is not in another block, then there
+                    # should be no other nested blocks that we are in.
+                    assert len(self.block_stack) == 1
+                    self.block_stack.pop()
 
     else:
 
@@ -1392,6 +1387,8 @@
         speculation.fail_and_restart_analysis()
 
     def store_attr_graph_break(self, inst):
+        if not self.should_compile_partial_graph():
+            unimplemented("should_compile_partial_graph=False")
         self.output.compile_subgraph(
             self, reason=GraphCompileReason("store_attr", [self.frame_summary()])
         )
@@ -1882,15 +1879,27 @@
             ctx,
             inst.target,
         )
+
         if sys.version_info >= (3, 11):
-            # see create_call_resume_at for block stack details
-            target = self.next_instruction.exn_tab_entry.target
+            # See create_call_resume_at for block stack details.
+            # Only push a block if the current instruction's block is a
+            # with block that is not nested in a try block - that is, the current
+            # instruction's block target is the same as the top block's target.
+            if inst.exn_tab_entry and (
+                not self.block_stack
+                or inst.exn_tab_entry.target is not self.block_stack[-1].target
+            ):
+                target = None
+            else:
+                target = self.next_instruction.exn_tab_entry.target
         else:
             target = inst.target
-        if isinstance(self, InstructionTranslator):
-            self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx))
-        else:
-            self.block_stack.append(BlockStackEntry(target))
+
+        if target:
+            if isinstance(self, InstructionTranslator):
+                self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx))
+            else:
+                self.block_stack.append(BlockStackEntry(target))
 
         self.push(exit)
         self.push(ctx.enter(self))
@@ -2234,6 +2243,13 @@
         return self.symbolic_locals[name]
 
     def should_compile_partial_graph(self):
+        if sys.version_info >= (3, 11):
+            # Do not compile if current instruction's block is not the top with block
+            entry = self.current_instruction.exn_tab_entry
+            if entry and (
+                not self.block_stack or entry.target is not self.block_stack[-1].target
+            ):
+                return False
         return (
             all(b.can_restore() for b in self.block_stack)
             and not self.one_graph