Use fast traceback for symbolic shapes (#107439)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107439
Approved by: https://github.com/voznesenskym
ghstack dependencies: #107505, #107516, #107530, #107532, #107562, #107471
diff --git a/torch/_guards.py b/torch/_guards.py
index b47f28b..647ec55 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -126,8 +126,7 @@
class ShapeGuard(NamedTuple):
expr: sympy.Expr
- # TODO: store this in slightly less formatted form
- stack: str
+ stack: CapturedTraceback
@dataclasses.dataclass
@@ -694,6 +693,12 @@
e.real_stack = context.extract_stack() # type: ignore[attr-defined]
raise
finally:
+ if (
+ context is not None
+ and context.fake_mode is not None
+ and context.fake_mode.shape_env is not None
+ ):
+ context.fake_mode.shape_env.cleanup()
_TLS.tracing_context = old_context
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 5e14e13..fce87be 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -8,7 +8,6 @@
import operator
import re
import sys
-import textwrap
import threading
import traceback
from collections import defaultdict
@@ -36,7 +35,7 @@
from torch.utils._sympy.functions import FloorDiv, LShift, Mod, RShift
from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError
-from torch.utils._traceback import format_frame
+from torch.utils._traceback import format_frame, CapturedTraceback
from torch._utils_internal import signpost_event
InputList = List
@@ -1059,11 +1058,6 @@
raise AssertionError("shouldn't be hit")
-def get_debugging_stack(num_frames_to_cut=2):
- # cut this frame and the caller's frame by default
- return ''.join(traceback.format_list(traceback.extract_stack()[:-num_frames_to_cut]))
-
-
def floor_ceil_helper(a, fn):
if isinstance(a, sympy.Mul):
aa = a.args
@@ -2032,7 +2026,7 @@
# for N < 2. Therefore, it will be too strict to assert N=2 at runtime.
self.runtime_var_to_range: Dict[sympy.Symbol, ValueRanges] = {}
self.var_to_sources: Dict[sympy.Symbol, List[Source]] = {}
- self.var_to_stack: Dict[sympy.Symbol, traceback.StackSummary] = {}
+ self.var_to_stack: Dict[sympy.Symbol, CapturedTraceback] = {}
# Maps symbolic ints to the guards that refine their lower/upper
# bound. If one of them is None, it means that there are no guards
# that refine that respective bound.
@@ -2377,7 +2371,7 @@
def create_unbacked_symfloat(self):
symbol: sympy.Symbol = sympy.Symbol(f"f{next(self.unbacked_symfloat_counter)}")
self.counter["create_unbacked_symbol"] += 1
- self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
+ self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = ValueRanges.unknown()
# Create a new FX placeholder and Z3 variable for 'symbol'.
@@ -2388,7 +2382,7 @@
def create_unbacked_symint(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.counter["create_unbacked_symbol"] += 1
- self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
+ self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = self._default_unspecified_value_range()
# Create a new FX placeholder and Z3 variable for 'symbol'.
@@ -2399,7 +2393,7 @@
def create_unbacked_symbool(self):
symbol: sympy.Symbol = sympy.Symbol(f"i{next(self.unbacked_symint_counter)}", integer=True)
self.counter["create_unbacked_symbol"] += 1
- self.var_to_stack[symbol] = traceback.extract_stack()[:-1]
+ self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1)
self.var_to_range[symbol] = ValueRanges(0, 1)
# Create a new FX placeholder and Z3 variable for 'symbol'.
@@ -2824,7 +2818,7 @@
else:
raise AssertionError(f"unrecognized constraint {c}")
except Exception:
- self.log.warning("Failing guard allocated at: \n%s", guard.stack)
+ self.log.warning("Failing guard allocated at: \n%s", ''.join(guard.stack.format()))
raise
# First, issue all the non-trivial guards.
@@ -3015,7 +3009,7 @@
def format_tb(tb):
if not verbose:
return ""
- return f"\n Guarded at:\n{textwrap.indent(tb, ' ')}"
+ return f"\n Guarded at:\n{''.join(' ' + l for l in tb.format())}"
return '\n'.join(f" - {guard.expr}{format_tb(guard.stack)}" for guard in self.guards)
@@ -3190,7 +3184,7 @@
# TODO: in a Dynamo context, having user code, and having the
# name of the local, will be much better
for s in expr.free_symbols:
- stacktrace = ''.join(traceback.format_list(self.var_to_stack[s]))
+ stacktrace = ''.join(self.var_to_stack[s].format())
self.log.debug("Data dependent variable '%s' allocated at:\n%s", s, stacktrace)
return GuardOnDataDependentSymNode(
"It appears that you're trying to get a value out of symbolic int/float "
@@ -3332,11 +3326,22 @@
log.warning("Ignored guard %s == %s, this could result in accuracy problems", expr, concrete_val)
- def _log_guard(self, prefix: str, g, tb):
+ def _log_guard(self, prefix: str, g):
if self.log.isEnabledFor(logging.INFO):
- for frame in reversed(tb):
- if frame.filename not in uninteresting_files():
- break
+ fsummary = None
+ frame = inspect.currentframe()
+ try:
+ while frame is not None:
+ if frame.f_code.co_filename not in uninteresting_files():
+ fsummary = traceback.FrameSummary(
+ frame.f_code.co_filename,
+ frame.f_lineno,
+ frame.f_code.co_name,
+ )
+ break
+ frame = frame.f_back
+ finally:
+ del frame
# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
@@ -3357,7 +3362,7 @@
"eval %s [guard added]%s (%s)%s",
g,
maybe_user_loc,
- format_frame(frame),
+ format_frame(fsummary),
maybe_extra_debug,
stack_info=is_debug,
)
@@ -3460,8 +3465,7 @@
g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type]
if not self._suppress_guards_tls():
- tb = traceback.extract_stack()[:-1]
- stack = ''.join(traceback.format_list(tb))
+ stack = CapturedTraceback.extract(skip=1)
guard = ShapeGuard(g, stack)
self.guards.append(guard)
except Exception:
@@ -3470,16 +3474,27 @@
else:
if not self._suppress_guards_tls():
assert guard is not None
- assert tb is not None
self.refine_ranges(guard)
- self._log_guard("eval", g, tb)
+ self._log_guard("eval", g)
else:
self.log.debug("eval %s [guard suppressed]", g)
return concrete_val
+ def cleanup(self):
+ # Break reference cycles.
+ # This destroys the stacks. If you really want to keep them, we
+ # just need some way to break references on code objects.
+ for g in self.guards:
+ g.stack.cleanup()
+ for s in self.var_to_stack.values():
+ s.cleanup()
+ for ras in self.deferred_runtime_asserts.values():
+ for ra in ras:
+ ra.stack.cleanup()
+
def defer_runtime_assert(self, orig_expr: "sympy.Expr", msg, fx_node=None):
expr = orig_expr
@@ -3511,8 +3526,7 @@
# here)
if not self._suppress_guards_tls():
- tb = traceback.extract_stack()[:-1]
- stack = ''.join(traceback.format_list(tb))
+ stack = CapturedTraceback.extract(skip=1)
ra = RuntimeAssert(expr, msg, stack)
# TODO: Do this in a way that is less janky than int(s.name[1:])
cands = sorted([s for s in expr.free_symbols if s.name.startswith("i")], key=lambda s: int(s.name[1:]))
@@ -3524,7 +3538,7 @@
# in ranges. For example, i0 <= s0 is un-rangeable, because
# we can't put s0 in the range. So this is not very high
# priority at the moment.
- self._log_guard("runtime_assert", expr, tb)
+ self._log_guard("runtime_assert", expr)
else:
self.log.debug("runtime_assert %s [guard suppressed]", expr)