Use fast traceback for symbolic shapes (#107439)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107439
Approved by: https://github.com/voznesenskym
ghstack dependencies: #107505, #107516, #107530, #107532, #107562, #107471
diff --git a/torch/_guards.py b/torch/_guards.py
index b47f28b..647ec55 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -126,8 +126,7 @@
 
 class ShapeGuard(NamedTuple):
     expr: sympy.Expr
-    # TODO: store this in slightly less formatted form
-    stack: str
+    stack: CapturedTraceback
 
 
 @dataclasses.dataclass
@@ -694,6 +693,12 @@
             e.real_stack = context.extract_stack()  # type: ignore[attr-defined]
         raise
     finally:
+        if (
+            context is not None
+            and context.fake_mode is not None
+            and context.fake_mode.shape_env is not None
+        ):
+            context.fake_mode.shape_env.cleanup()
         _TLS.tracing_context = old_context
 
 
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 5e14e13..fce87be 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -8,7 +8,6 @@
 import operator
 import re
 import sys
-import textwrap
 import threading
 import traceback
 from collections import defaultdict
@@ -36,7 +35,7 @@
 from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift
 from torch.utils._sympy.solve import try_solve
 from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
-from torch.utils._traceback import format_frame
+from torch.utils._traceback import format_frame, CapturedTraceback
 from torch._utils_internal import signpost_event
 
 InputList = List
@@ -1059,11 +1058,6 @@
     raise AssertionError("shouldn't be hit")
 
 
-def get_debugging_stack(num_frames_to_cut=2):
-    # cut this frame and the caller's frame by default
-    return ''.join(traceback.format_list(traceback.extract_stack()[:-num_frames_to_cut]))
-
-
 def floor_ceil_helper(a, fn):
     if isinstance(a, sympy.Mul):
         aa = a.args
@@ -2032,7 +2026,7 @@
         # for N < 2. Therefore, it will be too strict to assert N=2 at runtime.
         self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
         self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
-        self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {}
+        self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
         # Maps symbolic ints to the guards that refine their lower/upper
         # bound. If one of them is None, it means that there are no guards
         # that refine that respective bound.
@@ -2377,7 +2371,7 @@
     def create_unbacked_symfloat(self):
         symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
         self.counter["create_unbacked_symbol"] += 1
-        self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
         self.var_to_range[symbol] = ValueRanges.unknown()
 
         # Create a new FX placeholder and Z3 variable for 'symbol'.
@@ -2388,7 +2382,7 @@
     def create_unbacked_symint(self):
         symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
         self.counter["create_unbacked_symbol"] += 1
-        self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
         self.var_to_range[symbol] = self._default_unspecified_value_range()
 
         # Create a new FX placeholder and Z3 variable for 'symbol'.
@@ -2399,7 +2393,7 @@
     def create_unbacked_symbool(self):
         symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
         self.counter["create_unbacked_symbol"] += 1
-        self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
+        self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
         self.var_to_range[symbol] = ValueRanges(0, 1)
 
         # Create a new FX placeholder and Z3 variable for 'symbol'.
@@ -2824,7 +2818,7 @@
                         else:
                             raise AssertionError(f"unrecognized constraint {c}")
             except Exception:
-                self.log.warning("Failing guard allocated at: \n%s", guard.stack)
+                self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
                 raise
 
         # First, issue all the non-trivial guards.
@@ -3015,7 +3009,7 @@
         def format_tb(tb):
             if not verbose:
                 return ""
-            return f"\n   Guarded at:\n{textwrap.indent(tb, '   ')}"
+            return f"\n   Guarded at:\n{''.join('   ' + l for l in tb.format())}"
 
         return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
 
@@ -3190,7 +3184,7 @@
         # TODO: in a Dynamo context, having user code, and having the
         # name of the local, will be much better
         for s in expr.free_symbols:
-            stacktrace = ''.join(traceback.format_list(self.var_to_stack[s]))
+            stacktrace = ''.join(self.var_to_stack[s].format())
             self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
         return GuardOnDataDependentSymNode(
             "It appears that you're trying to get a value out of symbolic int/float "
@@ -3332,11 +3326,22 @@
             log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
 
 
-    def _log_guard(self, prefix: str, g, tb):
+    def _log_guard(self, prefix: str, g):
         if self.log.isEnabledFor(logging.INFO):
-            for frame in reversed(tb):
-                if frame.filename not in uninteresting_files():
-                    break
+            fsummary = None
+            frame = inspect.currentframe()
+            try:
+                while frame is not None:
+                    if frame.f_code.co_filename not in uninteresting_files():
+                        fsummary = traceback.FrameSummary(
+                            frame.f_code.co_filename,
+                            frame.f_lineno,
+                            frame.f_code.co_name,
+                        )
+                        break
+                    frame = frame.f_back
+            finally:
+                del frame
 
             # NB: this stack is truncated, but it's fine because the main
             # stack_info will give you the rest of the info you need
@@ -3357,7 +3362,7 @@
                 "eval %s [guard added]%s (%s)%s",
                 g,
                 maybe_user_loc,
-                format_frame(frame),
+                format_frame(fsummary),
                 maybe_extra_debug,
                 stack_info=is_debug,
             )
@@ -3460,8 +3465,7 @@
                 g = sympy.Eq(expr, concrete_val)  # type: ignore[arg-type]
 
             if not self._suppress_guards_tls():
-                tb = traceback.extract_stack()[:-1]
-                stack = ''.join(traceback.format_list(tb))
+                stack = CapturedTraceback.extract(skip=1)
                 guard = ShapeGuard(g, stack)
                 self.guards.append(guard)
         except Exception:
@@ -3470,16 +3474,27 @@
         else:
             if not self._suppress_guards_tls():
                 assert guard is not None
-                assert tb is not None
 
                 self.refine_ranges(guard)
 
-                self._log_guard("eval", g, tb)
+                self._log_guard("eval", g)
             else:
                 self.log.debug("eval %s [guard suppressed]", g)
 
         return concrete_val
 
+    def cleanup(self):
+        # Break reference cycles.
+        # This destroys the stacks. If you really want to keep them, we
+        # just need some way to break references on code objects.
+        for g in self.guards:
+            g.stack.cleanup()
+        for s in self.var_to_stack.values():
+            s.cleanup()
+        for ras in self.deferred_runtime_asserts.values():
+            for ra in ras:
+                ra.stack.cleanup()
+
     def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
         expr = orig_expr
 
@@ -3511,8 +3526,7 @@
         # here)
 
         if not self._suppress_guards_tls():
-            tb = traceback.extract_stack()[:-1]
-            stack = ''.join(traceback.format_list(tb))
+            stack = CapturedTraceback.extract(skip=1)
             ra = RuntimeAssert(expr, msg, stack)
             # TODO: Do this in a way that is less janky than int(s.name[1:])
             cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:]))
@@ -3524,7 +3538,7 @@
             # in ranges.  For example, i0 <= s0 is un-rangeable, because
             # we can't put s0 in the range.  So this is not very high
             # priority at the moment.
-            self._log_guard("runtime_assert", expr, tb)
+            self._log_guard("runtime_assert", expr)
         else:
             self.log.debug("runtime_assert %s [guard suppressed]", expr)