propagate torch stack trace metadata to copy_() nodes during input mutations (#117587)

Tested by running the below script:
```
import torch
@torch.compile(backend="aot_eager", fullgraph=True)
def f(x):
    y = x.view(-1)
    y.mul_(2)
    return

x = torch.ones(4)
f(x)
```

Which gives me this ATen graph (notice that the copy_() node is bundled under the stacktrace for `mul_(2)`):
```
 ===== Forward graph 0 =====
 <eval_with_key>.2 from /data/users/hirsheybar/e/pytorch/torch/fx/experimental/proxy_tensor.py:521 in wrapped class <lambda>(torch.nn.Module):
    def forward(self, arg0_1: "f32[4]"):
        # File: /data/users/hirsheybar/e/pytorch/tmp5.py:8, code: y = x.view(-1)
        view: "f32[4]" = torch.ops.aten.view.default(arg0_1, [-1])

        # File: /data/users/hirsheybar/e/pytorch/tmp5.py:9, code: y.mul_(2)
        mul: "f32[4]" = torch.ops.aten.mul.Tensor(view, 2);  view = None
        view_1: "f32[4]" = torch.ops.aten.view.default(mul, [4]);  mul = None
        copy_: "f32[4]" = torch.ops.aten.copy_.default(arg0_1, view_1);  arg0_1 = view_1 = None
        return ()

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117587
Approved by: https://github.com/eellison
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 2ee59e8..920c18f 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -4175,6 +4175,46 @@
         res = compiled_f(*inputs)
         res[0].sum().backward()
 
+    def test_aot_module_simplified_preserves_stack_trace_from_mutation(self):
+        class MockModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+
+            def forward(self, x):
+                x_view = x[0]
+                x_view.mul_(2)
+                return (x + x, )
+
+        tracer = torch.fx.Tracer()
+        tracer.record_stack_traces = True
+        graph = tracer.trace(MockModule())
+        mod = torch.fx.GraphModule(tracer.root, graph)
+
+        for node in mod.graph.nodes:
+            if node.op == 'output':
+                continue
+            self.assertTrue(node.stack_trace is not None)
+            assert 'test_aotdispatch.py' in node.stack_trace
+
+        def assert_compiler(gm: torch.fx.GraphModule, _):
+            assert torch.ops.aten.copy_.default in [x.target for x in gm.graph.nodes]
+            for node in gm.graph.nodes:
+                if node.target == torch.ops.aten.copy_.default:
+                    assert 'stack_trace' in node.meta
+                    assert 'x_view.mul_(2)' in node.meta['stack_trace']
+            return gm.forward  # return a python callable
+
+        x = torch.randn(128, 20)
+        inputs = [x]
+
+        aot_module_simplified(
+            mod,
+            inputs,
+            fw_compiler=assert_compiler,
+            bw_compiler=assert_compiler,
+            keep_inference_input_mutations=True,
+        )
+
     def test_aot_module_simplified_fake_tensor_gm_raises(self):
         fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
         real_x = torch.randn(4, requires_grad=True)
diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
index b90458f..94c4e80 100644
--- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
+++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py
@@ -15,7 +15,10 @@
 from torch._subclasses.functional_tensor import FunctionalTensorMode
 from torch.fx.experimental.proxy_tensor import make_fx
 
-from .functional_utils import assert_functional_graph
+from .functional_utils import (
+    assert_functional_graph,
+    propagate_input_mutation_stacktraces,
+)
 from .schemas import AOTConfig, SubclassMeta, ViewAndMutationMeta
 from .traced_function_transforms import (
     aot_dispatch_subclass,
@@ -96,6 +99,7 @@
     fw_module.recompile()
 
     copy_count2 = assert_functional_graph(fw_module.graph)
+    propagate_input_mutation_stacktraces(fw_module.graph)
 
     assert copy_count == copy_count2
 
diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py
index 0340e02..9fc989e 100644
--- a/torch/_functorch/_aot_autograd/functional_utils.py
+++ b/torch/_functorch/_aot_autograd/functional_utils.py
@@ -336,6 +336,25 @@
     return copy_count
 
 
+def propagate_input_mutation_stacktraces(fx_g: torch.fx.Graph) -> None:
+    placeholders = set()
+    for n in fx_g.nodes:
+        if n.op == "placeholder":
+            placeholders.add(n)
+        if isinstance(n.target, torch._ops.OpOverload):
+            if n.target is torch.ops.aten.copy_.default:
+                # Can only copy_ into an input, and can only do so once
+                assert n.args[0] in placeholders
+                placeholders.remove(n.args[0])
+                copy_from_node = n.args[1]
+                # Pre-condition: every node has a "stack_trace" field in its meta,
+                # but copy_() nodes do not (since we manually added them during functionalization).
+                # Instead, we manually propagate here.
+                if "stack_trace" in copy_from_node.meta:
+                    assert "stack_trace" not in n.meta, str(n)
+                    n.meta["stack_trace"] = copy_from_node.meta["stack_trace"]
+
+
 def _check_if_mutation_can_be_in_graph(
     keep_input_mutations: bool,
     mutates_data,