[dynamo] fine-grained bytecode-source attribution in python 3.11 (#104676)

Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.

Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond

def h(x):
    return x * 5

def true_fn(x):
    return x * 2

def false_fn(x):
    return x * 3

def f(pred, x):
    x = h(
        h(h(x))
    )
    x = x[1:][:2]
    torch._dynamo.graph_break()
    x = cond(pred, true_fn, false_fn, [x])

opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```

Output:
```
$ TORCH_LOGS="trace_call" python playground9.py
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
          ~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
        h(h(x))
        ~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
    x = h(
        ~^
        h(h(x))
        ^^^^^^^
    )
    ^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
    return x * 5
           ~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
    x = x[1:][:2]
        ~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
    return x * 2
           ~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
    return x * 3
           ~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
    x = cond(pred, true_fn, false_fn, [x])
        ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104676
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py
index eed9968..3bd02dc 100644
--- a/test/dynamo/test_logging.py
+++ b/test/dynamo/test_logging.py
@@ -12,6 +12,8 @@
 import torch._dynamo.testing
 import torch.distributed as dist
 
+from torch._dynamo.testing import skipIfNotPy311
+
 from torch.nn.parallel import DistributedDataParallel as DDP
 from torch.testing._internal.common_utils import find_free_port
 from torch.testing._internal.inductor_utils import HAS_CUDA
@@ -408,6 +410,108 @@
         )
         self.assertIn("[rank0]:", stderr.decode("utf-8"))
 
+    @skipIfNotPy311
+    @make_logging_test(trace_call=True)
+    def test_trace_call(self, records):
+        def fn(x, y):
+            return (x * 2) @ (y * 3)
+
+        fn_opt = torch._dynamo.optimize("eager")(fn)
+        fn_opt(torch.randn(10, 20), torch.randn(20, 30))
+
+        self.assertEqual(len(records), 3)
+        # only get last 2 lines
+        messages = [
+            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
+        ]
+        self.assertExpectedInline(
+            messages[0],
+            """\
+            return (x * 2) @ (y * 3)
+                    ~~^~~""",
+        )
+        self.assertExpectedInline(
+            messages[1],
+            """\
+            return (x * 2) @ (y * 3)
+                              ~~^~~""",
+        )
+        self.assertExpectedInline(
+            messages[2],
+            """\
+            return (x * 2) @ (y * 3)
+                   ~~~~~~~~^~~~~~~~~""",
+        )
+
+    @skipIfNotPy311
+    @make_logging_test(trace_call=True)
+    def test_trace_call_inline_call(self, records):
+        def g(x):
+            return x * 2
+
+        def f(x):
+            return g(g(x))
+
+        fn_opt = torch._dynamo.optimize("eager")(f)
+        fn_opt(torch.randn(3, 3))
+
+        self.assertEqual(len(records), 4)
+        messages = [
+            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
+        ]
+        self.assertExpectedInline(
+            messages[0],
+            """\
+            return g(g(x))
+                     ~^^^""",
+        )
+        self.assertExpectedInline(
+            messages[1],
+            """\
+            return x * 2
+                   ~~^~~""",
+        )
+        self.assertExpectedInline(
+            messages[2],
+            """\
+            return g(g(x))
+                   ~^^^^^^""",
+        )
+        self.assertExpectedInline(
+            messages[3],
+            """\
+            return x * 2
+                   ~~^~~""",
+        )
+
+    @skipIfNotPy311
+    @make_logging_test(trace_call=True)
+    def test_trace_call_graph_break(self, records):
+        def fn(x):
+            x = x * 2
+            torch._dynamo.graph_break()
+            return x * 3
+
+        fn_opt = torch._dynamo.optimize("eager")(fn)
+        fn_opt(torch.randn(3, 3))
+
+        self.assertEqual(len(records), 2)
+        messages = [
+            "\n".join(record.getMessage().split("\n")[-2:]) for record in records
+        ]
+        self.assertExpectedInline(
+            messages[0],
+            """\
+            x = x * 2
+                ~~^~~""",
+        )
+        self.assertExpectedInline(
+            messages[1],
+            """\
+            return x * 3
+                   ~~^~~""",
+        )
+
 
 # single record tests
 exclusions = {
@@ -421,6 +525,7 @@
     "perf_hints",
     "not_implemented",
     "trace_source",
+    "trace_call",
     "custom_format_test_artifact",
 }
 for name in torch._logging._internal.log_registry.artifact_names:
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 671fed1..cc5306b 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -2185,8 +2185,6 @@
             f = "linetable_writer"
             return f"Test if {f} generates correct co_linetable: {c}"
 
