Move auto functionalize tests in their own test file (#134834)

title + use `with torch.library._scoped_library as lib` when needed.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134834
Approved by: https://github.com/zou3519
ghstack dependencies: #134831
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
index 1bfb15d..0057383 100644
--- a/test/dynamo/test_misc.py
+++ b/test/dynamo/test_misc.py
@@ -112,16 +112,6 @@
     return wrapper
 
 
-def cleanup_op(opname):
-    ns, name = opname.split("::")
-    if not hasattr(torch.ops, ns):
-        return
-    actual_ns = getattr(torch.ops, ns)
-    if not hasattr(actual_ns, name):
-        return
-    delattr(actual_ns, name)
-
-
 class MyPickledModule(torch.nn.Module):
     def __init__(self, z):
         super().__init__()
@@ -509,8 +499,7 @@
 
     @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
     def test_pt2_compliant_ops_are_allowed(self):
-        lib = torch.library.Library("mylib", "FRAGMENT")
-        try:
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
             torch.library.define(
                 "mylib::bar",
                 "(Tensor x) -> Tensor",
@@ -538,14 +527,10 @@
 
             optimized_g = torch._dynamo.optimize(counts, nopython=True)(f)
             _ = optimized_g(x)
-        finally:
-            cleanup_op("mylib::bar")
-            del lib
 
     @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
     def test_non_pt2_compliant_ops_graph_break(self):
-        lib = torch.library.Library("mylib", "FRAGMENT")
-        try:
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
             torch.library.define("mylib::bar2", "(Tensor x) -> Tensor", lib=lib)
             torch.library.impl(
                 "mylib::bar2", "CompositeImplicitAutograd", torch.sin, lib=lib
@@ -574,14 +559,10 @@
             ):
                 optimized_g = torch._dynamo.optimize(counts, nopython=True)(f)
                 y = optimized_g(x)
-        finally:
-            cleanup_op("mylib::bar2")
-            del lib
 
     @torch._dynamo.config.patch(only_allow_pt2_compliant_ops=True)
     def test_pt2_compliant_overload(self):
-        lib = torch.library.Library("mylib", "FRAGMENT")
-        try:
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
             torch.library.define(
                 "mylib::bar3.tensor",
                 "(Tensor x) -> Tensor",
@@ -630,70 +611,6 @@
             with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "failed to"):
                 y = optimized_h(x)
 
-        finally:
-            cleanup_op("mylib::bar3")
-            del lib
-
-    def test_auto_functionalize_can_with_default(self):
-        lib = torch.library.Library("mylib", "FRAGMENT")
-        torch.library.define(
-            "mylib::foo",
-            "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
-            tags=torch.Tag.pt2_compliant_tag,
-            lib=lib,
-        )
-
-        @torch.library.impl("mylib::foo", "cpu", lib=lib)
-        def foo_impl(a, b, c=None, d=None, e=-1):
-            a + b
-            return
-
-        def f(a, mode):
-            return torch.ops.mylib.foo(
-                a,
-                0,
-            )
-
-        a = torch.tensor([10, 10, 10], dtype=torch.int64)
-
-        torch.compile(f)(a, 0)
-
-        cleanup_op("mylib::foo")
-        del lib
-
-    def test_auto_functionalize_can_with_none_return(self):
-        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
-            lib.define("foo(Tensor x, Tensor(a!) out) -> None")
-
-            def foo_impl(x, out):
-                out.copy_(x)
-
-            lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
-            x = torch.randn(3)
-            out = torch.zeros(3)
-
-            @torch.compile
-            def f(x, out):
-                torch.ops.mylib.foo(x, out)
-
-            f(x, out)
-
-    def test_auto_functionalize_self_as_mutate_arg(self):
-        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
-            lib.define("foo(Tensor(a!) self) -> None")
-
-            def foo_impl(self: torch.Tensor) -> None:
-                self.sin_()
-
-            x = torch.randn(3)
-            lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
-
-            @torch.compile(backend="inductor", fullgraph=True)
-            def f(x):
-                torch.ops.mylib.foo(x)
-
-            f(x)
-
     def test_user_defined_setattr1(self):
         @torch.compile(backend="eager", fullgraph=True)
         def fn(obj):
@@ -742,8 +659,7 @@
         self.assertEqual(cnt.frame_count, 2)
 
     def test_generate_trivial_abstract_impl(self):
-        try:
-            lib = torch.library.Library("mylib", "FRAGMENT")
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
             torch.library.define(
                 "mylib::foo",
                 "(Tensor x, Tensor[] y, Tensor(a!)? z, SymInt w) -> ()",
@@ -768,291 +684,6 @@
 
             output = torch.compile(f, backend="eager", fullgraph=True)(*args)
             self.assertEqual(output, None)
-        finally:
-            cleanup_op("mylib::foo")
-            del lib
-
-    def test_can_auto_functionalize(self):
-        from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
-
-        expected_true = [
-            "(Tensor(a!) x) -> ()",
-            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
-            "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
-            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
-            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
-            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
-        ]
-        expected_false = [
-            "(Tensor x) -> ()",
-            "(Tensor(a) x) -> Tensor(a)",
-            "(Tensor(a!) x) -> Tensor(a!)",
-            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
-            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
-            "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
-            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
-        ]
-        for schema in expected_true:
-            try:
-                lib = torch.library.Library("mylib", "FRAGMENT")
-                torch.library.define("mylib::a", schema, lib=lib)
-                self.assertTrue(
-                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
-                )
-                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
-            finally:
-                cleanup_op("mylib::a")
-                del lib
-        for schema in expected_false:
-            try:
-                lib = torch.library.Library("mylib", "FRAGMENT")
-                torch.library.define("mylib::a", schema, lib=lib)
-                self.assertFalse(
-                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
-                )
-                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
-            finally:
-                cleanup_op("mylib::a")
-                del lib
-
-    def test_auto_functionalize(self):
-        try:
-            lib = torch.library.Library("mylib", "FRAGMENT")
-            torch.library.define(
-                "mylib::foo",
-                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
-                tags=torch.Tag.pt2_compliant_tag,
-                lib=lib,
-            )
-
-            @torch.library.impl("mylib::foo", "cpu", lib=lib)
-            @torch._dynamo.disable
-            def foo_impl(x, y, z, w, n):
-                x.add_(y[0] + w)
-                z.add_(y[1] + n)
-
-            def f(x, y, z, n):
-                torch.ops.mylib.foo(x, y, z, 2, n)
-
-            x = torch.randn(3)
-            y = (torch.randn(3), torch.randn(3))
-            z = torch.randn(3)
-            n = torch.randn(3)
-            orig_args = (x, y, z, n)
-
-            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-
-            log_stream, ctx = logs_to_string(
-                "torch._inductor.compile_fx", "post_grad_graphs"
-            )
-            with ctx():
-                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
-
-            post_grad_graphs = "\n".join(
-                log_stream.getvalue().strip().split("\n")[3:]
-            ).strip()
-
-            # Check the graph under static shapes
-            if torch._dynamo.config.assume_static_by_default:
-                self.assertExpectedInline(
-                    post_grad_graphs,
-                    """\
-def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
-        # No stacktrace found for following nodes
-        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
-        return ()""",
-                )
-
-            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            f(*eager_args)
-            self.assertEqual(compiled_args, eager_args)
-        finally:
-            cleanup_op("mylib::foo")
-            del lib
-
-    def test_auto_functionalize_with_returns(self):
-        try:
-            lib = torch.library.Library("mylib", "FRAGMENT")
-            torch.library.define(
-                "mylib::foo",
-                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
-                tags=torch.Tag.pt2_compliant_tag,
-                lib=lib,
-            )
-
-            @torch.library.impl("mylib::foo", "cpu", lib=lib)
-            @torch._dynamo.disable
-            def foo_impl(x, y, z, w, n):
-                x.add_(y[0] + w)
-                z.add_(y[1] + n)
-                return y[0] + w, y[1] + n
-
-            @torch.library.impl_abstract("mylib::foo", lib=lib)
-            def foo_abstract(x, y, z, w, n):
-                return y[0] + w, y[1] + n
-
-            def f(x, y, z, n):
-                return torch.ops.mylib.foo(x, y, z, 2, n)
-
-            x = torch.randn(3)
-            y = (torch.randn(3), torch.randn(3))
-            z = torch.randn(3)
-            n = torch.randn(3)
-            orig_args = (x, y, z, n)
-
-            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            log_stream, ctx = logs_to_string(
-                "torch._inductor.compile_fx", "post_grad_graphs"
-            )
-            with ctx():
-                compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
-                    *compiled_args
-                )
-
-            if torch._dynamo.config.assume_static_by_default:
-                post_grad_graphs = "\n".join(
-                    log_stream.getvalue().strip().split("\n")[3:]
-                ).strip()
-                self.assertExpectedInline(
-                    post_grad_graphs,
-                    """\
-def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
-        # No stacktrace found for following nodes
-        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
-        getitem_4: "f32[3][1]cpu" = foo_default[0]
-        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
-        return (getitem_4, getitem_5)""",
-                )
-
-            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            eager_out = f(*eager_args)
-            self.assertEqual(compiled_args, eager_args)
-            self.assertEqual(compiled_out, eager_out)
-        finally:
-            cleanup_op("mylib::foo")
-            del lib
-
-    def test_auto_functionalize_on_view(self):
-        try:
-            lib = torch.library.Library("mylib", "FRAGMENT")
-            torch.library.define(
-                "mylib::foo",
-                "(Tensor(a!) x) -> ()",
-                tags=torch.Tag.pt2_compliant_tag,
-                lib=lib,
-            )
-
-            @torch.library.impl("mylib::foo", "cpu", lib=lib)
-            @torch._dynamo.disable
-            def foo_impl(x):
-                x_np = x.detach().numpy()  # view
-                np.sin(x_np, out=x_np)
-                return
-
-            x = torch.randn(3)
-            expected = x.sin()
-            torch.ops.mylib.foo(x)
-            assert torch.allclose(x, expected)
-
-            @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
-            def f(x):
-                x = x.clone()
-                y = x[:]
-                torch.ops.mylib.foo(y)
-                return x
-
-            y = f(x)
-            self.assertEqual(y, x.sin())
-        finally:
-            cleanup_op("mylib::foo")
-            del lib
-
-    def test_auto_functionalize_optional(self):
-        try:
-            lib = torch.library.Library("mylib", "FRAGMENT")
-            torch.library.define(
-                "mylib::foo",
-                "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
-                tags=torch.Tag.pt2_compliant_tag,
-                lib=lib,
-            )
-
-            @torch.library.impl("mylib::foo", "cpu", lib=lib)
-            @torch._dynamo.disable
-            def foo_impl(x, y, z, w, n):
-                if x is not None:
-                    x.add_(y[0] + w)
-                if z is not None:
-                    z.add_(y[1] + n)
-
-            def f(x, y, z, n):
-                torch.ops.mylib.foo(x, y, z, 2, n)
-
-            x = None
-            y = (torch.randn(3), torch.randn(3))
-            z = torch.randn(3)
-            n = torch.randn(3)
-            orig_args = (x, y, z, n)
-
-            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            log_stream, ctx = logs_to_string(
-                "torch._inductor.compile_fx", "post_grad_graphs"
-            )
-            with ctx():
-                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
-
-            if torch._dynamo.config.assume_static_by_default:
-                post_grad_graphs = "\n".join(
-                    log_stream.getvalue().strip().split("\n")[3:]
-                ).strip()
-                self.assertExpectedInline(
-                    post_grad_graphs,
-                    """\
-def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
-        # No stacktrace found for following nodes
-        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
-        return ()""",
-                )
-
-            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            f(*eager_args)
-            self.assertEqual(compiled_args, eager_args)
-        finally:
-            cleanup_op("mylib::foo")
-            del lib
-
-    def test_auto_functionalize_tensorlist(self):
-        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
-            torch.library.define(
-                "mylib::foo",
-                "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
-                tags=torch.Tag.pt2_compliant_tag,
-                lib=lib,
-            )
-
-            @torch.library.impl("mylib::foo", "cpu", lib=lib)
-            @torch._dynamo.disable
-            def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
-                for o in out:
-                    o.copy_(all_gather_output)
-
-            def f(all_gather_output, all_gather_input_split_sizes, dim, out):
-                torch.ops.mylib.foo(
-                    all_gather_output, all_gather_input_split_sizes, dim, out
-                )
-
-            a = torch.ones(4)
-            b = [2, 3]
-            c = 0
-            d = [torch.empty(4) for _ in range(2)]
-            orig_args = (a, b, c, d)
-
-            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
-
-            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
-            f(*eager_args)
-            self.assertEqual(compiled_args, eager_args)
 
     def test_shape_int_inplace_binops(self):
         def fn(x):
@@ -8961,25 +8592,6 @@
     @torch._dynamo.config.patch(
         capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
     )
-    def test_unbacked_auto_functionalize_op(self):
-        @torch.library.custom_op(
-            "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
-        )
-        def mk_image(decoder: Tensor) -> Tensor:
-            return torch.randn(2, 3, 4, 5)
-
-        @torch.library.register_fake("mylib::mk_image")
-        def _(decoder: Tensor) -> Tensor:
-            image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
-            return torch.empty(image_size)
-
-        @torch.compile(fullgraph=True)
-        def f(x):
-            return torch.ops.mylib.mk_image.default(x)
-
-        x = torch.zeros(100, dtype=torch.int64)
-        f(x)
-
     def test_out_variant_custom_op(self):
         with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
             lib.define(
diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py
new file mode 100644
index 0000000..a007c34
--- /dev/null
+++ b/test/inductor/test_auto_functionalize.py
@@ -0,0 +1,363 @@
+# Owner(s): ["module: functionalization"]
+
+import numpy as np
+
+import torch
+import torch._dynamo.testing
+import torch._inductor.test_case
+import torch.onnx.operators
+import torch.utils._pytree as pytree
+import torch.utils.cpp_extension
+from torch import Tensor
+from torch.testing._internal.logging_utils import logs_to_string
+
+
+class AutoFunctionalizeTests(torch._inductor.test_case.TestCase):
+    def test_auto_functionalize_can_with_default(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor a, int b, Tensor(d!)? c=None, Tensor? d=None, int e=-1) -> ()",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            def foo_impl(a, b, c=None, d=None, e=-1):
+                a + b
+                return
+
+            def f(a, mode):
+                return torch.ops.mylib.foo(
+                    a,
+                    0,
+                )
+
+            a = torch.tensor([10, 10, 10], dtype=torch.int64)
+
+            torch.compile(f)(a, 0)
+
+    def test_auto_functionalize_can_with_none_return(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            lib.define("foo(Tensor x, Tensor(a!) out) -> None")
+
+            def foo_impl(x, out):
+                out.copy_(x)
+
+            lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
+            x = torch.randn(3)
+            out = torch.zeros(3)
+
+            @torch.compile
+            def f(x, out):
+                torch.ops.mylib.foo(x, out)
+
+            f(x, out)
+
+    def test_auto_functionalize_self_as_mutate_arg(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            lib.define("foo(Tensor(a!) self) -> None")
+
+            def foo_impl(self: torch.Tensor) -> None:
+                self.sin_()
+
+            x = torch.randn(3)
+            lib.impl("foo", foo_impl, "CompositeExplicitAutograd")
+
+            @torch.compile(backend="inductor", fullgraph=True)
+            def f(x):
+                torch.ops.mylib.foo(x)
+
+            f(x)
+
+    def test_auto_functionalize_tensorlist(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim, Tensor(a!)[] out) -> ()",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            @torch._dynamo.disable
+            def foo_impl(all_gather_output, all_gather_input_split_sizes, dim, out):
+                for o in out:
+                    o.copy_(all_gather_output)
+
+            def f(all_gather_output, all_gather_input_split_sizes, dim, out):
+                torch.ops.mylib.foo(
+                    all_gather_output, all_gather_input_split_sizes, dim, out
+                )
+
+            a = torch.ones(4)
+            b = [2, 3]
+            c = 0
+            d = [torch.empty(4) for _ in range(2)]
+            orig_args = (a, b, c, d)
+
+            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
+
+            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            f(*eager_args)
+            self.assertEqual(compiled_args, eager_args)
+
+    def test_can_auto_functionalize(self):
+        from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
+
+        expected_true = [
+            "(Tensor(a!) x) -> ()",
+            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
+            "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> ()",
+            "(Tensor(a!) x, Tensor y, Tensor(b!)[] z, SymInt w) -> ()",
+            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor",
+            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor)",
+        ]
+        expected_false = [
+            "(Tensor x) -> ()",
+            "(Tensor(a) x) -> Tensor(a)",
+            "(Tensor(a!) x) -> Tensor(a!)",
+            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> Tensor(a)",
+            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
+            "(Tensor(a) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor(a))",
+            "(Tensor(a!) x, Tensor y, Tensor(b!) z, SymInt w, Tensor(c!)? n) -> (Tensor, Tensor[])",
+        ]
+        for schema in expected_true:
+            with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+                torch.library.define("mylib::a", schema, lib=lib)
+                self.assertTrue(
+                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
+                )
+                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
+
+        for schema in expected_false:
+            with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+                torch.library.define("mylib::a", schema, lib=lib)
+                self.assertFalse(
+                    can_auto_functionalize(torch.ops.mylib.a.default), msg=schema
+                )
+                self.assertFalse(can_auto_functionalize(torch.ops.mylib.a))
+
+    def test_auto_functionalize(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            @torch._dynamo.disable
+            def foo_impl(x, y, z, w, n):
+                x.add_(y[0] + w)
+                z.add_(y[1] + n)
+
+            def f(x, y, z, n):
+                torch.ops.mylib.foo(x, y, z, 2, n)
+
+            x = torch.randn(3)
+            y = (torch.randn(3), torch.randn(3))
+            z = torch.randn(3)
+            n = torch.randn(3)
+            orig_args = (x, y, z, n)
+
+            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+
+            log_stream, ctx = logs_to_string(
+                "torch._inductor.compile_fx", "post_grad_graphs"
+            )
+            with ctx():
+                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
+
+            post_grad_graphs = "\n".join(
+                log_stream.getvalue().strip().split("\n")[3:]
+            ).strip()
+
+            # Check the graph under static shapes
+            if torch._dynamo.config.assume_static_by_default:
+                self.assertExpectedInline(
+                    post_grad_graphs,
+                    """\
+def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \
+"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
+        # No stacktrace found for following nodes
+        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  arg4_1 = arg2_1 = \
+arg3_1 = arg1_1 = arg0_1 = foo_default = None
+        return ()""",
+                )
+
+            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            f(*eager_args)
+            self.assertEqual(compiled_args, eager_args)
+
+    def test_auto_functionalize_with_returns(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            @torch._dynamo.disable
+            def foo_impl(x, y, z, w, n):
+                x.add_(y[0] + w)
+                z.add_(y[1] + n)
+                return y[0] + w, y[1] + n
+
+            @torch.library.impl_abstract("mylib::foo", lib=lib)
+            def foo_abstract(x, y, z, w, n):
+                return y[0] + w, y[1] + n
+
+            def f(x, y, z, n):
+                return torch.ops.mylib.foo(x, y, z, 2, n)
+
+            x = torch.randn(3)
+            y = (torch.randn(3), torch.randn(3))
+            z = torch.randn(3)
+            n = torch.randn(3)
+            orig_args = (x, y, z, n)
+
+            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            log_stream, ctx = logs_to_string(
+                "torch._inductor.compile_fx", "post_grad_graphs"
+            )
+            with ctx():
+                compiled_out = torch.compile(f, backend="inductor", fullgraph=True)(
+                    *compiled_args
+                )
+
+            if torch._dynamo.config.assume_static_by_default:
+                post_grad_graphs = "\n".join(
+                    log_stream.getvalue().strip().split("\n")[3:]
+                ).strip()
+                self.assertExpectedInline(
+                    post_grad_graphs,
+                    """\
+def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", \
+arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"):
+        # No stacktrace found for following nodes
+        foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  \
+arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None
+        getitem_4: "f32[3][1]cpu" = foo_default[0]
+        getitem_5: "f32[3][1]cpu" = foo_default[1];  foo_default = None
+        return (getitem_4, getitem_5)""",
+                )
+
+            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            eager_out = f(*eager_args)
+            self.assertEqual(compiled_args, eager_args)
+            self.assertEqual(compiled_out, eager_out)
+
+    def test_auto_functionalize_on_view(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor(a!) x) -> ()",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            @torch._dynamo.disable
+            def foo_impl(x):
+                x_np = x.detach().numpy()  # view
+                np.sin(x_np, out=x_np)
+                return
+
+            x = torch.randn(3)
+            expected = x.sin()
+            torch.ops.mylib.foo(x)
+            assert torch.allclose(x, expected)
+
+            @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True)
+            def f(x):
+                x = x.clone()
+                y = x[:]
+                torch.ops.mylib.foo(y)
+                return x
+
+            y = f(x)
+            self.assertEqual(y, x.sin())
+
+    def test_auto_functionalize_optional(self):
+        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
+            torch.library.define(
+                "mylib::foo",
+                "(Tensor(a!)? x, Tensor[] y, Tensor(b!)? z, SymInt w, Tensor n) -> ()",
+                tags=torch.Tag.pt2_compliant_tag,
+                lib=lib,
+            )
+
+            @torch.library.impl("mylib::foo", "cpu", lib=lib)
+            @torch._dynamo.disable
+            def foo_impl(x, y, z, w, n):
+                if x is not None:
+                    x.add_(y[0] + w)
+                if z is not None:
+                    z.add_(y[1] + n)
+
+            def f(x, y, z, n):
+                torch.ops.mylib.foo(x, y, z, 2, n)
+
+            x = None
+            y = (torch.randn(3), torch.randn(3))
+            z = torch.randn(3)
+            n = torch.randn(3)
+            orig_args = (x, y, z, n)
+
+            compiled_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            log_stream, ctx = logs_to_string(
+                "torch._inductor.compile_fx", "post_grad_graphs"
+            )
+            with ctx():
+                torch.compile(f, backend="inductor", fullgraph=True)(*compiled_args)
+
+            if torch._dynamo.config.assume_static_by_default:
+                post_grad_graphs = "\n".join(
+                    log_stream.getvalue().strip().split("\n")[3:]
+                ).strip()
+                self.assertExpectedInline(
+                    post_grad_graphs,
+                    """\
+def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu"):
+        # No stacktrace found for following nodes
+        foo_default = torch.ops.mylib.foo.default(None, [arg2_1, arg3_1], arg1_1, 2, arg0_1);  \
+arg2_1 = arg3_1 = arg1_1 = arg0_1 = foo_default = None
+        return ()""",
+                )
+
+            eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args)
+            f(*eager_args)
+            self.assertEqual(compiled_args, eager_args)
+
+    @torch._dynamo.config.patch(
+        capture_scalar_outputs=True, capture_dynamic_output_shape_ops=True
+    )
+    def test_unbacked_auto_functionalize_op(self):
+        @torch.library.custom_op(
+            "mylib::mk_image", mutates_args=("decoder",), device_types=["cpu"]
+        )
+        def mk_image(decoder: Tensor) -> Tensor:
+            return torch.randn(2, 3, 4, 5)
+
+        @torch.library.register_fake("mylib::mk_image")
+        def _(decoder: Tensor) -> Tensor:
+            image_size = [torch.library.get_ctx().new_dynamic_size() for _ in range(4)]
+            return torch.empty(image_size)
+
+        @torch.compile(fullgraph=True)
+        def f(x):
+            return torch.ops.mylib.mk_image.default(x)
+
+        x = torch.zeros(100, dtype=torch.int64)
+        f(x)
+
+
+if __name__ == "__main__":
+    from torch._inductor.test_case import run_tests
+
+    run_tests()