[Traceable FSDP2] Add auto-functionalize support for mutable list[Tensor] (copy from Brian's PR #127347); enable E2E inductor unit test for transformer model (#129502)
Copy of Brian's PR: https://github.com/pytorch/pytorch/pull/127347 with additional changes to support mutable `List[Tensor]` in Inductor. Also enable E2E inductor unit test for Traceable FSDP2 + transformer model.
Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_set_`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_simple_mlp_fullgraph_backend_aot_eager`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_simple_mlp_fullgraph_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_fullgraph_backend_aot_eager`
- `pytest -rA test/dynamo/test_misc.py::MiscTests::test_auto_functionalize_tensorlist`
- `pytest -rA test/inductor/test_torchinductor.py::GPUTests::test_fallback_mutable_op_list_cuda`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129502
Approved by: https://github.com/zou3519
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py
index e4b5d44..a7c2bc0 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py
@@ -2,10 +2,12 @@
import contextlib
+import copy
import unittest
import torch
import torch._dynamo.testing
+import torch.distributed._composable.fsdp._fsdp_param
from torch import nn
from torch._dynamo import compiled_autograd
@@ -113,6 +115,24 @@
self.assertEqual(cnt.op_count, 1)
self.assertEqual(len(cnt.graphs), 1)
+ def test_trace_fsdp_set_(self):
+ @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
+ def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
+ torch.add(x, 1, out=out)
+
+ def f(x):
+ buf = torch.zeros(2)
+ buf_view = buf.view(-1)
+ torch.ops.mylib.add_one_out(x, out=buf_view)
+ buf_view2 = buf.view(-1)
+ torch.ops.fsdp.set_(x, buf_view2)
+
+ ref_x = torch.zeros(2)
+ x = copy.deepcopy(ref_x)
+ f(ref_x)
+ torch.compile(f, backend="aot_eager")(x)
+ self.assertEqual(x, ref_x)
+
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
@torch._functorch.config.patch(recompute_views=True)
@torch._functorch.config.patch(cse=False)
@@ -210,7 +230,15 @@
*self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
)
- @unittest.expectedFailure
+ @skipIfRocm
+ @skip_if_lt_x_gpu(2)
+ def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
+ self._test_traceable_fsdp(
+ *self._create_simple_mlp_factory_fns(),
+ "aot_eager_decomp_partition",
+ fullgraph=True,
+ )
+
@skipIfRocm
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
@@ -253,10 +281,22 @@
*self._create_transformer_factory_fns(), "aot_eager", fullgraph=True
)
- @unittest.expectedFailure
+ @skipIfRocm
+ @skip_if_lt_x_gpu(2)
+ # TODO: native_dropout has worse accuracy after decomp, need to figure out why
+ @torch._inductor.config.patch(fallback_random=True)
+ def test_transformer_fullgraph_backend_aot_eager_decomp_partition(self):
+ self._test_traceable_fsdp(
+ *self._create_transformer_factory_fns(),
+ "aot_eager_decomp_partition",
+ fullgraph=True,
+ )
+
@skipIfRocm
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
@skip_if_lt_x_gpu(2)
+ # TODO: native_dropout causes CUDA IMA error, need to figure out why
+ @torch._inductor.config.patch(fallback_random=True)
def test_transformer_fullgraph_backend_inductor(self):
self._test_traceable_fsdp(
*self._create_transformer_factory_fns(), "inductor", fullgraph=True
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index bff953d..bd016ea 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -671,6 +671,7 @@
"(Tensor(a!) x) -> ()",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
+ "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
]
@@ -678,7 +679,6 @@
"(Tensor x) -> ()",
"(Tensor(a) x) -> Tensor(a)",
"(Tensor(a!) x) -> Tensor(a!)",
- "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
"(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
"(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
@@ -913,6 +913,39 @@
cleanup_op("mylib::foo")
del lib
+ def test_auto_functionalize_tensorlist(self):
+ with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+ torch.library.define(
+ "mylib::foo",
+ "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
+ tags=torch.Tag.pt2_compliant_tag,
+ lib=lib,
+ )
+
+ @torch.library.impl("mylib::foo", "cpu", lib=lib)
+ @torch._dynamo.disable
+ def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
+ for o in out:
+ o.copy_(all_gather_output)
+
+ def f(all_gather_output, all_gather_input_split_sizes, dim, out):
+ torch.ops.mylib.foo(
+ all_gather_output, all_gather_input_split_sizes, dim, out
+ )
+
+ a = torch.ones(4)
+ b = [2, 3]
+ c = 0
+ d = [torch.empty(4) for _ in range(2)]
+ orig_args = (a, b, c, d)
+
+ compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+ torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
+
+ eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+ f(*eager_args)
+ self.assertEqual(compiled_args, eager_args)
+
def test_shape_int_inplace_binops(self):
def fn(x):
p = x.shape[0]
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 0632d14..09598bf 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -7512,7 +7512,7 @@
def f(a, b):
torch.ops.mylib.inplace_(a, b)
- return ()
+ return None
a = torch.tensor([0.0, 1.0, 2])
b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])]
@@ -7521,11 +7521,26 @@
mod = make_fx(f)(*cloned_args)
cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
- with self.assertRaisesRegex(
- torch._inductor.exc.LoweringException,
- "NYI: Can't generate FallbackKernel",
- ):
- compiled_f = compile_fx_inner(mod, cloned_args)
+ compiled_f = compile_fx_inner(mod, cloned_args)
+
+ @torch.library.custom_op("mylib::sin_out", mutates_args={"outs"})
+ def sin_out(x: torch.Tensor, outs: typing.List[torch.Tensor]) -> None:
+ x_np = x.numpy()
+ assert len(outs) == 2
+ out_np0 = out[0].numpy()
+ out_np1 = out[1].numpy()
+ np.sin(x_np, out=out_np0)
+ np.sin(x_np, out=out_np1)
+
+ @torch.compile
+ def g(x):
+ outs = [torch.empty_like(x) for _ in range(2)]
+ sin_out(x, outs)
+ return outs
+
+ x = torch.randn(3)
+ out = [torch.empty_like(x) for _ in range(2)]
+ y = g(x)
@expectedFailureXPU
def test_functionalize_rng_wrappers(self):
diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py
index 34624d7..02cf2ab 100644
--- a/torch/_functorch/_aot_autograd/functional_utils.py
+++ b/torch/_functorch/_aot_autograd/functional_utils.py
@@ -416,10 +416,6 @@
), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
placeholders.remove(n.args[0])
mutation_count += 1
- elif n.target in [torch.ops.aten.split_with_sizes_copy.out]:
- # These are mutation ops that can show up in the middle of the graph,
- # because they are ops that we explicitly do **not** functinoalize
- continue
else:
assert (
not n.target._schema.is_mutable
diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py
index 77b548b..aa3d449 100644
--- a/torch/_higher_order_ops/auto_functionalize.py
+++ b/torch/_higher_order_ops/auto_functionalize.py
@@ -93,8 +93,13 @@
and type(arg.type.getElementType()) is torch.TensorType
):
continue
+ if (
+ type(arg.type) is torch.ListType
+ and type(arg.type.getElementType()) is torch.TensorType
+ ):
+ continue
# Not yet supported: other Tensor types. This includes things like
- # Tensor[], Tensor?[], Tensor[]?.
+ # Tensor?[], Tensor[]?.
return False
if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):
@@ -129,7 +134,9 @@
new_kwargs[name] = kwargs[name]
else:
new_kwargs[name] = (
- clone_preserve_strides(kwargs[name])
+ [clone_preserve_strides(x) for x in kwargs[name]]
+ if kwargs[name] is not None and isinstance(kwargs[name], list)
+ else clone_preserve_strides(kwargs[name])
if kwargs[name] is not None
else None
)
@@ -250,11 +257,27 @@
# Can be None if input was `Tensor(a!)?`
if unwrapped_out is None:
continue
- assert isinstance(unwrapped_out, torch.Tensor)
+
+ # We only handle Tensor or List[Tensor] here for now.
+ def sync_update(o, orig_arg):
+ ctx.replace(orig_arg, o)
+ ctx.commit_update(orig_arg)
+ ctx.sync(orig_arg)
+
orig_arg = normalized_kwargs[name]
- ctx.replace(orig_arg, unwrapped_out)
- ctx.commit_update(orig_arg)
- ctx.sync(orig_arg)
+
+ if isinstance(unwrapped_out, torch.Tensor):
+ sync_update(unwrapped_out, orig_arg)
+ elif isinstance(unwrapped_out, list) and all(
+ isinstance(o, torch.Tensor) for o in unwrapped_out
+ ):
+ assert len(orig_arg) == len(unwrapped_out)
+ for orig_a, o in zip(orig_arg, unwrapped_out):
+ sync_update(o, orig_a)
+ else:
+ raise RuntimeError(
+ f"unsupported type for auto-functionalization: {unwrapped_out}"
+ )
return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 5ca8bef..28020cf 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -5454,6 +5454,9 @@
is_optional_tensor = isinstance(
info.type, torch.OptionalType
) and isinstance(info.type.getElementType(), torch.TensorType)
+ is_list_tensor = isinstance(info.type, torch.ListType) and isinstance(
+ info.type.getElementType(), torch.TensorType
+ )
if is_optional_tensor or isinstance(info.type, torch.TensorType):
# PyTorch also accepts None and scalar types for args marked as "Tensor".
# We're not going to check all of them here.
@@ -5463,12 +5466,15 @@
return
if info.alias_info is None:
return
- # can_auto_functionalize already filters out mutable List[Tensor].
- # We can support this in the future, but this is very uncommon.
- assert isinstance(info.type, torch.TensorType) or is_optional_tensor
- self.alias_names.append(arg.get_name())
- if info.alias_info.is_write:
- mark_node_as_mutating(self, arg)
+ if is_list_tensor:
+ for tensor_arg in arg:
+ self.alias_names.append(tensor_arg.get_name())
+ mark_node_as_mutating(self, tensor_arg)
+ else:
+ assert isinstance(info.type, torch.TensorType) or is_optional_tensor
+ self.alias_names.append(arg.get_name())
+ if info.alias_info.is_write:
+ mark_node_as_mutating(self, arg)
for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
handle_aliasing_and_mutation(info, arg)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 764022a..ac92757 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -292,14 +292,13 @@
unpacked = True
args = args[0]
- # explicitly assert for "out=" ops for better error messages
- assert not any(
- x == "out" for x in kwargs.keys()
- ), "out= ops aren't yet supported"
# kwargs tensors not supported yet unless it's a fallback op
- assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
- fn in fallbacks for fn in aten_fn
- )
+ if not all(fn in fallbacks for fn in aten_fn):
+ assert not any(isinstance(x, TensorBox) for x in kwargs.values())
+ # explicitly assert for "out=" ops for better error messages
+ assert not any(
+ x == "out" for x in kwargs.keys()
+ ), "out= ops aren't yet supported"
args = transform_args(
args, broadcast, type_promotion_kind, convert_input_to_bool
diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index fdbb461..33de0d7 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -64,6 +64,7 @@
@torch.library.impl(lib, "all_gather_copy_in", "CUDA")
+@torch.library.impl(lib, "all_gather_copy_in", "CPU")
def all_gather_copy_in_cuda(
all_gather_inputs: List[torch.Tensor],
inp_split_sizes: List[int],
@@ -92,6 +93,7 @@
@torch.library.impl(lib, "split_with_sizes_copy", "Meta")
@torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
+@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
def split_with_sizes_copy(
all_gather_output: torch.Tensor,
all_gather_input_split_sizes: List[int],
@@ -103,20 +105,6 @@
)
-@torch.library.impl(lib, "split_with_sizes_copy", "Functionalize")
-def split_with_sizes_copy_functionalize(
- all_gather_output: torch.Tensor,
- all_gather_input_split_sizes: List[int],
- dim: int,
- out: List[torch.Tensor],
-) -> None:
- ag_output_elem = torch._from_functional_tensor(all_gather_output)
- out_elem = [torch._from_functional_tensor(x) for x in out]
- torch.split_with_sizes_copy(
- ag_output_elem, all_gather_input_split_sizes, dim=dim, out=out_elem
- )
-
-
lib.define(
"chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
)
@@ -124,6 +112,7 @@
@torch.library.impl(lib, "chunk_cat", "Meta")
@torch.library.impl(lib, "chunk_cat", "CUDA")
+@torch.library.impl(lib, "chunk_cat", "CPU")
def chunk_cat(
tensors: List[torch.Tensor],
dim: int,
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py
index 955beb8..a8f6f55 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -74,8 +74,50 @@
tensor.set_(data)
+"""
+[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
+
+Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
+(i.e. they show up as a mutation op in the middle of the AOT joint graph).
+
+Reason:
+Traceable FSDP2 compiled autograd BWD graph have the following traits:
+(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
+(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
+(3) They are both subclasses.
+The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
+So this doesn't work at all for Traceable FSDP2.
+
+The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
+This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
+that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
+
+We can avoid this functionalization because:
+(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
+so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
+(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
+So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
+(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
+
+Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
+A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
+Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
+(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
+make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
+This requires a custom FX pass but we believe it's not hard to write and maintain.
+
+Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
+A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
+so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
+This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
+nn.Parameter in order to see the result of .set_.
+"""
+
+
@torch.library.impl(lib, "set_", "Functionalize")
def set__functionalize(tensor, data):
+ torch._sync(tensor)
+ torch._sync(data)
tensor_inner = torch._from_functional_tensor(tensor)
data_inner = torch._from_functional_tensor(data)
tensor_inner.set_(data_inner) # type: ignore[call-overload]
diff --git a/torch/fx/node.py b/torch/fx/node.py
index cabaf98..0105025 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -56,7 +56,6 @@
_ops.profiler._record_function_enter_new,
_ops.profiler._record_function_exit,
_ops.inductor.accumulate_grad_.default,
- _ops.aten.split_with_sizes_copy.out,
} | _side_effectful_need_to_be_preserved_pre_dispatch
if hasattr(_ops.inductor, "resize_storage_bytes_"):
_side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default)