-        # Dynamo doesn't deal with column locations or end line numbers,
-        # so we only check that start line numbers in the linetables match.
         keys = bytecode_transformation.get_code_keys()
         code_options = {k: getattr(fn.__code__, k) for k in keys}
         result = bytecode_transformation.clean_and_assemble_instructions(
@@ -2197,8 +2195,7 @@
         l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
         self.assertEqual(len(l1), len(l2))
         for p1, p2 in zip(l1, l2):
-            # check that start line numbers match
-            self.assertEqual(p1[0], p2[0])
+            self.assertEqual(p1, p2)
         self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
 
     @skipIfNotPy311
@@ -2238,8 +2235,7 @@
         l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
         self.assertEqual(len(l1), len(l2))
         for p1, p2 in zip(l1, l2):
-            # check that start line numbers match
-            self.assertEqual(p1[0], p2[0])
+            self.assertEqual(p1, p2)
         self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
 
     @unittest.skipIf(
@@ -5158,6 +5154,146 @@
         dis.dis(fn)
         self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3)
 
+    @skipIfNotPy311
+    def test_get_instruction_source_311(self):
+        def f():
+            # flake8: noqa
+            # fmt: off
+            # test binary ops
+            a = ( b   )   +   c
+            a = (a + b) // (c - d)
+            a = b    \
+         +\
+               c  # test
+            a = (
+                (b  # test +
+                    )  \
+                # +
+            << (
+
+                c  # test
+                \
+            )  # test
+            )
+
+            # test slice
+            a = bbb   [  ccc    ]
+            b = bbbbb \
+                [  ccc # test
+
+                 + ddd  \
+
+                ] # test
+            a = bbb[ccc][ddd][eee]
+
+            # test nested and multiline function calls
+            a = g(g(g(b)))
+            a = g(h(
+                g(b),
+                c
+            ))
+
+            # test chained function calls
+            a = (g(x).y)(
+                z
+            )(1)(2)
+
+            # test unicode (match traceback behavior)
+            a = ("🔥🔥🔥" +
+                + "🔥🔥") + b
+
+        from torch._dynamo.utils import get_instruction_source_311
+
+        offsets = (3, 11, 15, 19, 23, 29, 35, 46, 58, 74)
+        insts = list(dis.get_instructions(f))
+        # uncomment to determine offsets
+        # print(*enumerate(insts), sep="\n")
+        all_sources = "\n".join(
+            get_instruction_source_311(f.__code__, insts[offset]) for offset in offsets
+        )
+        self.assertExpectedInline(
+            all_sources,
+            """\
+            a = ( b   )   +   c
+                ~~~~~~~~~~^~~~~
+
+            a = (a + b) // (c - d)
+                ~~~~~~~~^^~~~~~~~~
+
+            a = b    \\
+                ~~~~~~
+         +\\
+         ^~
+               c  # test
+               ~
+
+                (b  # test +
+                ~~~~~~~~~~~~
+                    )  \\
+                    ~~~~
+                # +
+                ~~~
+            << (
+            ^^~~
+
+
+                c  # test
+                ~~~~~~~~~
+                \\
+                ~
+            )  # test
+            ~
+
+            a = bbb   [  ccc    ]
+                ~~~~~~^^^^^^^^^^^
+
+            b = bbbbb \\
+                ~~~~~~~
+                [  ccc # test
+                ^^^^^^^^^^^^^
+
+
+                 + ddd  \\
+                 ^^^^^^^^
+
+
+                ] # test
+                ^
+
+            a = bbb[ccc][ddd][eee]
+                ~~~~~~~~^^^^^
+
+            a = g(g(g(b)))
+                  ~^^^^^^
+
+            a = g(h(
+                  ~^
+                g(b),
+                ^^^^^
+                c
+                ^
+            ))
+            ^
+
+            a = (g(x).y)(
+                ~~~~~~~~~
+                z
+                ~
+            )(1)(2)
+            ~^^^
+""",
+        )
+        # test unicode (since assertExpectedInline doesn't support unicode)
+        self.assertEqual(
+            get_instruction_source_311(f.__code__, insts[84]),
+            """\
+            a = ("🔥🔥🔥" +
+                ~~~~~~~~
+                + "🔥🔥") + b
+                ~~~~~~~~^~~
+""",
+        )
+
     def test_raise_guard_full_constraint(self):
         y = torch.randn([3, 3, 3])
 
diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py
index 58faa1f..20bd965 100644
--- a/torch/_dynamo/bytecode_transformation.py
+++ b/torch/_dynamo/bytecode_transformation.py
@@ -295,13 +295,32 @@
     linetable = []
     lineno = first_lineno
 
-    def update(lineno_new, inst_size):
+    def update(positions: dis.Positions, inst_size):
         nonlocal lineno
+        lineno_new = positions.lineno if positions else None
 
         def _update(delta, size):
             assert 0 < size <= 8
-            # first byte - always use no column info code (13)
-            linetable.append(0b1_1101_000 + size - 1)
+            # first byte - use 13 (no column info) is positions is
+            # malformed, otherwise use 14 (long form)
+            other_varints = ()
+            if (
+                positions
+                and positions.lineno is not None
+                and positions.end_lineno is not None
+                and positions.col_offset is not None
+                and positions.end_col_offset is not None
+            ):
+                linetable.append(0b1_1110_000 + size - 1)
+                # for whatever reason, column offset needs `+ 1`
+                # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603
+                other_varints = (
+                    positions.end_lineno - positions.lineno,
+                    positions.col_offset + 1,
+                    positions.end_col_offset + 1,
+                )
+            else:
+                linetable.append(0b1_1101_000 + size - 1)
             # encode signed int
             if delta < 0:
                 delta = ((-delta) << 1) | 1
@@ -309,6 +328,8 @@
                 delta <<= 1
             # encode unsigned int
             linetable.extend(encode_varint(delta))
+            for n in other_varints:
+                linetable.extend(encode_varint(n))
 
         if lineno_new is None:
             lineno_delta = 0
@@ -420,14 +441,19 @@
     if sys.version_info >= (3, 11):
         lnotab, update_lineno = linetable_311_writer(firstlineno)
         num_ext = 0
-        for inst in instructions:
+        for i, inst in enumerate(instructions):
             if inst.opname == "EXTENDED_ARG":
                 inst_size = 1
                 num_ext += 1
+                # copy positions from the actual instruction
+                for j in (1, 2, 3):
+                    if instructions[i + j].opname != "EXTENDED_ARG":
+                        inst.positions = instructions[i + j].positions
+                        break
             else:
                 inst_size = instruction_size(inst) // 2 + num_ext
                 num_ext = 0
-            update_lineno(inst.starts_line, inst_size)
+            update_lineno(inst.positions, inst_size)
             num_ext = 0
             arg = inst.arg or 0
             code.extend((inst.opcode, arg & 0xFF))
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index 3bc65ff..0002650 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -62,6 +62,7 @@
     count_calls,
     counters,
     dynamo_timed,
+    get_instruction_source_311,
     graph_break_reasons,
     increment_op_count,
     lazy_format_graph_code,
@@ -85,6 +86,7 @@
 graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
 graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
 graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes")
+trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
 
 
 class OutputGraphState(NamedTuple):
@@ -1100,6 +1102,7 @@
         # This is a OrderedDict so that we can
         # maintain the order of args for the HigherOrderOperator call.
         self.lifted_freevars = collections.OrderedDict()
+        self.prev_inst = None
 
     def create_proxy(
         self,
@@ -1160,6 +1163,24 @@
         # append stack trace to fx node
         tx = self.output_graph.current_tx
 
+        # log detailed location of line of code in 3.11
+        if sys.version_info >= (3, 11) and kind in (
+            "call_function",
+            "call_method",
+            "call_module",
+        ):
+            cur_inst = tx.current_instruction
+            if cur_inst is not self.prev_inst and cur_inst.positions.lineno is not None:
+                tx_code = tx.f_code
+                header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno)
+
+                def get_trace_call_log_str():
+                    line = get_instruction_source_311(tx_code, cur_inst).rstrip()
+                    return f"TRACE FX call {rv.node.name} from {header}\n{line}"
+
+                trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
+                self.prev_inst = cur_inst
+
         nn_module_stack = tx.nn_module_stack
         if nn_module_stack:
             rv.node.meta["nn_module_stack"] = nn_module_stack.copy()
diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py
index f3eafba..f5b09e0 100644
--- a/torch/_dynamo/resume_execution.py
+++ b/torch/_dynamo/resume_execution.py
@@ -437,6 +437,8 @@
                 if inst.offset == target.offset:
                     break
                 inst.starts_line = None
+                if sys.version_info >= (3, 11):
+                    inst.positions = None
 
             if cleanup:
                 prefix.extend(cleanup)
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index c246794..de05bc4 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -59,6 +59,7 @@
 from .utils import (
     counters,
     get_fake_value,
+    get_instruction_source_311,
     graph_break_dup_warning_checker,
     istype,
     LazyString,
@@ -108,6 +109,7 @@
 
 log = logging.getLogger(__name__)
 graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
+trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
 trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source")
 
 
@@ -596,19 +598,22 @@
             self.restore_graphstate(state)
             raise
 
-    def get_log_starts_line_log_str(self):
+    def get_line_of_code_header(self, lineno=None):
+        if lineno is None:
+            lineno = self.lineno
         inline_depth_str = (
             f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else ""
         )
-        log_str = f"TRACE starts_line {self.f_code.co_name} {self.f_code.co_filename}:{self.lineno}{inline_depth_str}\n"
+        return f"{self.f_code.co_name} {self.f_code.co_filename}:{lineno}{inline_depth_str}"
+
+    def get_log_starts_line_log_str(self):
+        log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n"
         line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip()
         log_str += f"    {line}"
         return log_str
 
     def log_starts_line(self):
-        trace_source_log.debug(
-            "%s", LazyString(lambda: self.get_log_starts_line_log_str())
-        )
+        trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str))
 
     def step(self):
         """Process exactly one instruction, return False we should exit"""
