Print restarting analysis at INFO level with a exception breadcrumb (#101573)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101573
Approved by: https://github.com/albanD
diff --git a/test/test_utils.py b/test/test_utils.py
index 7ee2f5e..37151e4 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -26,7 +26,7 @@
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from torch import set_default_device
from torch.utils._device import set_device
-from torch.utils._traceback import report_compile_source_on_error
+from torch.utils._traceback import report_compile_source_on_error, format_traceback_short
import torch.utils.cpp_extension
from torch.autograd._functions.utils import check_onnx_broadcast
from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
@@ -946,6 +946,12 @@
except RuntimeError as e:
self.assertIn("HEYA", ''.join(traceback.format_tb(e.__traceback__)))
+ def test_format_traceback_short(self):
+ try:
+ raise RuntimeError()
+ except RuntimeError as e:
+ self.assertRegex(format_traceback_short(e.__traceback__), r'.*test_utils.py:\d+ in test_format_traceback_short')
+
if __name__ == '__main__':
run_tests()
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index ff4bb11..e873769 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -16,6 +16,7 @@
GuardOnDataDependentSymNode,
)
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
+from torch.utils._traceback import format_traceback_short
from . import config, exc
from .allowed_functions import is_allowed
@@ -52,6 +53,7 @@
increment_frame,
is_namedtuple,
istype,
+ LazyString,
orig_code_map,
reset_graph_break_dup_checker,
setup_compile_debug,
@@ -428,8 +430,11 @@
out_code = transform_code_object(code, transform)
orig_code_map[out_code] = code
break
- except exc.RestartAnalysis:
- log.debug("Restarting analysis ...")
+ except exc.RestartAnalysis as e:
+ log.info(
+ "Restarting analysis due to %s",
+ LazyString(format_traceback_short, e.__traceback__),
+ )
if attempt > 100:
unimplemented("100+ RestartAnalysis() calls")
except exc.SkipFrame as e:
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 98bcf54..0c8b45b 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -6,7 +6,6 @@
import logging
import math
import operator
-import os
import re
import sys
import textwrap
@@ -34,6 +33,7 @@
from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges, ValueRangeError
+from torch.utils._traceback import format_frame
InputList = List
DimList = List
@@ -69,10 +69,6 @@
]
return {inspect.getfile(m) for m in mods}
-def shorten_filename(fn):
- prefix = os.path.commonprefix([fn, __file__])
- return fn[len(prefix):]
-
SYM_FUNCTION_MODE = None
# We don't bother with the metaclass as all of the dispatching logic happens
@@ -2831,9 +2827,6 @@
if frame.filename not in uninteresting_files():
break
- def format_frame(frame):
- return f"{shorten_filename(frame.filename)}:{frame.lineno} in {frame.name}"
-
# NB: this stack is truncated, but it's fine because the main
# stack_info will give you the rest of the info you need
maybe_user_loc = ""
diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py
index 92e186b..0adb522 100644
--- a/torch/utils/_traceback.py
+++ b/torch/utils/_traceback.py
@@ -1,7 +1,9 @@
from types import TracebackType
import tempfile
+import traceback
import contextlib
import inspect
+import os.path
# This file contains utilities for ensuring dynamically compile()'d
# code fragments display their line numbers in backtraces.
@@ -126,3 +128,25 @@
tb_next = tb
raise exc.with_traceback(tb_next)
+
+def shorten_filename(fn):
+ """
+ Shorten a source filepath, under the assumption that anything under torch/
+ directory is "obvious" and doesn't need to be shown to user.
+ """
+ # Truncate torch/foo.py to foo.py
+ prefix = os.path.commonprefix([fn, os.path.join(os.path.dirname(os.path.dirname(__file__)), "")])
+ return fn[len(prefix):]
+
+def format_frame(frame):
+ """
+ Format a FrameSummary in a short way, without printing full absolute path
+ or code. The idea is the result fits on a single line.
+ """
+ return f"{shorten_filename(frame.filename)}:{frame.lineno} in {frame.name}"
+
+def format_traceback_short(tb):
+ """
+ Format a TracebackType in a short way, printing only the inner-most frame.
+ """
+ return format_frame(traceback.extract_tb(tb)[-1])