[compiled autograd][aot] Trim runtime refs for list inputs from dynamo (#122535)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122535
Approved by: https://github.com/bdhirsh
ghstack dependencies: #123630, #123674, #122353, #123359
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index da4f7a9..7992ce8 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -1191,6 +1191,43 @@
 
         self.check_output_and_recompiles(fn, 3)
 
+    @unittest.skipIf(not HAS_CUDA, "requires cuda")
+    def test_free_activation_memory(self):
+        self.assertTrue(torch.cuda.memory_allocated() == 0)
+
+        # Use an op to check that the memory is freed by the time the op is executed
+        def assertion_impl(to_clone):
+            mem_allocated = torch.cuda.memory_allocated()
+            self.assertTrue(
+                mem_allocated < 4000000, "activations should have been freed"
+            )
+            return to_clone.clone()
+
+        with torch.library._scoped_library("test_compiled_autograd", "FRAGMENT") as lib:
+            lib.define(
+                "assertion_op(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)
+            )
+            lib.impl("assertion_op", assertion_impl, "CPU")
+            lib.impl("assertion_op", lambda x: x.clone(), "Meta")
+
+            # Create a graph that allows inputs stealing
+            def forward(activations):
+                add = activations[0] + 1
+                out = add.cpu()
+                cloned_out = torch.ops.test_compiled_autograd.assertion_op(out)
+                return (cloned_out,)
+
+            gm = torch.fx.symbolic_trace(forward)
+            torch._dynamo.utils.set_locals_to_steal(gm, ["activations"])
+            compiled_fn = torch.compile(gm)
+
+            # allocate at least 4,000,000 bytes (1,000,000 * 4 bytes)
+            activations = [torch.ones(1000000, dtype=torch.float32, device="cuda")]
+            self.assertTrue(torch.cuda.memory_allocated() > 4000000)
+
+            out = compiled_fn(activations)
+            self.assertTrue(len(activations) == 0)
+
 
 def load_test_module(name):
     testdir = Path(__file__).absolute().parent.parent
@@ -1362,6 +1399,7 @@
     "test_save_for_backward_inputs_are_namedtuple",  # torch._dynamo.exc.Unsupported: 'skip function
     "test_autograd_function_backed_op",  # RuntimeError: compiled_args not implemented
     "test_setitem",  # AssertionError: Tensor-likes are not close!
+    "test_grad_nonleaf_register_hook",  # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
 }
 
 if not HAS_CUDA:
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index c60a308..153a1c2 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2627,6 +2627,17 @@
     return proxy
 
 
+class GmWrapper(torch.nn.Module):
+    def __init__(self, gm, spec):
+        super().__init__()
+        self.gm = gm
+        self.spec = spec
+
+    def forward(self, *args):
+        args: List[Any] = list(args)
+        return self.gm(*pytree.tree_unflatten(args, self.spec))
+
+
 def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
     """
     Mutate inputs so that they are flat and wrap gm such that it
@@ -2634,21 +2645,24 @@
     bumpy inputs.
     """
     inputs, spec = pytree.tree_flatten(inputs)
+    compiled_fn = compile_gm(GmWrapper(gm, spec), inputs)
 
-    class GmWrapper(torch.nn.Module):
-        def __init__(self):
-            super().__init__()
-            self.gm = gm
-
-        def forward(self, *args):
-            args: List[Any] = list(args)
-            return self.gm(*pytree.tree_unflatten(args, spec))
-
-    compiled_fn = compile_gm(GmWrapper(), inputs)
+    idx_to_steal = [
+        i
+        for i, node in enumerate(gm.graph.nodes)
+        if node.op == "placeholder" and node.meta.get("steal_arg", False)
+    ]
 
     def wrapper(*args):
         # note this doesn't check the spec, assuming it is the same
-        return compiled_fn(*pytree.arg_tree_leaves(*args))
+        flat_args = pytree.arg_tree_leaves(*args)
+
+        # flat_args is a new list, so we need to clear references from the old list
+        for i in idx_to_steal:
+            args[i].clear()
+
+        # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward
+        return compiled_fn(flat_args)
 
     return wrapper
 
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 8d1626f..f94c464 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -894,6 +894,7 @@
             tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
                 re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
             )
+            tensor_list_proxy.node.meta["steal_arg"] = True
 
             list_variable = wrap_fx_proxy_cls(
                 target_cls=TensorVariable,
diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
index e2291ea..4bbed04 100644
--- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
@@ -170,7 +170,6 @@
             fakified_out = None
             return out
 
-        # args is a list because compiled_fw is boxed_call
         if fw_metadata.is_rng_op_functionalized:
             # Add the seed and offset to args
             seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index 8ef179b..eaecf43 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -99,8 +99,9 @@
             assert num_tokens == 0
         elif num_tokens > 0:
             # Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
-            # NOTE: this keeps an extra reference to the old args until the end of this function
+            old_args = args
             args = [[None] * num_tokens, *args]
+            old_args.clear()
 
         # stash a ref to each input tensor we plan to use after the compiled function
         orig_inputs = {i: args[i] for i in epilogue_args_idx}
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 818cfd0..44fa024 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -954,10 +954,29 @@
             aot_config,
         )
 
+    if isinstance(mod, torch._dynamo.utils.GmWrapper):
+        # This function is called by the flatten_graph_inputs wrapper, which boxes
+        # the inputs so that they can be freed before the end of this scope.
+        # For overhead reasons, this is not the default wrapper, see comment:
+        # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
+        def boxed_forward(runtime_args: List[Any]):
+            flat_args = []
+            flat_args.extend(params_flat)
+            flat_args.extend(runtime_args)
+            runtime_args.clear()
+            return compiled_fn(flat_args)
+
+        # Just for convenience
+        boxed_forward.zero_grad = mod.zero_grad
+        boxed_forward.named_parameters = mod.named_parameters
+        boxed_forward.named_buffers = mod.named_buffers
+        return boxed_forward
+
     # TODO: There is something deeply wrong here; compiled_fn running with
     # the boxed calling convention, but aot_module_simplified somehow
     # historically returned a function that was not the boxed calling
     # convention.  This should get fixed...
+    # NB: GraphModule/nn.Module rely on the non-boxed calling convention here
     def forward(*runtime_args: Tuple[Any]):
         full_args = []
         full_args.extend(params_flat)