@@ -2242,6 +2247,16 @@
         # with a single alias
         if torch._logging._internal.log_state.is_artifact_enabled("output_code"):
             suffix = f"\n{dis.Bytecode(code).dis()}"
+        if sys.version_info >= (3, 11):
+            cur_inst = parent.current_instruction
+            parent_code = parent.f_code
+            header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno)
+
+            def get_trace_call_log_str():
+                line = get_instruction_source_311(parent_code, cur_inst).rstrip()
+                return f"TRACE inlined call {code.co_name} from {header}\n{line}"
+
+            trace_call_log.debug("%s", LazyString(get_trace_call_log_str))
         log.debug("INLINING %s%s", code, suffix)
 
         tracer: InliningInstructionTranslator
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index 27aecce..6e4614a 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -11,6 +11,7 @@
 import gc
 import inspect
 import itertools
+import linecache
 import logging
 import math
 import operator
@@ -24,7 +25,7 @@
 import weakref
 from contextlib import contextmanager
 from functools import lru_cache, wraps
-from typing import Any, Dict, Tuple, Union
+from typing import Any, Dict, Optional, Tuple, Union
 
 import torch._logging
 from torch._guards import detect_fake_mode  # noqa: F401
@@ -1756,6 +1757,260 @@
     return compile_supported
 
 
