[dynamo] fine-grained bytecode-source attribution in python 3.11 (#104676)
Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.
Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond
def h(x):
return x * 5
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f(pred, x):
x = h(
h(h(x))
)
x = x[1:][:2]
torch._dynamo.graph_break()
x = cond(pred, true_fn, false_fn, [x])
opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```
Output:
```
$ TORCH_LOGS="trace_call" python playground9.py
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
x = h(
~^
h(h(x))
^^^^^^^
)
^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
return x * 2
~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
return x * 3
~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104676
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py
index eed9968..3bd02dc 100644
--- a/test/dynamo/test_logging.py
+++ b/test/dynamo/test_logging.py
@@ -12,6 +12,8 @@
import torch._dynamo.testing
import torch.distributed as dist
+from torch._dynamo.testing import skipIfNotPy311
+
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_utils import find_free_port
from torch.testing._internal.inductor_utils import HAS_CUDA
@@ -408,6 +410,108 @@
)
self.assertIn("[rank0]:", stderr.decode("utf-8"))
+ @skipIfNotPy311
+ @make_logging_test(trace_call=True)
+ def test_trace_call(self, records):
+ def fn(x, y):
+ return (x * 2) @ (y * 3)
+
+ fn_opt = torch._dynamo.optimize("eager")(fn)
+ fn_opt(torch.randn(10, 20), torch.randn(20, 30))
+
+ self.assertEqual(len(records), 3)
+ # only get last 2 lines
+ messages = [
+ "\n".join(record.getMessage().split("\n")[-2:]) for record in records
+ ]
+ self.assertExpectedInline(
+ messages[0],
+ """\
+ return (x * 2) @ (y * 3)
+ ~~^~~""",
+ )
+ self.assertExpectedInline(
+ messages[1],
+ """\
+ return (x * 2) @ (y * 3)
+ ~~^~~""",
+ )
+ self.assertExpectedInline(
+ messages[2],
+ """\
+ return (x * 2) @ (y * 3)
+ ~~~~~~~~^~~~~~~~~""",
+ )
+
+ @skipIfNotPy311
+ @make_logging_test(trace_call=True)
+ def test_trace_call_inline_call(self, records):
+ def g(x):
+ return x * 2
+
+ def f(x):
+ return g(g(x))
+
+ fn_opt = torch._dynamo.optimize("eager")(f)
+ fn_opt(torch.randn(3, 3))
+
+ self.assertEqual(len(records), 4)
+ messages = [
+ "\n".join(record.getMessage().split("\n")[-2:]) for record in records
+ ]
+ self.assertExpectedInline(
+ messages[0],
+ """\
+ return g(g(x))
+ ~^^^""",
+ )
+ self.assertExpectedInline(
+ messages[1],
+ """\
+ return x * 2
+ ~~^~~""",
+ )
+ self.assertExpectedInline(
+ messages[2],
+ """\
+ return g(g(x))
+ ~^^^^^^""",
+ )
+ self.assertExpectedInline(
+ messages[3],
+ """\
+ return x * 2
+ ~~^~~""",
+ )
+
+ @skipIfNotPy311
+ @make_logging_test(trace_call=True)
+ def test_trace_call_graph_break(self, records):
+ def fn(x):
+ x = x * 2
+ torch._dynamo.graph_break()
+ return x * 3
+
+ fn_opt = torch._dynamo.optimize("eager")(fn)
+ fn_opt(torch.randn(3, 3))
+
+ self.assertEqual(len(records), 2)
+ messages = [
+ "\n".join(record.getMessage().split("\n")[-2:]) for record in records
+ ]
+ self.assertExpectedInline(
+ messages[0],
+ """\
+ x = x * 2
+ ~~^~~""",
+ )
+ self.assertExpectedInline(
+ messages[1],
+ """\
+ return x * 3
+ ~~^~~""",
+ )
+
# single record tests
exclusions = {
@@ -421,6 +525,7 @@
"perf_hints",
"not_implemented",
"trace_source",
+ "trace_call",
"custom_format_test_artifact",
}
for name in torch._logging._internal.log_registry.artifact_names:
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 671fed1..cc5306b 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2185,8 +2185,6 @@
f = "linetable_writer"
return f"Test if {f} generates correct co_linetable: {c}"
- # Dynamo doesn't deal with column locations or end line numbers,
- # so we only check that start line numbers in the linetables match.
keys = bytecode_transformation.get_code_keys()
code_options = {k: getattr(fn.__code__, k) for k in keys}
result = bytecode_transformation.clean_and_assemble_instructions(
@@ -2197,8 +2195,7 @@
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
self.assertEqual(len(l1), len(l2))
for p1, p2 in zip(l1, l2):
- # check that start line numbers match
- self.assertEqual(p1[0], p2[0])
+ self.assertEqual(p1, p2)
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
@skipIfNotPy311
@@ -2238,8 +2235,7 @@
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
self.assertEqual(len(l1), len(l2))
for p1, p2 in zip(l1, l2):
- # check that start line numbers match
- self.assertEqual(p1[0], p2[0])
+ self.assertEqual(p1, p2)
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
@unittest.skipIf(
@@ -5158,6 +5154,146 @@
dis.dis(fn)
self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3)
+ @skipIfNotPy311
+ def test_get_instruction_source_311(self):
+ def f():
+ # flake8: noqa
+ # fmt: off
+ # test binary ops
+ a = ( b ) + c
+ a = (a + b) // (c - d)
+ a = b \
+ +\
+ c # test
+ a = (
+ (b # test +
+ ) \
+ # +
+ << (
+
+ c # test
+ \
+ ) # test
+ )
+
+ # test slice
+ a = bbb [ ccc ]
+ b = bbbbb \
+ [ ccc # test
+
+ + ddd \
+
+ ] # test
+ a = bbb[ccc][ddd][eee]
+
+ # test nested and multiline function calls
+ a = g(g(g(b)))
+ a = g(h(
+ g(b),
+ c
+ ))
+
+ # test chained function calls
+ a = (g(x).y)(
+ z
+ )(1)(2)
+
+ # test unicode (match traceback behavior)
+ a = ("🔥🔥🔥" +
+ + "🔥🔥") + b
+
+ from torch._dynamo.utils import get_instruction_source_311
+
+ offsets = (3, 11, 15, 19, 23, 29, 35, 46, 58, 74)
+ insts = list(dis.get_instructions(f))
+ # uncomment to determine offsets
+ # print(*enumerate(insts), sep="\n")
+ all_sources = "\n".join(
+ get_instruction_source_311(f.__code__, insts[offset]) for offset in offsets
+ )
+ self.assertExpectedInline(
+ all_sources,
+ """\
+ a = ( b ) + c
+ ~~~~~~~~~~^~~~~
+
+ a = (a + b) // (c - d)
+ ~~~~~~~~^^~~~~~~~~
+
+ a = b \\
+ ~~~~~~
+ +\\
+ ^~
+ c # test
+ ~
+
+ (b # test +
+ ~~~~~~~~~~~~
+ ) \\
+ ~~~~
+ # +
+ ~~~
+ << (
+ ^^~~
+
+
+ c # test
+ ~~~~~~~~~
+ \\
+ ~
+ ) # test
+ ~
+
+ a = bbb [ ccc ]
+ ~~~~~~^^^^^^^^^^^
+
+ b = bbbbb \\
+ ~~~~~~~
+ [ ccc # test
+ ^^^^^^^^^^^^^
+
+
+ + ddd \\
+ ^^^^^^^^
+
+
+ ] # test
+ ^
+
+ a = bbb[ccc][ddd][eee]
+ ~~~~~~~~^^^^^
+
+ a = g(g(g(b)))
+ ~^^^^^^
+
+ a = g(h(
+ ~^
+ g(b),
+ ^^^^^
+ c
+ ^
+ ))
+ ^
+
+ a = (g(x).y)(
+ ~~~~~~~~~
+ z
+ ~
+ )(1)(2)
+ ~^^^
+""",
+ )
+ # test unicode (since assertExpectedInline doesn't support unicode)
+ self.assertEqual(
+ get_instruction_source_311(f.__code__, insts[84]),
+ """\
+ a = ("🔥🔥🔥" +
+ ~~~~~~~~
+ + "🔥🔥") + b
+ ~~~~~~~~^~~
+""",
+ )
+
def test_raise_guard_full_constraint(self):
y = torch.randn([3, 3, 3])
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
index 58faa1f..20bd965 100644
--- a/torch/_dynamo/bytecode_transformation.py
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -295,13 +295,32 @@
linetable = []
lineno = first_lineno
- def update(lineno_new, inst_size):
+ def update(positions: dis.Positions, inst_size):
nonlocal lineno
+ lineno_new = positions.lineno if positions else None
def _update(delta, size):
assert 0 < size <= 8
- # first byte - always use no column info code (13)
- linetable.append(0b1_1101_000 + size - 1)
+ # first byte - use 13 (no column info) is positions is
+ # malformed, otherwise use 14 (long form)
+ other_varints = ()
+ if (
+ positions
+ and positions.lineno is not None
+ and positions.end_lineno is not None
+ and positions.col_offset is not None
+ and positions.end_col_offset is not None
+ ):
+ linetable.append(0b1_1110_000 + size - 1)
+ # for whatever reason, column offset needs `+ 1`
+ # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603
+ other_varints = (
+ positions.end_lineno - positions.lineno,
+ positions.col_offset + 1,
+ positions.end_col_offset + 1,
+ )
+ else:
+ linetable.append(0b1_1101_000 + size - 1)
# encode signed int
if delta < 0:
delta = ((-delta) << 1) | 1
@@ -309,6 +328,8 @@
delta <<= 1
# encode unsigned int
linetable.extend(encode_varint(delta))
+ for n in other_varints:
+ linetable.extend(encode_varint(n))
if lineno_new is None:
lineno_delta = 0
@@ -420,14 +441,19 @@
if sys.version_info >= (3, 11):
lnotab, update_lineno = linetable_311_writer(firstlineno)
num_ext = 0
- for inst in instructions:
+ for i, inst in enumerate(instructions):
if inst.opname == "EXTENDED_ARG":
inst_size = 1
num_ext += 1
+ # copy positions from the actual instruction
+ for j in (1, 2, 3):
+ if instructions[i + j].opname != "EXTENDED_ARG":
+ inst.positions = instructions[i + j].positions
+ break
else:
inst_size = instruction_size(inst) // 2 + num_ext
num_ext = 0
- update_lineno(inst.starts_line, inst_size)
+ update_lineno(inst.positions, inst_size)
num_ext = 0
arg = inst.arg or 0
code.extend((inst.opcode, arg & 0xFF))
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 3bc65ff..0002650 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -62,6 +62,7 @@
count_calls,
counters,
dynamo_timed,
+ get_instruction_source_311,
graph_break_reasons,
increment_op_count,
lazy_format_graph_code,
@@ -85,6 +86,7 @@
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
+trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
class OutputGraphState(NamedTuple):
@@ -1100,6 +1102,7 @@
# This is a OrderedDict so that we can
# maintain the order of args for the HigherOrderOperator call.
self.lifted_freevars = collections.OrderedDict()
+ self.prev_inst = None
def create_proxy(
self,
@@ -1160,6 +1163,24 @@
# append stack trace to fx node
tx = self.output_graph.current_tx
+ # log detailed location of line of code in 3.11
+ if sys.version_info >= (3, 11) and kind in (
+ "call_function",
+ "call_method",
+ "call_module",
+ ):
+ cur_inst = tx.current_instruction
+ if cur_inst is not self.prev_inst and cur_inst.positions.lineno is not None:
+ tx_code = tx.f_code
+ header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
+
+ def get_trace_call_log_str():
+ line = get_instruction_source_311(tx_code, cur_inst).rstrip()
+ return f"TRACE FX call {rv.node.name} from {header}\n{line}"
+
+ trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
+ self.prev_inst = cur_inst
+
nn_module_stack = tx.nn_module_stack
if nn_module_stack:
rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
index f3eafba..f5b09e0 100644
--- a/torch/_dynamo/resume_execution.py
+++ b/torch/_dynamo/resume_execution.py
@@ -437,6 +437,8 @@
if inst.offset == target.offset:
break
inst.starts_line = None
+ if sys.version_info >= (3, 11):
+ inst.positions = None
if cleanup:
prefix.extend(cleanup)
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index c246794..de05bc4 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -59,6 +59,7 @@
from .utils import (
counters,
get_fake_value,
+ get_instruction_source_311,
graph_break_dup_warning_checker,
istype,
LazyString,
@@ -108,6 +109,7 @@
log = logging.getLogger(__name__)
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
+trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
@@ -596,19 +598,22 @@
self.restore_graphstate(state)
raise
- def get_log_starts_line_log_str(self):
+ def get_line_of_code_header(self, lineno=None):
+ if lineno is None:
+ lineno = self.lineno
inline_depth_str = (
f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
)
- log_str = f"TRACE starts_line {self.f_code.co_name} {self.f_code.co_filename}:{self.lineno}{inline_depth_str}\n"
+ return f"{self.f_code.co_name} {self.f_code.co_filename}:{lineno}{inline_depth_str}"
+
+ def get_log_starts_line_log_str(self):
+ log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
log_str += f" {line}"
return log_str
def log_starts_line(self):
- trace_source_log.debug(
- "%s", LazyString(lambda: self.get_log_starts_line_log_str())
- )
+ trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
def step(self):
"""Process exactly one instruction, return False we should exit"""
@@ -2242,6 +2247,16 @@
# with a single alias
if torch._logging._internal.log_state.is_artifact_enabled("output_code"):
suffix = f"\n{dis.Bytecode(code).dis()}"
+ if sys.version_info >= (3, 11):
+ cur_inst = parent.current_instruction
+ parent_code = parent.f_code
+ header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno)
+
+ def get_trace_call_log_str():
+ line = get_instruction_source_311(parent_code, cur_inst).rstrip()
+ return f"TRACE inlined call {code.co_name} from {header}\n{line}"
+
+ trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
log.debug("INLINING %s%s", code, suffix)
tracer: InliningInstructionTranslator
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 27aecce..6e4614a 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -11,6 +11,7 @@
import gc
import inspect
import itertools
+import linecache
import logging
import math
import operator
@@ -24,7 +25,7 @@
import weakref
from contextlib import contextmanager
from functools import lru_cache, wraps
-from typing import Any, Dict, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
import torch._logging
from torch._guards import detect_fake_mode # noqa: F401
@@ -1756,6 +1757,260 @@
return compile_supported
+# The following 3.11 source code functions are adapted from
+# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py
+# in order to output source code corresponding to bytecode in 3.11+.
+# We need our own versions since we want to support multiline expressions.
+def _fix_offset(str: str, offset: int) -> int:
+ """
+ Convert byte offset `offset` of `str` into character offset.
+ Byte offset is used for 3.11+ instruction column data.
+ Takes things like unicode characters into consideration.
+
+ Unchanged from CPython implementation.
+ """
+ as_utf8 = str.encode("utf-8")
+ return len(as_utf8[:offset].decode("utf-8", errors="replace"))
+
+
+@dataclasses.dataclass
+class _Anchors:
+ # inclusive
+ left_end_lineno: int
+ left_end_offset: int
+ right_start_lineno: int
+ # exclusive
+ right_start_offset: int
+
+
+def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
+ """
+ Given source code `segment` corresponding to a bytecode
+ instruction, determine:
+ - for binary ops, the location of the binary op
+ - for indexing, the location of the brackets.
+ `segment` is expected to be a valid Python expression
+ """
+ assert sys.version_info >= (3, 11)
+
+ import ast
+
+ try:
+ # Without brackets, `segment` is parsed as a statement.
+ # We expect an expression, so wrap `segment` in
+ # brackets to handle multi-line expressions.
+ tree = ast.parse("(\n" + segment + "\n)")
+ except SyntaxError:
+ return None
+
+ if len(tree.body) != 1:
+ return None
+
+ lines = segment.split("\n")
+
+ # get character index given byte offset
+ def normalize(lineno, offset):
+ return _fix_offset(lines[lineno], offset)
+
+ # Gets the next valid character index in `lines`, if
+ # the current location is not valid. Handles empty lines.
+ def next_valid_char(lineno, col):
+ while lineno < len(lines) and col >= len(lines[lineno]):
+ col = 0
+ lineno += 1
+ assert lineno < len(lines) and col < len(lines[lineno])
+ return lineno, col
+
+ # Get the next valid character index in `lines`.
+ def increment(lineno, col):
+ col += 1
+ lineno, col = next_valid_char(lineno, col)
+ assert lineno < len(lines) and col < len(lines[lineno])
+ return lineno, col
+
+ # Get the next valid character at least on the next line
+ def nextline(lineno, col):
+ col = 0
+ lineno += 1
+ lineno, col = next_valid_char(lineno, col)
+ assert lineno < len(lines) and col < len(lines[lineno])
+ return lineno, col
+
+ statement = tree.body[0]
+ if isinstance(statement, ast.Expr):
+ expr = statement.value
+ if isinstance(expr, ast.BinOp):
+ # ast gives locations for BinOp subexpressions, e.g.
+ # ( left_expr ) + ( right_expr )
+ # left^^^^^ right^^^^^
+ # -2 since end_lineno is 1-indexed and because we added an extra
+ # bracket to `segment` when calling ast.parse
+ cur_lineno = expr.left.end_lineno - 2
+ cur_col = normalize(cur_lineno, expr.left.end_col_offset)
+ cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col)
+
+ # Heuristic to find the operator character.
+ # The original CPython implementation did not look for ), \, or #,
+ # leading to incorrect anchor location, e.g.
+ # (x) + (y)
+ # ~~^~~~~~~
+ while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
+ if ch in "\\#":
+ cur_lineno, cur_col = nextline(cur_lineno, cur_col)
+ else:
+ cur_lineno, cur_col = increment(cur_lineno, cur_col)
+
+ # binary op is 1 or 2 characters long, on the same line
+ right_col = cur_col + 1
+ if (
+ right_col < len(lines[cur_lineno])
+ and not (ch := lines[cur_lineno][right_col]).isspace()
+ and ch not in "\\#"
+ ):
+ right_col += 1
+ # right_col can be invalid since it is exclusive
+
+ return _Anchors(cur_lineno, cur_col, cur_lineno, right_col)
+ elif isinstance(expr, ast.Subscript):
+ # ast gives locations for value and slice subexpressions, e.g.
+ # ( value_expr ) [ slice_expr ]
+ # value^^^^^ slice^^^^^
+ # subscript^^^^^^^^^^^^^^^^^^^^
+ # find left bracket (first '[' after value)
+ left_lineno = expr.value.end_lineno - 2
+ left_col = normalize(left_lineno, expr.value.end_col_offset)
+ left_lineno, left_col = next_valid_char(left_lineno, left_col)
+ while lines[left_lineno][left_col] != "[":
+ left_lineno, left_col = increment(left_lineno, left_col)
+ # find right bracket (final character of expression)
+ right_lineno = expr.end_lineno - 2
+ right_col = normalize(right_lineno, expr.end_col_offset)
+ return _Anchors(left_lineno, left_col, right_lineno, right_col)
+ elif isinstance(expr, ast.Call):
+ # ( func_expr ) (args, kwargs)
+ # func^^^^^
+ # call^^^^^^^^^^^^^^^^^^^^^^^^
+ # find left bracket (first '(' after func)
+ left_lineno = expr.func.end_lineno - 2
+ left_col = normalize(left_lineno, expr.func.end_col_offset)
+ left_lineno, left_col = next_valid_char(left_lineno, left_col)
+ while lines[left_lineno][left_col] != "(":
+ left_lineno, left_col = increment(left_lineno, left_col)
+ # find right bracket (final character of expression)
+ right_lineno = expr.end_lineno - 2
+ right_col = normalize(right_lineno, expr.end_col_offset)
+ return _Anchors(left_lineno, left_col, right_lineno, right_col)
+
+ return None
+
+
+def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str:
+ """
+ Python 3.11+ only. Returns lines of source code (from code object `code`)
+ corresponding to `inst`'s location data, and underlines relevant code to `inst`.
+
+ Example: CALL on `g`:
+ f(g(
+ ^^
+ h(x)))
+ ^^^^^
+
+ We need our own implementation since `format_frame_summary` in
+ Python's `traceback` module doesn't handle multi-line expressions
+ (and their anchor extraction code is not completely correct).
+ """
+ if inst.positions.lineno is None:
+ return ""
+ # The rstrip + "\n" pattern is used throughout this function to handle
+ # linecache.getline errors. Error lines are treated as empty strings "", but we want
+ # to treat them as blank lines "\n".
+ first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip()
+ if inst.positions.end_lineno is None:
+ return first_line
+ if inst.positions.col_offset is None or inst.positions.end_col_offset is None:
+ return first_line
+
+ # character index of the start of the instruction
+ start_offset = _fix_offset(first_line, inst.positions.col_offset)
+ # character index of the end of the instruction
+ # compute later since end may be a different line
+ end_offset = None
+ # expression corresponding to the instruction so we can get anchors
+ segment = ""
+ # underline markers to be printed - start with `~` marker and replace with `^` later
+ markers = []
+
+ # Compute segment and initial markers
+ if inst.positions.end_lineno == inst.positions.lineno:
+ end_offset = _fix_offset(first_line, inst.positions.end_col_offset)
+ segment = first_line[start_offset:end_offset]
+ markers.append(" " * start_offset + "~" * (end_offset - start_offset))
+ else:
+ segment = first_line[start_offset:] + "\n"
+ markers.append(" " * start_offset + "~" * (len(first_line) - start_offset))
+ last_line = linecache.getline(
+ code.co_filename, inst.positions.end_lineno
+ ).rstrip()
+ end_offset = _fix_offset(last_line, inst.positions.end_col_offset)
+ for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno):
+ line = linecache.getline(code.co_filename, lineno).rstrip()
+ segment += line + "\n"
+ # don't underline leading spaces
+ num_spaces = len(line) - len(line.lstrip())
+ markers.append(" " * num_spaces + "~" * (len(line) - num_spaces))
+ segment += last_line[:end_offset]
+ num_spaces = len(last_line) - len(last_line.lstrip())
+ markers.append(" " * num_spaces + "~" * (end_offset - num_spaces))
+
+ anchors: Optional[_Anchors] = None
+ try:
+ anchors = _extract_anchors_from_expr(segment)
+ except AssertionError:
+ pass
+
+ # replace `~` markers with `^` where necessary
+ if anchors is None:
+ markers = [marker.replace("~", "^") for marker in markers]
+ else:
+ # make markers mutable
+ markers = [list(marker) for marker in markers]
+
+ # anchor positions do not take start_offset into account
+ if anchors.left_end_lineno == 0:
+ anchors.left_end_offset += start_offset
+ if anchors.right_start_lineno == 0:
+ anchors.right_start_offset += start_offset
+
+ # Turn `~`` markers between anchors to `^`
+ for line in range(len(markers)):
+ for col in range(len(markers[line])):
+ if line < anchors.left_end_lineno:
+ continue
+ if line == anchors.left_end_lineno and col < anchors.left_end_offset:
+ continue
+ if (
+ line == anchors.right_start_lineno
+ and col >= anchors.right_start_offset
+ ):
+ continue
+ if line > anchors.right_start_lineno:
+ continue
+ if markers[line][col] == "~":
+ markers[line][col] = "^"
+
+ # make markers into strings again
+ markers = ["".join(marker) for marker in markers]
+
+ result = ""
+ for i in range(len(markers)):
+ result += (
+ linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip()
+ + "\n"
+ )
+ result += markers[i] + "\n"
+ return result
+
+
def is_guard_failure_reporting_enabled():
return (
config.report_guard_failures
diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py
index c4e7fcd..3ad713d 100644
--- a/torch/_logging/_internal.py
+++ b/torch/_logging/_internal.py
@@ -138,6 +138,7 @@
guards: bool = False,
recompiles: bool = False,
trace_source: bool = False,
+ trace_call: bool = False,
output_code: bool = False,
schedule: bool = False,
perf_hints: bool = False,
@@ -238,6 +239,10 @@
trace_source (:class:`bool`):
Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
+ trace_call (:class:`bool`):
+ Whether to emit detailed line location when TorchDynamo creates an FX node
+ corresponding to function call. Python 3.11+ only. Default: ``False``
+
output_code (:class:`bool`):
Whether to emit the TorchInductor output code. Default: ``False``
@@ -348,6 +353,7 @@
guards=guards,
recompiles=recompiles,
trace_source=trace_source,
+ trace_call=trace_call,
output_code=output_code,
schedule=schedule,
perf_hints=perf_hints,
diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py
index 4062ace..1639bfe 100644
--- a/torch/_logging/_registrations.py
+++ b/torch/_logging/_registrations.py
@@ -13,6 +13,7 @@
register_artifact("graph_code")
register_artifact("graph_sizes")
register_artifact("trace_source", log_format="")
+register_artifact("trace_call", log_format="")
register_artifact("aot_graphs")
register_artifact("aot_joint_graph")
register_artifact("ddp_graphs")