Print restarting analysis at INFO level with a exception breadcrumb (#101573)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101573
Approved by: https://github.com/albanD
diff --git a/test/test_utils.py b/test/test_utils.py
index 7ee2f5e..37151e4 100644
--- a/test/test_utils.py
+++ b/test/test_utils.py
@@ -26,7 +26,7 @@
 from torch.utils.checkpoint import checkpoint, checkpoint_sequential
 from torch import set_default_device
 from torch.utils._device import set_device
-from torch.utils._traceback import report_compile_source_on_error
+from torch.utils._traceback import report_compile_source_on_error, format_traceback_short
 import torch.utils.cpp_extension
 from torch.autograd._functions.utils import check_onnx_broadcast
 from torch.onnx.symbolic_opset9 import _prepare_onnx_paddings
@@ -946,6 +946,12 @@
         except RuntimeError as e:
             self.assertIn("HEYA", ''.join(traceback.format_tb(e.__traceback__)))
 
+    def test_format_traceback_short(self):
+        try:
+            raise RuntimeError()
+        except RuntimeError as e:
+            self.assertRegex(format_traceback_short(e.__traceback__), r'.*test_utils.py:\d+ in test_format_traceback_short')
+
 
 if __name__ == '__main__':
     run_tests()
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index ff4bb11..e873769 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -16,6 +16,7 @@
     GuardOnDataDependentSymNode,
 )
 from torch.fx.graph_module import _forward_from_src as original_forward_from_src
+from torch.utils._traceback import format_traceback_short
 
 from . import config, exc
 from .allowed_functions import is_allowed
@@ -52,6 +53,7 @@
     increment_frame,
     is_namedtuple,
     istype,
+    LazyString,
     orig_code_map,
     reset_graph_break_dup_checker,
     setup_compile_debug,
@@ -428,8 +430,11 @@
                 out_code = transform_code_object(code, transform)
                 orig_code_map[out_code] = code
                 break
-            except exc.RestartAnalysis:
-                log.debug("Restarting analysis ...")
+            except exc.RestartAnalysis as e:
+                log.info(
+                    "Restarting analysis due to %s",
+                    LazyString(format_traceback_short, e.__traceback__),
+                )
                 if attempt > 100:
                     unimplemented("100+ RestartAnalysis() calls")
             except exc.SkipFrame as e:
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index 98bcf54..0c8b45b 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -6,7 +6,6 @@
 import logging
 import math
 import operator
-import os
 import re
 import sys
 import textwrap
@@ -34,6 +33,7 @@
 from torch._guards import ShapeGuard, Source, TracingContext, detect_fake_mode
 from torch.utils._sympy.interp import sympy_interp
 from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges, ValueRangeError
+from torch.utils._traceback import format_frame
 
 InputList = List
 DimList = List
@@ -69,10 +69,6 @@
     ]
     return {inspect.getfile(m) for m in mods}
 
-def shorten_filename(fn):
-    prefix = os.path.commonprefix([fn, __file__])
-    return fn[len(prefix):]
-
 SYM_FUNCTION_MODE = None
 
 # We don't bother with the metaclass as all of the dispatching logic happens
@@ -2831,9 +2827,6 @@
                     if frame.filename not in uninteresting_files():
                         break
 
-                def format_frame(frame):
-                    return f"{shorten_filename(frame.filename)}:{frame.lineno} in {frame.name}"
-
                 # NB: this stack is truncated, but it's fine because the main
                 # stack_info will give you the rest of the info you need
                 maybe_user_loc = ""
diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py
index 92e186b..0adb522 100644
--- a/torch/utils/_traceback.py
+++ b/torch/utils/_traceback.py
@@ -1,7 +1,9 @@
 from types import TracebackType
 import tempfile
+import traceback
 import contextlib
 import inspect
+import os.path
 
 # This file contains utilities for ensuring dynamically compile()'d
 # code fragments display their line numbers in backtraces.
@@ -126,3 +128,25 @@
             tb_next = tb
 
         raise exc.with_traceback(tb_next)
+
+def shorten_filename(fn):
+    """
+    Shorten a source filepath, under the assumption that anything under torch/
+    directory is "obvious" and doesn't need to be shown to user.
+    """
+    # Truncate torch/foo.py to foo.py
+    prefix = os.path.commonprefix([fn, os.path.join(os.path.dirname(os.path.dirname(__file__)), "")])
+    return fn[len(prefix):]
+
+def format_frame(frame):
+    """
+    Format a FrameSummary in a short way, without printing full absolute path
+    or code.  The idea is the result fits on a single line.
+    """
+    return f"{shorten_filename(frame.filename)}:{frame.lineno} in {frame.name}"
+
+def format_traceback_short(tb):
+    """
+    Format a TracebackType in a short way, printing only the inner-most frame.
+    """
+    return format_frame(traceback.extract_tb(tb)[-1])