+# The following 3.11 source code functions are adapted from
+# https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py
+# in order to output source code corresponding to bytecode in 3.11+.
+# We need our own versions since we want to support multiline expressions.
+def _fix_offset(str: str, offset: int) -> int:
+    """
+    Convert byte offset `offset` of `str` into character offset.
+    Byte offset is used for 3.11+ instruction column data.
+    Takes things like unicode characters into consideration.
+
+    Unchanged from CPython implementation.
+    """
+    as_utf8 = str.encode("utf-8")
+    return len(as_utf8[:offset].decode("utf-8", errors="replace"))
+
+
+@dataclasses.dataclass
+class _Anchors:
+    # inclusive
+    left_end_lineno: int
+    left_end_offset: int
+    right_start_lineno: int
+    # exclusive
+    right_start_offset: int
+
+
+def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
+    """
+    Given source code `segment` corresponding to a bytecode
+    instruction, determine:
+        - for binary ops, the location of the binary op
+        - for indexing, the location of the brackets.
+    `segment` is expected to be a valid Python expression
+    """
+    assert sys.version_info >= (3, 11)
+
+    import ast
+
+    try:
+        # Without brackets, `segment` is parsed as a statement.
+        # We expect an expression, so wrap `segment` in
+        # brackets to handle multi-line expressions.
+        tree = ast.parse("(\n" + segment + "\n)")
+    except SyntaxError:
+        return None
+
+    if len(tree.body) != 1:
+        return None
+
+    lines = segment.split("\n")
+
+    # get character index given byte offset
+    def normalize(lineno, offset):
+        return _fix_offset(lines[lineno], offset)
+
+    # Gets the next valid character index in `lines`, if
+    # the current location is not valid. Handles empty lines.
+    def next_valid_char(lineno, col):
+        while lineno < len(lines) and col >= len(lines[lineno]):
+            col = 0
+            lineno += 1
+        assert lineno < len(lines) and col < len(lines[lineno])
+        return lineno, col
+
+    # Get the next valid character index in `lines`.
+    def increment(lineno, col):
+        col += 1
+        lineno, col = next_valid_char(lineno, col)
+        assert lineno < len(lines) and col < len(lines[lineno])
+        return lineno, col
+
+    # Get the next valid character at least on the next line
+    def nextline(lineno, col):
+        col = 0
+        lineno += 1
+        lineno, col = next_valid_char(lineno, col)
+        assert lineno < len(lines) and col < len(lines[lineno])
+        return lineno, col
+
+    statement = tree.body[0]
+    if isinstance(statement, ast.Expr):
+        expr = statement.value
+        if isinstance(expr, ast.BinOp):
+            # ast gives locations for BinOp subexpressions, e.g.
+            # ( left_expr ) + ( right_expr )
+            #   left^^^^^       right^^^^^
+            # -2 since end_lineno is 1-indexed and because we added an extra
+            # bracket to `segment` when calling ast.parse
+            cur_lineno = expr.left.end_lineno - 2
+            cur_col = normalize(cur_lineno, expr.left.end_col_offset)
+            cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col)
+
+            # Heuristic to find the operator character.
+            # The original CPython implementation did not look for ), \, or #,
+            # leading to incorrect anchor location, e.g.
+            # (x) + (y)
+            # ~~^~~~~~~
+            while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
+                if ch in "\\#":
+                    cur_lineno, cur_col = nextline(cur_lineno, cur_col)
+                else:
+                    cur_lineno, cur_col = increment(cur_lineno, cur_col)
+
+            # binary op is 1 or 2 characters long, on the same line
+            right_col = cur_col + 1
+            if (
+                right_col < len(lines[cur_lineno])
+                and not (ch := lines[cur_lineno][right_col]).isspace()
+                and ch not in "\\#"
+            ):
+                right_col += 1
+            # right_col can be invalid since it is exclusive
+
+            return _Anchors(cur_lineno, cur_col, cur_lineno, right_col)
+        elif isinstance(expr, ast.Subscript):
+            # ast gives locations for value and slice subexpressions, e.g.
+            # ( value_expr ) [ slice_expr ]
+            #   value^^^^^     slice^^^^^
+            # subscript^^^^^^^^^^^^^^^^^^^^
+            # find left bracket (first '[' after value)
+            left_lineno = expr.value.end_lineno - 2
+            left_col = normalize(left_lineno, expr.value.end_col_offset)
+            left_lineno, left_col = next_valid_char(left_lineno, left_col)
+            while lines[left_lineno][left_col] != "[":
+                left_lineno, left_col = increment(left_lineno, left_col)
+            # find right bracket (final character of expression)
+            right_lineno = expr.end_lineno - 2
+            right_col = normalize(right_lineno, expr.end_col_offset)
+            return _Anchors(left_lineno, left_col, right_lineno, right_col)
+        elif isinstance(expr, ast.Call):
+            # ( func_expr ) (args, kwargs)
+            #   func^^^^^
+            # call^^^^^^^^^^^^^^^^^^^^^^^^
+            # find left bracket (first '(' after func)
+            left_lineno = expr.func.end_lineno - 2
+            left_col = normalize(left_lineno, expr.func.end_col_offset)
+            left_lineno, left_col = next_valid_char(left_lineno, left_col)
+            while lines[left_lineno][left_col] != "(":
+                left_lineno, left_col = increment(left_lineno, left_col)
+            # find right bracket (final character of expression)
+            right_lineno = expr.end_lineno - 2
+            right_col = normalize(right_lineno, expr.end_col_offset)
+            return _Anchors(left_lineno, left_col, right_lineno, right_col)
+
+    return None
+
+
+def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> str:
+    """
+    Python 3.11+ only. Returns lines of source code (from code object `code`)
+    corresponding to `inst`'s location data, and underlines relevant code to `inst`.
+
+    Example: CALL on `g`:
+    f(g(
+      ^^
+        h(x)))
+        ^^^^^
+
+    We need our own implementation since `format_frame_summary` in
+    Python's `traceback` module doesn't handle multi-line expressions
+    (and their anchor extraction code is not completely correct).
+    """
+    if inst.positions.lineno is None:
+        return ""
+    # The rstrip + "\n" pattern is used throughout this function to handle
+    # linecache.getline errors. Error lines are treated as empty strings "", but we want
+    # to treat them as blank lines "\n".
+    first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip()
+    if inst.positions.end_lineno is None:
+        return first_line
+    if inst.positions.col_offset is None or inst.positions.end_col_offset is None:
+        return first_line
+
+    # character index of the start of the instruction
+    start_offset = _fix_offset(first_line, inst.positions.col_offset)
+    # character index of the end of the instruction
+    # compute later since end may be a different line
+    end_offset = None
+    # expression corresponding to the instruction so we can get anchors
+    segment = ""
+    # underline markers to be printed - start with `~` marker and replace with `^` later
+    markers = []
+
+    # Compute segment and initial markers
+    if inst.positions.end_lineno == inst.positions.lineno:
+        end_offset = _fix_offset(first_line, inst.positions.end_col_offset)
+        segment = first_line[start_offset:end_offset]
+        markers.append(" " * start_offset + "~" * (end_offset - start_offset))
+    else:
+        segment = first_line[start_offset:] + "\n"
+        markers.append(" " * start_offset + "~" * (len(first_line) - start_offset))
+        last_line = linecache.getline(
+            code.co_filename, inst.positions.end_lineno
+        ).rstrip()
+        end_offset = _fix_offset(last_line, inst.positions.end_col_offset)
+        for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno):
+            line = linecache.getline(code.co_filename, lineno).rstrip()
+            segment += line + "\n"
+            # don't underline leading spaces
+            num_spaces = len(line) - len(line.lstrip())
+            markers.append(" " * num_spaces + "~" * (len(line) - num_spaces))
+        segment += last_line[:end_offset]
+        num_spaces = len(last_line) - len(last_line.lstrip())
+        markers.append(" " * num_spaces + "~" * (end_offset - num_spaces))
+
+    anchors: Optional[_Anchors] = None
+    try:
+        anchors = _extract_anchors_from_expr(segment)
+    except AssertionError:
+        pass
+
+    # replace `~` markers with `^` where necessary
+    if anchors is None:
+        markers = [marker.replace("~", "^") for marker in markers]
+    else:
+        # make markers mutable
+        markers = [list(marker) for marker in markers]
+
+        # anchor positions do not take start_offset into account
+        if anchors.left_end_lineno == 0:
+            anchors.left_end_offset += start_offset
+        if anchors.right_start_lineno == 0:
+            anchors.right_start_offset += start_offset
+
+        # Turn `~`` markers between anchors to `^`
+        for line in range(len(markers)):
+            for col in range(len(markers[line])):
+                if line < anchors.left_end_lineno:
+                    continue
+                if line == anchors.left_end_lineno and col < anchors.left_end_offset:
+                    continue
+                if (
+                    line == anchors.right_start_lineno
+                    and col >= anchors.right_start_offset
+                ):
+                    continue
+                if line > anchors.right_start_lineno:
+                    continue
+                if markers[line][col] == "~":
+                    markers[line][col] = "^"
+
+        # make markers into strings again
+        markers = ["".join(marker) for marker in markers]
+
+    result = ""
+    for i in range(len(markers)):
+        result += (
+            linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip()
+            + "\n"
+        )
+        result += markers[i] + "\n"
+    return result
+
+
 def is_guard_failure_reporting_enabled():
     return (
         config.report_guard_failures
diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py
index c4e7fcd..3ad713d 100644
--- a/torch/_logging/_internal.py
+++ b/torch/_logging/_internal.py
@@ -138,6 +138,7 @@
     guards: bool = False,
     recompiles: bool = False,
     trace_source: bool = False,
+    trace_call: bool = False,
     output_code: bool = False,
     schedule: bool = False,
     perf_hints: bool = False,
@@ -238,6 +239,10 @@
         trace_source (:class:`bool`):
             Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
 
+        trace_call (:class:`bool`):
+            Whether to emit detailed line location when TorchDynamo creates an FX node
+            corresponding to function call. Python 3.11+ only. Default: ``False``
+
         output_code (:class:`bool`):
             Whether to emit the TorchInductor output code. Default: ``False``
 
@@ -348,6 +353,7 @@
         guards=guards,
         recompiles=recompiles,
         trace_source=trace_source,
+        trace_call=trace_call,
         output_code=output_code,
         schedule=schedule,
         perf_hints=perf_hints,
diff --git a/torch/_logging/_registrations.py b/torch/_logging/_registrations.py
index 4062ace..1639bfe 100644
--- a/torch/_logging/_registrations.py
+++ b/torch/_logging/_registrations.py
@@ -13,6 +13,7 @@
 register_artifact("graph_code")
 register_artifact("graph_sizes")
 register_artifact("trace_source", log_format="")
+register_artifact("trace_call", log_format="")
 register_artifact("aot_graphs")
 register_artifact("aot_joint_graph")
 register_artifact("ddp_graphs")