[Traceable FSDP2] Add auto-functionalize support for mutable list[Tensor] (copy from Brian's PR #127347); enable E2E inductor unit test for transformer model (#129502)

Copy of Brian's PR: https://github.com/pytorch/pytorch/pull/127347 with additional changes to support mutable `List[Tensor]` in Inductor. Also enable E2E inductor unit test for Traceable FSDP2 + transformer model.

Test commands:
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_trace_fsdp_set_`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_simple_mlp_fullgraph_backend_aot_eager`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_simple_mlp_fullgraph_backend_inductor`
- `pytest -rA test/distributed/_composable/fsdp/test_fully_shard_compile.py::TestFullyShardCompile::test_transformer_fullgraph_backend_aot_eager`
- `pytest -rA test/dynamo/test_misc.py::MiscTests::test_auto_functionalize_tensorlist`
- `pytest -rA  test/inductor/test_torchinductor.py::GPUTests::test_fallback_mutable_op_list_cuda`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129502
Approved by: https://github.com/zou3519
diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py
index e4b5d44..a7c2bc0 100644
--- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py
+++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py
@@ -2,10 +2,12 @@
 
 
 import contextlib
+import copy
 import unittest
 
 import torch
 import torch._dynamo.testing
+import torch.distributed._composable.fsdp._fsdp_param
 from torch import nn
 from torch._dynamo import compiled_autograd
 
@@ -113,6 +115,24 @@
         self.assertEqual(cnt.op_count, 1)
         self.assertEqual(len(cnt.graphs), 1)
 
+    def test_trace_fsdp_set_(self):
+        @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"})
+        def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None:
+            torch.add(x, 1, out=out)
+
+        def f(x):
+            buf = torch.zeros(2)
+            buf_view = buf.view(-1)
+            torch.ops.mylib.add_one_out(x, out=buf_view)
+            buf_view2 = buf.view(-1)
+            torch.ops.fsdp.set_(x, buf_view2)
+
+        ref_x = torch.zeros(2)
+        x = copy.deepcopy(ref_x)
+        f(ref_x)
+        torch.compile(f, backend="aot_eager")(x)
+        self.assertEqual(x, ref_x)
+
     @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
     @torch._functorch.config.patch(recompute_views=True)
     @torch._functorch.config.patch(cse=False)
@@ -210,7 +230,15 @@
             *self._create_simple_mlp_factory_fns(), "aot_eager", fullgraph=True
         )
 
-    @unittest.expectedFailure
+    @skipIfRocm
+    @skip_if_lt_x_gpu(2)
+    def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self):
+        self._test_traceable_fsdp(
+            *self._create_simple_mlp_factory_fns(),
+            "aot_eager_decomp_partition",
+            fullgraph=True,
+        )
+
     @skipIfRocm
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
     @skip_if_lt_x_gpu(2)
@@ -253,10 +281,22 @@
             *self._create_transformer_factory_fns(), "aot_eager", fullgraph=True
         )
 
-    @unittest.expectedFailure
+    @skipIfRocm
+    @skip_if_lt_x_gpu(2)
+    # TODO: native_dropout has worse accuracy after decomp, need to figure out why
+    @torch._inductor.config.patch(fallback_random=True)
+    def test_transformer_fullgraph_backend_aot_eager_decomp_partition(self):
+        self._test_traceable_fsdp(
+            *self._create_transformer_factory_fns(),
+            "aot_eager_decomp_partition",
+            fullgraph=True,
+        )
+
     @skipIfRocm
     @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
     @skip_if_lt_x_gpu(2)
+    # TODO: native_dropout causes CUDA IMA error, need to figure out why
+    @torch._inductor.config.patch(fallback_random=True)
     def test_transformer_fullgraph_backend_inductor(self):
         self._test_traceable_fsdp(
             *self._create_transformer_factory_fns(), "inductor", fullgraph=True
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index bff953d..bd016ea 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -671,6 +671,7 @@
             "(Tensor(a!) x) -> ()",
             "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
             "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
+            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
             "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
             "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
         ]
@@ -678,7 +679,6 @@
             "(Tensor x) -> ()",
             "(Tensor(a) x) -> Tensor(a)",
             "(Tensor(a!) x) -> Tensor(a!)",
-            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
             "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
             "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
             "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
@@ -913,6 +913,39 @@
             cleanup_op("mylib::foo")
             del lib
 
+    def test_auto_functionalize_tensorlist(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            @torch._dynamo.disable
+            def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
+                for o in out:
+                    o.copy_(all_gather_output)
+
+            def f(all_gather_output, all_gather_input_split_sizes, dim, out):
+                torch.ops.mylib.foo(
+                    all_gather_output, all_gather_input_split_sizes, dim, out
+                )
+
+            a = torch.ones(4)
+            b = [2, 3]
+            c = 0
+            d = [torch.empty(4) for _ in range(2)]
+            orig_args = (a, b, c, d)
+
+            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
+
+            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            f(*eager_args)
+            self.assertEqual(compiled_args, eager_args)
+
     def test_shape_int_inplace_binops(self):
         def fn(x):
             p = x.shape[0]
diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py
index 0632d14..09598bf 100644
--- a/test/inductor/test_torchinductor.py
+++ b/test/inductor/test_torchinductor.py
@@ -7512,7 +7512,7 @@
 
             def f(a, b):
                 torch.ops.mylib.inplace_(a, b)
-                return ()
+                return None
 
             a = torch.tensor([0.0, 1.0, 2])
             b = [torch.tensor([2.0, 3.0, 5.0]), torch.tensor([1.0, 4.0, 6.0])]
@@ -7521,11 +7521,26 @@
             mod = make_fx(f)(*cloned_args)
             cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
 
-            with self.assertRaisesRegex(
-                torch._inductor.exc.LoweringException,
-                "NYI: Can't generate FallbackKernel",
-            ):
-                compiled_f = compile_fx_inner(mod, cloned_args)
+            compiled_f = compile_fx_inner(mod, cloned_args)
+
+        @torch.library.custom_op("mylib::sin_out", mutates_args={"outs"})
+        def sin_out(x: torch.Tensor, outs: typing.List[torch.Tensor]) -> None:
+            x_np = x.numpy()
+            assert len(outs) == 2
+            out_np0 = out[0].numpy()
+            out_np1 = out[1].numpy()
+            np.sin(x_np, out=out_np0)
+            np.sin(x_np, out=out_np1)
+
+        @torch.compile
+        def g(x):
+            outs = [torch.empty_like(x) for _ in range(2)]
+            sin_out(x, outs)
+            return outs
+
+        x = torch.randn(3)
+        out = [torch.empty_like(x) for _ in range(2)]
+        y = g(x)
 
     @expectedFailureXPU
     def test_functionalize_rng_wrappers(self):
diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py
index 34624d7..02cf2ab 100644
--- a/torch/_functorch/_aot_autograd/functional_utils.py
+++ b/torch/_functorch/_aot_autograd/functional_utils.py
@@ -416,10 +416,6 @@
                     ), f"n={str(n)}, n.args[0]={str(n.args[0])}, placeholders={str(placeholders)}, graph={str(fx_g)}"
                     placeholders.remove(n.args[0])
                 mutation_count += 1
-            elif n.target in [torch.ops.aten.split_with_sizes_copy.out]:
-                # These are mutation ops that can show up in the middle of the graph,
-                # because they are ops that we explicitly do **not** functinoalize
-                continue
             else:
                 assert (
                     not n.target._schema.is_mutable
diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py
index 77b548b..aa3d449 100644
--- a/torch/_higher_order_ops/auto_functionalize.py
+++ b/torch/_higher_order_ops/auto_functionalize.py
@@ -93,8 +93,13 @@
             and type(arg.type.getElementType()) is torch.TensorType
         ):
             continue
+        if (
+            type(arg.type) is torch.ListType
+            and type(arg.type.getElementType()) is torch.TensorType
+        ):
+            continue
         # Not yet supported: other Tensor types. This includes things like
-        # Tensor[], Tensor?[], Tensor[]?.
+        # Tensor?[], Tensor[]?.
         return False
 
     if len(schema.returns) == 1 and isinstance(schema.returns[0].type, torch.NoneType):
@@ -129,7 +134,9 @@
             new_kwargs[name] = kwargs[name]
         else:
             new_kwargs[name] = (
-                clone_preserve_strides(kwargs[name])
+                [clone_preserve_strides(x) for x in kwargs[name]]
+                if kwargs[name] is not None and isinstance(kwargs[name], list)
+                else clone_preserve_strides(kwargs[name])
                 if kwargs[name] is not None
                 else None
             )
@@ -250,11 +257,27 @@
         # Can be None if input was `Tensor(a!)?`
         if unwrapped_out is None:
             continue
-        assert isinstance(unwrapped_out, torch.Tensor)
+
+        # We only handle Tensor or List[Tensor] here for now.
+        def sync_update(o, orig_arg):
+            ctx.replace(orig_arg, o)
+            ctx.commit_update(orig_arg)
+            ctx.sync(orig_arg)
+
         orig_arg = normalized_kwargs[name]
-        ctx.replace(orig_arg, unwrapped_out)
-        ctx.commit_update(orig_arg)
-        ctx.sync(orig_arg)
+
+        if isinstance(unwrapped_out, torch.Tensor):
+            sync_update(unwrapped_out, orig_arg)
+        elif isinstance(unwrapped_out, list) and all(
+            isinstance(o, torch.Tensor) for o in unwrapped_out
+        ):
+            assert len(orig_arg) == len(unwrapped_out)
+            for orig_a, o in zip(orig_arg, unwrapped_out):
+                sync_update(o, orig_a)
+        else:
+            raise RuntimeError(
+                f"unsupported type for auto-functionalization: {unwrapped_out}"
+            )
 
     return ctx.wrap_tensors(unwrapped_actual_out)  # type: ignore[arg-type]
 
diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py
index 5ca8bef..28020cf 100644
--- a/torch/_inductor/ir.py
+++ b/torch/_inductor/ir.py
@@ -5454,6 +5454,9 @@
             is_optional_tensor = isinstance(
                 info.type, torch.OptionalType
             ) and isinstance(info.type.getElementType(), torch.TensorType)
+            is_list_tensor = isinstance(info.type, torch.ListType) and isinstance(
+                info.type.getElementType(), torch.TensorType
+            )
             if is_optional_tensor or isinstance(info.type, torch.TensorType):
                 # PyTorch also accepts None and scalar types for args marked as "Tensor".
                 # We're not going to check all of them here.
@@ -5463,12 +5466,15 @@
                 return
             if info.alias_info is None:
                 return
-            # can_auto_functionalize already filters out mutable List[Tensor].
-            # We can support this in the future, but this is very uncommon.
-            assert isinstance(info.type, torch.TensorType) or is_optional_tensor
-            self.alias_names.append(arg.get_name())
-            if info.alias_info.is_write:
-                mark_node_as_mutating(self, arg)
+            if is_list_tensor:
+                for tensor_arg in arg:
+                    self.alias_names.append(tensor_arg.get_name())
+                    mark_node_as_mutating(self, tensor_arg)
+            else:
+                assert isinstance(info.type, torch.TensorType) or is_optional_tensor
+                self.alias_names.append(arg.get_name())
+                if info.alias_info.is_write:
+                    mark_node_as_mutating(self, arg)
 
         for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
             handle_aliasing_and_mutation(info, arg)
diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py
index 764022a..ac92757 100644
--- a/torch/_inductor/lowering.py
+++ b/torch/_inductor/lowering.py
@@ -292,14 +292,13 @@
             unpacked = True
             args = args[0]
 
-        # explicitly assert for "out=" ops for better error messages
-        assert not any(
-            x == "out" for x in kwargs.keys()
-        ), "out= ops aren't yet supported"
         # kwargs tensors not supported yet unless it's a fallback op
-        assert not any(isinstance(x, TensorBox) for x in kwargs.values()) or all(
-            fn in fallbacks for fn in aten_fn
-        )
+        if not all(fn in fallbacks for fn in aten_fn):
+            assert not any(isinstance(x, TensorBox) for x in kwargs.values())
+            # explicitly assert for "out=" ops for better error messages
+            assert not any(
+                x == "out" for x in kwargs.keys()
+            ), "out= ops aren't yet supported"
 
         args = transform_args(
             args, broadcast, type_promotion_kind, convert_input_to_bool
diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
index fdbb461..33de0d7 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py
@@ -64,6 +64,7 @@
 
 
 @torch.library.impl(lib, "all_gather_copy_in", "CUDA")
+@torch.library.impl(lib, "all_gather_copy_in", "CPU")
 def all_gather_copy_in_cuda(
     all_gather_inputs: List[torch.Tensor],
     inp_split_sizes: List[int],
@@ -92,6 +93,7 @@
 
 @torch.library.impl(lib, "split_with_sizes_copy", "Meta")
 @torch.library.impl(lib, "split_with_sizes_copy", "CUDA")
+@torch.library.impl(lib, "split_with_sizes_copy", "CPU")
 def split_with_sizes_copy(
     all_gather_output: torch.Tensor,
     all_gather_input_split_sizes: List[int],
@@ -103,20 +105,6 @@
     )
 
 
-@torch.library.impl(lib, "split_with_sizes_copy", "Functionalize")
-def split_with_sizes_copy_functionalize(
-    all_gather_output: torch.Tensor,
-    all_gather_input_split_sizes: List[int],
-    dim: int,
-    out: List[torch.Tensor],
-) -> None:
-    ag_output_elem = torch._from_functional_tensor(all_gather_output)
-    out_elem = [torch._from_functional_tensor(x) for x in out]
-    torch.split_with_sizes_copy(
-        ag_output_elem, all_gather_input_split_sizes, dim=dim, out=out_elem
-    )
-
-
 lib.define(
     "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()"
 )
@@ -124,6 +112,7 @@
 
 @torch.library.impl(lib, "chunk_cat", "Meta")
 @torch.library.impl(lib, "chunk_cat", "CUDA")
+@torch.library.impl(lib, "chunk_cat", "CPU")
 def chunk_cat(
     tensors: List[torch.Tensor],
     dim: int,
diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py
index 955beb8..a8f6f55 100644
--- a/torch/distributed/_composable/fsdp/_fsdp_param.py
+++ b/torch/distributed/_composable/fsdp/_fsdp_param.py
@@ -74,8 +74,50 @@
     tensor.set_(data)
 
 
+"""
+[Note: Avoiding functionalization for fsdp.set_ and inductor.resize_storage_bytes_(0)]
+
+Currently we don't functionalize `fsdp.set_` op or `inductor.resize_storage_bytes_(0)` op
+(i.e. they show up as a mutation op in the middle of the AOT joint graph).
+
+Reason:
+Traceable FSDP2 compiled autograd BWD graph have the following traits:
+(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors).
+(2) One of them is mutated (set_ and resize_(0) to handle the all-gathered param).
+(3) They are both subclasses.
+The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing).
+So this doesn't work at all for Traceable FSDP2.
+
+The compromise we use is to avoid functionalization for the FSDP2 set_ and resize_(0) ops.
+This avoids the problem above, because from AOTAutograd point-of-view there are no mutations
+that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.)
+
+We can avoid this functionalization because:
+(1) The nn.Parameter is never used before its .set_() is called in eager code (i.e. no alias of it is created),
+so it's safe to call .set_() in the middle of the graph to swap out its storage and start using the nn.Parameter downstream.
+(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops.
+So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay
+(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore).
+
+Q: But doesn't the torch.compile stack have the "functional graph" assumption in many places?
+A: Yes - this is WIP but we will try to get back to functional graph as early as possible in the lowering process.
+Specifically, we believe we can move both .set_ and .resize_(0) ops to end of graph in AOT joint graph before partitioner
+(i.e. effectively "re-functionalizing" those ops). Put it in another way, we avoid functionalization for those two ops just to
+make AOTAutograd alias analysis happy, and as soon as we are past that point, we "re-functionalize" the graph.
+This requires a custom FX pass but we believe it's not hard to write and maintain.
+
+Q: What's the importance of partitioner not saving views of nn.Parameter as FWD saved tensors?
+A: This is critical: we do want to save FWD nn.Parameter graph input (instead of its view) for BWD use,
+so that downstream ops in BWD graph uses the post-`.set_` nn.Parameter instead of any of its saved views as input.
+This is because .set_ will not update any of the nn.Parameter's views, so BWD downstream ops must use the original
+nn.Parameter in order to see the result of .set_.
+"""
+
+
 @torch.library.impl(lib, "set_", "Functionalize")
 def set__functionalize(tensor, data):
+    torch._sync(tensor)
+    torch._sync(data)
     tensor_inner = torch._from_functional_tensor(tensor)
     data_inner = torch._from_functional_tensor(data)
     tensor_inner.set_(data_inner)  # type: ignore[call-overload]
diff --git a/torch/fx/node.py b/torch/fx/node.py
index cabaf98..0105025 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -56,7 +56,6 @@
     _ops.profiler._record_function_enter_new,
     _ops.profiler._record_function_exit,
     _ops.inductor.accumulate_grad_.default,
-    _ops.aten.split_with_sizes_copy.out,
 } | _side_effectful_need_to_be_preserved_pre_dispatch
 if hasattr(_ops.inductor, "resize_storage_bytes_"):
     _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default)