[compiled autograd][aot] Trim runtime refs for list inputs from dynamo (#122535)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122535
Approved by: https://github.com/bdhirsh
ghstack dependencies: #123630, #123674, #122353, #123359
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index da4f7a9..7992ce8 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -1191,6 +1191,43 @@
self.check_output_and_recompiles(fn, 3)
+ @unittest.skipIf(not HAS_CUDA, "requires cuda")
+ def test_free_activation_memory(self):
+ self.assertTrue(torch.cuda.memory_allocated() == 0)
+
+ # Use an op to check that the memory is freed by the time the op is executed
+ def assertion_impl(to_clone):
+ mem_allocated = torch.cuda.memory_allocated()
+ self.assertTrue(
+ mem_allocated < 4000000, "activations should have been freed"
+ )
+ return to_clone.clone()
+
+ with torch.library._scoped_library("test_compiled_autograd", "FRAGMENT") as lib:
+ lib.define(
+ "assertion_op(Tensor x) -> Tensor", tags=(torch.Tag.pt2_compliant_tag,)
+ )
+ lib.impl("assertion_op", assertion_impl, "CPU")
+ lib.impl("assertion_op", lambda x: x.clone(), "Meta")
+
+ # Create a graph that allows inputs stealing
+ def forward(activations):
+ add = activations[0] + 1
+ out = add.cpu()
+ cloned_out = torch.ops.test_compiled_autograd.assertion_op(out)
+ return (cloned_out,)
+
+ gm = torch.fx.symbolic_trace(forward)
+ torch._dynamo.utils.set_locals_to_steal(gm, ["activations"])
+ compiled_fn = torch.compile(gm)
+
+ # allocate at least 4,000,000 bytes (1,000,000 * 4 bytes)
+ activations = [torch.ones(1000000, dtype=torch.float32, device="cuda")]
+ self.assertTrue(torch.cuda.memory_allocated() > 4000000)
+
+ out = compiled_fn(activations)
+ self.assertTrue(len(activations) == 0)
+
def load_test_module(name):
testdir = Path(__file__).absolute().parent.parent
@@ -1362,6 +1399,7 @@
"test_save_for_backward_inputs_are_namedtuple", # torch._dynamo.exc.Unsupported: 'skip function
"test_autograd_function_backed_op", # RuntimeError: compiled_args not implemented
"test_setitem", # AssertionError: Tensor-likes are not close!
+ "test_grad_nonleaf_register_hook", # IndexError: list index out of range (NB: x.grad = y where both x and y are input tensors)
}
if not HAS_CUDA:
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index c60a308..153a1c2 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2627,6 +2627,17 @@
return proxy
+class GmWrapper(torch.nn.Module):
+ def __init__(self, gm, spec):
+ super().__init__()
+ self.gm = gm
+ self.spec = spec
+
+ def forward(self, *args):
+ args: List[Any] = list(args)
+ return self.gm(*pytree.tree_unflatten(args, self.spec))
+
+
def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
"""
Mutate inputs so that they are flat and wrap gm such that it
@@ -2634,21 +2645,24 @@
bumpy inputs.
"""
inputs, spec = pytree.tree_flatten(inputs)
+ compiled_fn = compile_gm(GmWrapper(gm, spec), inputs)
- class GmWrapper(torch.nn.Module):
- def __init__(self):
- super().__init__()
- self.gm = gm
-
- def forward(self, *args):
- args: List[Any] = list(args)
- return self.gm(*pytree.tree_unflatten(args, spec))
-
- compiled_fn = compile_gm(GmWrapper(), inputs)
+ idx_to_steal = [
+ i
+ for i, node in enumerate(gm.graph.nodes)
+ if node.op == "placeholder" and node.meta.get("steal_arg", False)
+ ]
def wrapper(*args):
# note this doesn't check the spec, assuming it is the same
- return compiled_fn(*pytree.arg_tree_leaves(*args))
+ flat_args = pytree.arg_tree_leaves(*args)
+
+ # flat_args is a new list, so we need to clear references from the old list
+ for i in idx_to_steal:
+ args[i].clear()
+
+ # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward
+ return compiled_fn(flat_args)
return wrapper
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 8d1626f..f94c464 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -894,6 +894,7 @@
tensor_list_proxy = self.tx.output.root_tracer.create_graph_input(
re.sub(r"[^a-zA-Z0-9]+", "_", self.name), type(value), source=source
)
+ tensor_list_proxy.node.meta["steal_arg"] = True
list_variable = wrap_fx_proxy_cls(
target_cls=TensorVariable,
diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
index e2291ea..4bbed04 100644
--- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py
@@ -170,7 +170,6 @@
fakified_out = None
return out
- # args is a list because compiled_fw is boxed_call
if fw_metadata.is_rng_op_functionalized:
# Add the seed and offset to args
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py
index 8ef179b..eaecf43 100644
--- a/torch/_functorch/_aot_autograd/runtime_wrappers.py
+++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py
@@ -99,8 +99,9 @@
assert num_tokens == 0
elif num_tokens > 0:
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
- # NOTE: this keeps an extra reference to the old args until the end of this function
+ old_args = args
args = [[None] * num_tokens, *args]
+ old_args.clear()
# stash a ref to each input tensor we plan to use after the compiled function
orig_inputs = {i: args[i] for i in epilogue_args_idx}
diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py
index 818cfd0..44fa024 100644
--- a/torch/_functorch/aot_autograd.py
+++ b/torch/_functorch/aot_autograd.py
@@ -954,10 +954,29 @@
aot_config,
)
+ if isinstance(mod, torch._dynamo.utils.GmWrapper):
+ # This function is called by the flatten_graph_inputs wrapper, which boxes
+ # the inputs so that they can be freed before the end of this scope.
+ # For overhead reasons, this is not the default wrapper, see comment:
+ # https://github.com/pytorch/pytorch/pull/122535/files#r1560096481
+ def boxed_forward(runtime_args: List[Any]):
+ flat_args = []
+ flat_args.extend(params_flat)
+ flat_args.extend(runtime_args)
+ runtime_args.clear()
+ return compiled_fn(flat_args)
+
+ # Just for convenience
+ boxed_forward.zero_grad = mod.zero_grad
+ boxed_forward.named_parameters = mod.named_parameters
+ boxed_forward.named_buffers = mod.named_buffers
+ return boxed_forward
+
# TODO: There is something deeply wrong here; compiled_fn running with
# the boxed calling convention, but aot_module_simplified somehow
# historically returned a function that was not the boxed calling
# convention. This should get fixed...
+ # NB: GraphModule/nn.Module rely on the non-boxed calling convention here
def forward(*runtime_args: Tuple[Any]):
full_args = []
full_args.extend(params_flat)