propagate torch stack trace metadata to copy_() nodes during input mutations (#117587)
Tested by running the below script:
```
import torch
@torch.compile(backend="aot_eager", fullgraph=True)
def f(x):
y = x.view(-1)
y.mul_(2)
return
x = torch.ones(4)
f(x)
```
Which gives me this ATen graph (notice that the copy_() node is bundled under the stacktrace for `mul_(2)`):
```
===== Forward graph 0 =====
<eval_with_key>.2 from /data/users/hirsheybar/e/pytorch/torch/fx/experimental/proxy_tensor.py:521 in wrapped class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[4]"):
# File: /data/users/hirsheybar/e/pytorch/tmp5.py:8, code: y = x.view(-1)
view: "f32[4]" = torch.ops.aten.view.default(arg0_1, [-1])
# File: /data/users/hirsheybar/e/pytorch/tmp5.py:9, code: y.mul_(2)
mul: "f32[4]" = torch.ops.aten.mul.Tensor(view, 2); view = None
view_1: "f32[4]" = torch.ops.aten.view.default(mul, [4]); mul = None
copy_: "f32[4]" = torch.ops.aten.copy_.default(arg0_1, view_1); arg0_1 = view_1 = None
return ()
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117587
Approved by: https://github.com/eellison
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 2ee59e8..920c18f 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -4175,6 +4175,46 @@
res = compiled_f(*inputs)
res[0].sum().backward()
+ def test_aot_module_simplified_preserves_stack_trace_from_mutation(self):
+ class MockModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ x_view = x[0]
+ x_view.mul_(2)
+ return (x + x, )
+
+ tracer = torch.fx.Tracer()
+ tracer.record_stack_traces = True
+ graph = tracer.trace(MockModule())
+ mod = torch.fx.GraphModule(tracer.root, graph)
+
+ for node in mod.graph.nodes:
+ if node.op == 'output':
+ continue
+ self.assertTrue(node.stack_trace is not None)
+ assert 'test_aotdispatch.py' in node.stack_trace
+
+ def assert_compiler(gm: torch.fx.GraphModule, _):
+ assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes]
+ for node in gm.graph.nodes:
+ if node.target == torch.ops.aten.copy_.default:
+ assert 'stack_trace' in node.meta
+ assert 'x_view.mul_(2)' in node.meta['stack_trace']
+ return gm.forward # return a python callable
+
+ x = torch.randn(128, 20)
+ inputs = [x]
+
+ aot_module_simplified(
+ mod,
+ inputs,
+ fw_compiler=assert_compiler,
+ bw_compiler=assert_compiler,
+ keep_inference_input_mutations=True,
+ )
+
def test_aot_module_simplified_fake_tensor_gm_raises(self):
fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
real_x = torch.randn(4, requires_grad=True)
diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
index b90458f..94c4e80 100644
--- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
+++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
@@ -15,7 +15,10 @@
from torch._subclasses.functional_tensor import FunctionalTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
-from .functional_utils import assert_functional_graph
+from .functional_utils import (
+ assert_functional_graph,
+ propagate_input_mutation_stacktraces,
+)
from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
from .traced_function_transforms import (
aot_dispatch_subclass,
@@ -96,6 +99,7 @@
fw_module.recompile()
copy_count2 = assert_functional_graph(fw_module.graph)
+ propagate_input_mutation_stacktraces(fw_module.graph)
assert copy_count == copy_count2
diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py
index 0340e02..9fc989e 100644
--- a/torch/_functorch/_aot_autograd/functional_utils.py
+++ b/torch/_functorch/_aot_autograd/functional_utils.py
@@ -336,6 +336,25 @@
return copy_count
+def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None:
+ placeholders = set()
+ for n in fx_g.nodes:
+ if n.op == "placeholder":
+ placeholders.add(n)
+ if isinstance(n.target, torch._ops.OpOverload):
+ if n.target is torch.ops.aten.copy_.default:
+ # Can only copy_ into an input, and can only do so once
+ assert n.args[0] in placeholders
+ placeholders.remove(n.args[0])
+ copy_from_node = n.args[1]
+ # Pre-condition: every node has a "stack_trace" field in its meta,
+ # but copy_() nodes do not (since we manually added them during functionalization).
+ # Instead, we manually propagate here.
+ if "stack_trace" in copy_from_node.meta:
+ assert "stack_trace" not in n.meta, str(n)
+ n.meta["stack_trace"] = copy_from_node.meta["stack_trace"]
+
+
def _check_if_mutation_can_be_in_graph(
keep_input_mutations: bool,
mutates_data,