Revamp guard debug logging (#107505)
The new guard printout looks like this:
```
[DEBUG] GUARDS:
[DEBUG] ___check_type_id(L['name'], 7605632) # if name == "special_attr": # test/dynamo/test_misc.py:1155 in __getattribute__
[DEBUG] L['name'] == '_backward_pre_hooks' # if name == "special_attr": # test/dynamo/test_misc.py:1155 in __getattribute__
[DEBUG] ___check_obj_id(L['self'], 139746432564960) # return super().__getattribute__(name) # test/dynamo/test_misc.py:1157 in __getattribute__
[DEBUG] ___check_obj_id(L['__class__'], 1451499216) # return super().__getattribute__(name) # test/dynamo/test_misc.py:1157 in __getattribute__
[DEBUG] ___is_grad_enabled() # _dynamo/output_graph.py:346 in init_ambient_guards
[DEBUG] not ___are_deterministic_algorithms_enabled() # _dynamo/output_graph.py:342 in init_ambient_guards
[DEBUG] ___is_torch_function_enabled() # _dynamo/output_graph.py:350 in init_ambient_guards
[DEBUG] utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:348 in init_ambient_guards
```
Along with the guards, we also print what line of user code caused the guard to be added, or what line of Dynamo internal code added the guard (if there is no user stack trace, which is typically the case for ambient guards.)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107505
Approved by: https://github.com/mlazos, https://github.com/voznesenskym, https://github.com/anijain2305
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 9128985..33be9ec 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -1,4 +1,5 @@
import functools
+import inspect
import itertools
import logging
import os
@@ -16,7 +17,7 @@
GuardOnDataDependentSymNode,
)
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
-from torch.utils._traceback import format_traceback_short
+from torch.utils._traceback import format_frame, format_traceback_short
from . import config, exc
from .allowed_functions import is_allowed
@@ -72,6 +73,17 @@
recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
+# For user stack printing
+@functools.lru_cache(None)
+def uninteresting_files():
+ import torch._dynamo.external_utils
+
+ mods = [
+ torch._dynamo.external_utils,
+ ]
+ return {inspect.getfile(m) for m in mods}
+
+
class Tracker:
def __init__(self):
self.seen = []
@@ -530,14 +542,24 @@
if guards_log.isEnabledFor(logging.DEBUG):
guard_str = "GUARDS:\n"
- guard_str += "\n".join(
- [
- f" {code}"
- for guard in sorted(output.guards)
- if guard.code_list is not None
- for code in guard.code_list
- ]
- )
+ base = os.path.dirname(__file__)
+ for guard in sorted(output.guards):
+ if guard.code_list is None:
+ continue
+
+ extra = ""
+ if guard.user_stack:
+ for fs in reversed(guard.user_stack):
+ if fs.filename not in uninteresting_files():
+ break
+ else:
+ fs = guard.user_stack[-1]
+ extra = f" # {format_frame(fs, line=True)}"
+ elif guard.stack:
+ extra = f" # {format_frame(guard.stack.summary()[-1])}"
+
+ for code in guard.code_list:
+ guard_str += f" {code:<60}{extra}\n"
guards_log.debug("%s", guard_str)
if verbose_guards_log.isEnabledFor(logging.DEBUG):
@@ -545,11 +567,16 @@
if guard.code_list is None:
continue
cat_code = " and ".join(guard.code_list)
+ maybe_user_stack = ""
+ if guard.user_stack:
+ maybe_user_stack = (
+ f"\nUser stack:\n{''.join(guard.user_stack.format())}"
+ )
verbose_guards_log.debug(
- "GUARD: %s\nStack:\n%sUser Stack:\n%s",
+ "GUARD: %s\nStack:\n%s%s",
cat_code,
"".join(guard.stack.format()),
- "".join(guard.user_stack.format()),
+ maybe_user_stack,
)
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index db68951..47d47eb 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -1007,7 +1007,7 @@
guard_body, pycode = build_guard_function(unique_code_parts, make_guard_fn_args)
if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1":
- print("GUARDS", guard_body)
+ print("GUARDS\n", guard_body)
if is_guard_failure_reporting_enabled() or guard_fail_fn is not None:
# Guard fail hook is called everytime guard eval fails. For a cache
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index ebed110..367181b 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -268,21 +268,7 @@
allow_non_fake_inputs=True if self.export else False,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
- # Register a SHAPE_ENV guard to make sure we setup shape guards
- # that show up in ShapeEnv
- self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
-
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
- )
-
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
-
- self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
-
- self.guards.add(
- GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
- )
+ self.init_ambient_guards()
# tracked_fakes says where any tensor that was wrapped to fake came
# from. It is similar to GraphArg, in that all GraphArgs will get
@@ -346,6 +332,25 @@
# presence of torch.no_grad) and there is a graph break.
self.save_global_state()
+ # This gets its own helper function so guards DEBUG logs are more
+ # informative
+ def init_ambient_guards(self):
+ # Register a SHAPE_ENV guard to make sure we setup shape guards
+ # that show up in ShapeEnv
+ self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
+
+ self.guards.add(
+ GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS)
+ )
+
+ self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE))
+
+ self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE))
+
+ self.guards.add(
+ GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
+ )
+
@property
def root_tracer(self):
return self.tracers[0]
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 9d41de4..7c84a04 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -2027,46 +2027,49 @@
if k in f_locals
)
- # symbolic_locals contains the mapping from original f_locals to the
- # Variable objects. During the Variable building phase, each object also
- # has its associated guards. At the end, we will accumulate these
- # guards.
- #
- # One way of handling these guards is to just accumulate all of them
- # right now. However, many f_locals might not be used in the frame and
- # thus can unnecessarily increase guard execution overhead. Therefore,
- # we selectively update output.guards as we run the Python Bytecode
- # instruction by instruction.
- #
- # An exception here is list/dict variables. Guards related to these
- # variables have indexed access, like Tensor_match on args[0], and if
- # args is not used in this frame, we will miss a LIST_LENGTH check like
- # len(args) == 2. Missing the LIST_LENGTH check causes problem for the
- # next invocation when args is not a list, and args[0] is a runtime
- # error. Therefore, we recursively add guards for list/dict variable here.
- for val in self.symbolic_locals.values():
- if isinstance(
- val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
- ):
- local_guards = VariableTracker.propagate(val)["guards"]
- index_guards = [
- guard
- for guard in local_guards
- if guard.create_fn
- in (
- GuardBuilder.LIST_LENGTH,
- GuardBuilder.DICT_KEYS,
- GuardBuilder.ODICT_KEYS,
- GuardBuilder.TUPLE_ITERATOR_LEN,
- )
- ]
- self.output.guards.update(index_guards)
+ self.init_local_index_guards_hack()
self._freevars_ids = dict()
for name in self.code_options["co_freevars"]:
if name in f_locals:
self._freevars_ids[name] = id(f_locals[name])
+ def init_local_index_guards_hack(self):
+ # symbolic_locals contains the mapping from original f_locals to the
+ # Variable objects. During the Variable building phase, each object also
+ # has its associated guards. At the end, we will accumulate these
+ # guards.
+ #
+ # One way of handling these guards is to just accumulate all of them
+ # right now. However, many f_locals might not be used in the frame and
+ # thus can unnecessarily increase guard execution overhead. Therefore,
+ # we selectively update output.guards as we run the Python Bytecode
+ # instruction by instruction.
+ #
+ # An exception here is list/dict variables. Guards related to these
+ # variables have indexed access, like Tensor_match on args[0], and if
+ # args is not used in this frame, we will miss a LIST_LENGTH check like
+ # len(args) == 2. Missing the LIST_LENGTH check causes problem for the
+ # next invocation when args is not a list, and args[0] is a runtime
+ # error. Therefore, we recursively add guards for list/dict variable here.
+ for val in self.symbolic_locals.values():
+ if isinstance(
+ val, (ListIteratorVariable, BaseListVariable, ConstDictVariable)
+ ):
+ local_guards = VariableTracker.propagate(val)["guards"]
+ index_guards = [
+ guard
+ for guard in local_guards
+ if guard.create_fn
+ in (
+ GuardBuilder.LIST_LENGTH,
+ GuardBuilder.DICT_KEYS,
+ GuardBuilder.ODICT_KEYS,
+ GuardBuilder.TUPLE_ITERATOR_LEN,
+ )
+ ]
+ self.output.guards.update(index_guards)
+
def run(self):
super().run()
diff --git a/torch/_guards.py b/torch/_guards.py
index b616e13..99aa1d9 100644
--- a/torch/_guards.py
+++ b/torch/_guards.py
@@ -148,10 +148,6 @@
stack = None
user_stack = None
- def __post_init__(self):
- self.stack = CapturedTraceback.extract(skip=2)
- self.user_stack = TracingContext.extract_stack()
-
def __hash__(self):
return hash((self.name, self.source, id(self.create_fn)))
@@ -436,17 +432,55 @@
"""
+# Like a Set[Guard] but will record the user stack on all guards at the
+# time they were installed at their destination
+class GuardsSet:
+ def __init__(self, inner=None):
+ if inner is None:
+ inner = set()
+ self.inner = inner
+
+ def __iter__(self):
+ return iter(self.inner)
+
+ def __len__(self):
+ return len(self.inner)
+
+ # Subtraction along with bool is typically used to determine the delta of
+ # added guards between checkpoints for higher order ops
+ def __sub__(self, other):
+ return GuardsSet(self.inner - other.inner)
+
+ def __bool__(self):
+ return bool(self.inner)
+
+ def add(self, guard: Guard, *, skip=0):
+ if guard in self.inner:
+ return
+ if guard.stack is None:
+ guard.stack = CapturedTraceback.extract(skip=1 + skip)
+ if guard.user_stack is None:
+ guard.user_stack = TracingContext.extract_stack()
+ self.inner.add(guard)
+
+ def update(self, *others: Set[Guard]):
+ for o in others:
+ for g in o:
+ self.add(g, skip=1)
+
+
class GuardsContext(Checkpointable[GuardsCheckpointState]):
def __init__(self):
- self.dynamo_guards: Set[Guard] = set()
+ self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: List[GuardEnvExpr] = []
def copy_graphstate(self):
- return GuardsCheckpointState(set(self.dynamo_guards))
+ return GuardsCheckpointState(set(self.dynamo_guards.inner))
def restore_graphstate(self, state):
+ # NB: "steals" the passed in state
assert isinstance(state, GuardsCheckpointState)
- self.dynamo_guards = state.dynamo_guards
+ self.dynamo_guards = GuardsSet(state.dynamo_guards)
_TLS = threading.local()
diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py
index fb4d01a..24b574b 100644
--- a/torch/utils/_traceback.py
+++ b/torch/utils/_traceback.py
@@ -130,21 +130,30 @@
raise exc.with_traceback(tb_next)
-def shorten_filename(fn):
+def shorten_filename(fn, *, base=None):
"""
Shorten a source filepath, under the assumption that anything under torch/
directory is "obvious" and doesn't need to be shown to user.
"""
+ if base is None:
+ base = os.path.dirname(os.path.dirname(__file__))
# Truncate torch/foo.py to foo.py
- prefix = os.path.commonprefix([fn, os.path.join(os.path.dirname(os.path.dirname(__file__)), "")])
- return fn[len(prefix):]
+ try:
+ prefix = os.path.commonpath([fn, base])
+ except ValueError:
+ return fn
+ else:
+ return fn[len(prefix) + 1:]
-def format_frame(frame):
+def format_frame(frame, *, base=None, line=False):
"""
Format a FrameSummary in a short way, without printing full absolute path
or code. The idea is the result fits on a single line.
"""
- return f"{shorten_filename(frame.filename)}:{frame.lineno} in {frame.name}"
+ extra_line = ""
+ if line:
+ extra_line = f"{frame.line} # "
+ return f"{extra_line}{shorten_filename(frame.filename, base=base)}:{frame.lineno} in {frame.name}"
def format_traceback_short(tb):
"""