# Owner(s): ["module: __torch_dispatch__"]

import tempfile
import unittest
from copy import deepcopy

import torch
from torch import SymInt
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.cuda.jiterator import _create_jit_fn
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.library import _scoped_library, fallthrough_kernel, impl, Library
from torch.testing._internal.common_utils import *  # noqa: F403
import logging
import sys

import torch._dynamo
from torch._C import DispatchKey, DispatchKeySet
from torch._custom_op.functional import register_functional_op
from torch.fx.experimental.proxy_tensor import make_fx
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.logging_tensor import (
    capture_logs,
    capture_logs_with_logging_tensor_mode,
    log_input,
    LoggingTensor,
    LoggingTensorMode,
    LoggingTensorReentrant,
)
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils import _pytree as pytree
from torch.utils._mode_utils import all_same_mode, no_dispatch
from torch.utils._python_dispatch import (
    _get_current_dispatch_mode,
    _get_current_dispatch_mode_stack,
    TorchDispatchMode,
)
from torch.utils._pytree import tree_map, tree_map_only


# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda
def _identity(x):
    return x


class TestDispatcherPythonBindings(TestCase):
    def test_call_boxed(self) -> None:
        sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
        x = torch.randn(3)
        y = torch._C._dispatch_call_boxed(sin, x)
        self.assertEqual(y, x.sin())


class TestPythonRegistration(TestCase):
    test_ns = "_test_python_registration"

    def tearDown(self):
        if hasattr(torch.ops, self.test_ns):
            del torch.ops._test_python_registration

    def test_override_aten_ops_with_multiple_libraries(self) -> None:
        x = torch.tensor([1, 2])
        with _scoped_library("aten", "IMPL") as my_lib2:
            with _scoped_library("aten", "IMPL") as my_lib1:
                # Example 1
                def my_neg(*args, **kwargs):
                    return args[0]._neg_view()

                # Now we are secretly making the operator a view op so autograd needs to know how
                # to handle it
                my_lib1.impl("neg", my_neg, "AutogradCPU")

                self.assertTrue(torch.neg(x).is_neg())

                # RuntimeError: impl("aten::neg", ...):
                # Explicitly provided namespace (aten) in operator name does not match ...
                with self.assertRaisesRegex(
                    RuntimeError, "operator name does not match namespace"
                ):
                    with _scoped_library("foo", "DEF") as my_lib3:
                        my_lib3.define("neg(Tensor self) -> Tensor")
                        my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")

                # Example 2
                def my_mul(*args, **kwargs):
                    return torch.zeros_like(args[0])

                # torch.ops.aten.mul.Tensor
                my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")

                y = torch._efficientzerotensor(2)
                self.assertFalse(torch.mul(x, y)._is_zerotensor())

                # Assert that a user can't override the behavior of a (ns, op, dispatch_key)
                # combination if someone overrided the behavior for the same before them
                with self.assertRaisesRegex(
                    RuntimeError, "already a kernel registered from python"
                ):
                    my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")

            # Validate that lib2 is not affected by removing lib1
            self.assertFalse(torch.mul(x, y)._is_zerotensor())

        # Validate that the old behavior is restored for neg and mul
        self.assertFalse(torch.neg(x).is_neg())
        self.assertTrue(torch.mul(x, y)._is_zerotensor())

    def test_error_if_fn_not_callable(self):
        with self.assertRaisesRegex(
            TypeError, "Input function is required to be a callable"
        ):
            with _scoped_library("aten", "IMPL") as my_lib:
                my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")

    def test_finalizer(self):
        impls_refcnt = sys.getrefcount(torch.library._impls)
        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
        lib.define("foo123(Tensor x) -> Tensor")

        # 1 for `lib`, 1 for sys.getrefcount
        self.assertEqual(sys.getrefcount(lib), 2)
        # We gained an additional reference that gets cleared when the finalizer runs
        self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
        # 1 for `lib`
        # 1 for the finalizer
        # 1 for sys.getrefcount
        self.assertEqual(sys.getrefcount(lib._op_impls), 3)

        def foo123(x):
            pass

        lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
        key = f"{self.test_ns}/foo123/CPU"
        self.assertTrue(key in torch.library._impls)

        saved_op_impls = lib._op_impls

        # del will definitely work if the following passes
        self.assertEqual(sys.getrefcount(lib), 2)
        del lib

        # 1 for saved_op_impls
        # 1 for sys.getrefcount
        # This function should be the last user of lib._op_impls:
        # - lib should not have a reference anymore (it was del'ed)
        # - lib's finalizer should not have a reference anymore
        self.assertEqual(sys.getrefcount(saved_op_impls), 2)

        self.assertTrue(key not in torch.library._impls)

        # lib's finalizer should not have a reference anymore
        self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)

    def test_override_cpu_sum(self) -> None:
        # Example 1
        run = [False]

        def my_sum(*args, **kwargs):
            run[0] = True
            return args[0].clone()

        with _scoped_library("aten", "IMPL") as my_lib1:
            my_lib1.impl("aten::sum", my_sum, "CPU")
            x = torch.tensor([1, 2])
            self.assertEqual(torch.sum(x), x)
            self.assertTrue(run[0])
        # Validate that the old behavior is restored for sum
        self.assertEqual(torch.sum(x), torch.tensor(3))

    def test_override_cuda_with_jiterator(self) -> None:
        def override_where_cuda() -> None:
            # Example 1: Invert the behavior of where's condition input
            not_where_code_string = """
            template <typename T> T inverted_where(bool cond, T a, T b){
                return !cond ? a : b;
            }
            """
            jitted_where = _create_jit_fn(not_where_code_string)

            CALLED = [False]

            def inverted_where(*args, **kwargs):
                CALLED[0] = True
                return jitted_where(*args, **kwargs)

            # overriding where's cuda kernel with Jiterator generated kernel
            with _scoped_library("aten", "IMPL") as my_lib:
                my_lib.impl("aten::where.self", inverted_where, "CUDA")

                device = "cuda"
                cond = torch.tensor(
                    [True, True, False], device=device, dtype=torch.bool
                )
                x = torch.tensor([1, 2, 3], device=device)
                y = torch.tensor([-1, -2, -3], device=device)

                self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
                self.assertTrue(CALLED[0])

            # behavior restored after deregistration
            self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))

        def override_gelu_cuda() -> None:
            # Example 2: Use relu to approximate gelu for faster compute
            fastest_gelu_code_string = """
            template <typename T> T fast_gelu(T a){
                return a > 0 ? a : 0;
            }
            """
            jitted_gelu = _create_jit_fn(fastest_gelu_code_string)

            CALLED = [False]

            def fast_gelu(*args, **kwargs):
                CALLED[0] = True
                return jitted_gelu(*args, **kwargs)

            # overriding gelu's cuda kernel with Jiterator generated relu kernel
            with _scoped_library("aten", "IMPL") as my_lib:
                my_lib.impl("aten::gelu", fast_gelu, "CUDA")

                x = torch.rand([3, 3], device="cuda", dtype=torch.float)
                self.assertEqual(
                    torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
                )
                self.assertTrue(CALLED[0])

            # behavior restored after deregistration
            self.assertNotEqual(
                torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
            )

        def override_exp_cuda() -> None:
            # Example 3: Preventing exp from exploding for float16
            clipped_exp_code_string = """
            template <typename T> T clipped_exp(T a){
                return a > T(10.0) ? T(22026.4657948) : exp(a);
            }
            """
            jitted_exp = _create_jit_fn(clipped_exp_code_string)

            CALLED = [False]

            def clipped_exp(*args, **kwargs):
                CALLED[0] = True
                return jitted_exp(*args, **kwargs)

            # overriding exp's cuda kernel with clipped_exp kernel
            with _scoped_library("aten", "IMPL") as my_lib:
                my_lib.impl("aten::exp", clipped_exp, "CUDA")

                x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16)
                self.assertEqual(
                    torch.exp(x),
                    torch.tensor([1.0, 22026.4657948], dtype=torch.float16),
                )
                self.assertTrue(CALLED[0])

            # behavior restored after deregistration
            self.assertEqual(
                torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16)
            )

        def override_add_cuda() -> None:
            # Example 4: simulate a hardware bug, where the adder is always off by 1
            buggy_add_code_string = """
            template <typename T> T buggy_add(T a, T b){
                return a + b + T(1);
            }
            """
            jitted_add = _create_jit_fn(buggy_add_code_string)

            CALLED = [False]

            def buggy_add(*args, **kwargs):
                CALLED[0] = True
                return jitted_add(*args, **kwargs)

            with _scoped_library("aten", "IMPL") as my_lib:
                my_lib.impl("aten::add.Tensor", buggy_add, "CUDA")

                x_cpu = torch.rand([3, 3], device="cpu")
                y_cpu = torch.rand([3], device="cpu")

                x_cuda = x_cpu.cuda()
                y_cuda = y_cpu.cuda()

                self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
                self.assertTrue(CALLED[0])

            # behavior restored after deregistration
            self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)

        if torch.cuda.is_available() and not TEST_WITH_ROCM:
            override_where_cuda()
            override_gelu_cuda()
            override_exp_cuda()
            override_add_cuda()

    def test_extend_library_with_dispatch_key_arg(self):
        def my_sum(*args, **kwargs):
            return args[0].clone()

        with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
            # RuntimeError: Explicitly provided dispatch key (Conjugate) is
            # inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
            with self.assertRaisesRegex(
                RuntimeError, "inconsistent with the dispatch key"
            ):
                my_lib1.impl("sum", my_sum, "Conjugate")
            my_lib1.impl("aten::sum", my_sum)
            x = torch.tensor([1, 2])
            self.assertEqual(torch.sum(x), x)

    def test_create_new_library(self) -> None:
        with _scoped_library(self.test_ns, "DEF") as my_lib1:
            my_lib1.define("sum(Tensor self) -> Tensor")

            # Example 1
            @torch.library.impl(my_lib1, "sum", "CPU")
            def my_sum(*args, **kwargs):
                return args[0].clone()

            x = torch.tensor([1, 2])
            op = getattr(torch.ops, self.test_ns).sum
            self.assertEqual(op(x), x)

            with _scoped_library(self.test_ns, "IMPL") as my_lib2:
                # Example 2
                @torch.library.impl(my_lib2, op.default, "ZeroTensor")
                def my_sum_zt(*args, **kwargs):
                    if args[0]._is_zerotensor():
                        return torch._efficientzerotensor(args[0].shape)
                    else:
                        return args[0].clone()

                y = torch._efficientzerotensor(3)
                self.assertTrue(op(y)._is_zerotensor())
                self.assertEqual(op(x), x)

    def test_create_new_library_fragment_no_existing(self):
        with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
            my_lib.define("sum2(Tensor self) -> Tensor")

            @torch.library.impl(my_lib, "sum2", "CPU")
            def my_sum(*args, **kwargs):
                return args[0]

            x = torch.tensor([1, 2])
            self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)

    def test_create_new_library_fragment_with_existing(self):
        with _scoped_library(self.test_ns, "DEF") as my_lib1:
            # Create a fragment
            with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
                my_lib2.define("sum4(Tensor self) -> Tensor")

                @torch.library.impl(my_lib2, "sum4", "CPU")
                def my_sum4(*args, **kwargs):
                    return args[0]

                x = torch.tensor([1, 2])
                self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)

                # Create another fragment
                with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
                    my_lib3.define("sum3(Tensor self) -> Tensor")

                    @torch.library.impl(my_lib3, "sum3", "CPU")
                    def my_sum3(*args, **kwargs):
                        return args[0]

                    x = torch.tensor([1, 2])
                    self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)

    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
    def test_alias_analysis(self):
        def test_helper(alias_analysis=""):
            my_lib1 = Library(self.test_ns, "DEF")  # noqa: TOR901

            called = [0]

            @torch.library.define(
                my_lib1, "_op() -> None", alias_analysis=alias_analysis
            )
            def _op(*args, **kwargs):
                called[0] += 1

            @torch.jit.script
            def _test():
                torch.ops._test_python_registration._op()

            assert "_test_python_registration::_op" in str(_test.graph)

        with self.assertRaises(AssertionError):
            test_helper("")  # alias_analysis="FROM_SCHEMA"

        test_helper("CONSERVATIVE")

    def test_error_for_unsupported_ns_or_kind(self) -> None:
        with self.assertRaisesRegex(ValueError, "Unsupported kind"):
            my_lib1 = Library("myns", "BLA")  # noqa: TOR901

        for kind in ("DEF", "FRAGMENT"):
            with self.assertRaisesRegex(ValueError, "reserved namespace"):
                my_lib1 = Library("prim", kind)  # noqa: TOR901

    def test_returning_symint(self) -> None:
        shape_env = ShapeEnv()
        fake_tensor_mode = FakeTensorMode(shape_env=shape_env)

        ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))

        s0, s1 = ft.shape

        with _scoped_library(self.test_ns, "DEF") as tlib:
            tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")

            @impl(tlib, "sqsum", "CompositeExplicitAutograd")
            def sqsum(a: SymInt, b: SymInt):
                return a * a + b * b

            out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
            out_val = shape_env.evaluate_expr(out.node.expr)
        self.assertEqual(out_val, 13)

    def test_register_functional_op_error_cases(self):
        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
            with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
                register_functional_op(lib, "abs", torch.ops.aten.abs_)
            with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
                register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
            with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
                register_functional_op(lib, "abs", torch.ops.aten.abs.out)

            schemas = [
                "foo(Tensor x, Tensor(a!)[] y) -> ()",
                "foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)",
                "foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))",
            ]

        for schema in schemas:
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
                lib.define(schema)
                with self.assertRaisesRegex(RuntimeError, "NYI"):
                    register_functional_op(
                        lib,
                        "foo_functional",
                        getattr(torch.ops, self.test_ns).foo.default,
                    )

    def _check_is_functional_variant(self, mutable_op, functional_op, args):
        # functional op should not mutate
        cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
        functional_result = functional_op(*cloned_args)
        self.assertEqual(cloned_args, args)

        # check functional_result includes mutable_result
        mutable_result = mutable_op(*cloned_args)
        if mutable_result is None:
            flat_mutable_result = []
        else:
            flat_mutable_result = pytree.tree_leaves(mutable_result)
        flat_functional_result = pytree.tree_leaves(functional_result)
        assert len(flat_functional_result) > len(flat_mutable_result)
        self.assertEqual(
            flat_functional_result[: len(flat_mutable_result)], flat_mutable_result
        )

        # check rest of functional_result is the mutated args
        mutated_args = [
            maybe_mutated_arg
            for maybe_mutated_arg, arg in zip(cloned_args, args)
            if not (
                maybe_mutated_arg is not None
                and arg is not None
                and torch.allclose(maybe_mutated_arg, arg)
            )
        ]
        self.assertEqual(
            flat_functional_result[len(flat_mutable_result) :], mutated_args
        )

        # check that functionalization kernel was indeed registered
        def fn(*args):
            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
            mutable_op(*cloned_args)
            return cloned_args

        gm = make_fx(torch.func.functionalize(fn))(*args)
        has_functional_op = False
        for node in gm.graph.nodes:
            self.assertFalse(node.target is mutable_op)
            if node.target is functional_op:
                has_functional_op = True
        self.assertTrue(has_functional_op)

    def test_register_functional_op_no_returns(self):
        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
            lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()")

            def foo_impl(x, y, z, w):
                y.fill_(3.14)
                w.fill_(2.71)

            lib.impl("foo", foo_impl, "CPU")
            register_functional_op(
                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
            )
            x = torch.randn([])
            y = torch.randn([])
            z = torch.randn([])
            w = torch.randn([])
            self._check_is_functional_variant(
                getattr(torch.ops, self.test_ns).foo.default,
                getattr(torch.ops, self.test_ns).foo_functional.default,
                (x, y, z, w),
            )

    def test_register_functional_op_with_optional(self):
        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
            lib.define(
                "foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()"
            )

            def foo_impl(x, y, z, w):
                y.fill_(3.14)
                z.fill_(2.71)
                if w is not None:
                    w.fill_(1.618)

            lib.impl("foo", foo_impl, "CPU")
            register_functional_op(
                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
            )
            x = torch.randn([])
            y = torch.randn([])
            z = torch.randn([])
            w = torch.randn([])
            self._check_is_functional_variant(
                getattr(torch.ops, self.test_ns).foo.default,
                getattr(torch.ops, self.test_ns).foo_functional.default,
                (x, y, z, w),
            )
            self._check_is_functional_variant(
                getattr(torch.ops, self.test_ns).foo.default,
                getattr(torch.ops, self.test_ns).foo_functional.default,
                (x, y, z, None),
            )

    def test_register_functional_op_one_return(self):
        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
            lib.define(
                "foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor"
            )

            def foo_impl(x, y, z, w):
                y.fill_(3.14)
                w.fill_(2.71)
                z.fill_(0.99)
                return x.clone()

            lib.impl("foo", foo_impl, "CPU")
            register_functional_op(
                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
            )
            x = torch.randn([])
            y = torch.randn([])
            z = torch.randn([])
            w = torch.randn([])
            self._check_is_functional_variant(
                getattr(torch.ops, self.test_ns).foo.default,
                getattr(torch.ops, self.test_ns).foo_functional.default,
                (x, y, z, w),
            )

    def test_register_functional_op_multiple_returns(self):
        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
            lib.define(
                "foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)"
            )

            def foo_impl(x, y, z, w):
                y.fill_(3.14)
                w.fill_(2.71)
                return x.clone(), z.clone()

            lib.impl("foo", foo_impl, "CPU")
            register_functional_op(
                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
            )

            x = torch.randn([])
            y = torch.randn([])
            z = torch.randn([])
            w = torch.randn([])
            self._check_is_functional_variant(
                getattr(torch.ops, self.test_ns).foo.default,
                getattr(torch.ops, self.test_ns).foo_functional.default,
                (x, y, z, w),
            )

    def test_register_fallthrough(self):
        with _scoped_library("aten", "IMPL") as my_lib:
            my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")

            a = torch.randn(2, 3, device="cpu", dtype=torch.float32)
            b = torch.randn(3, 2, device="cpu", dtype=torch.float32)
            with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
                # dtype for mm should be float32 since we registered a fallthrough
                self.assertEqual(torch.mm(a, b).dtype, torch.float32)
                # ops that don't have a fallthrough registered should not be affected
                self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)

        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
            # default behavior should have been restored
            self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)


class TestPythonDispatch(TestCase):
    def test_basic(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
            log_input("x", x)
            y = x * x
            saved_x = y.grad_fn._saved_self
            grad_y = LoggingTensor(torch.tensor([1.0]))
            log_input("grad_y", grad_y)
            (g,) = torch.autograd.grad((y,), (x,), (grad_y,))

        self.assertEqual(g.elem, torch.tensor([6.0]))
        with torch.no_grad():
            self.assertEqual(saved_x, x)
            self.assertEqual(saved_x._version, x._version)
            x.add_(2)
            self.assertEqual(saved_x, x)
            # TODO: figure out why broken
            # self.assertEqual(saved_x._version, x._version)
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
$2: f32[1] = input('grad_y')
$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""",
        )

    def test_out(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1))
            y = LoggingTensor(torch.zeros(1))
            log_input("x", x)
            log_input("y", y)
            torch.abs(x, out=y)

        self.assertEqual(y.elem, torch.ones(1))
        # TODO: arguably this shouldn't pass and we should complain
        # that out isn't a kwarg
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[1] = input('x')
$1: f32[1] = input('y')
$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""",
        )

    def test_kwarg_only(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1))
            y = LoggingTensor(torch.ones(1, 1))
            z = LoggingTensor(torch.ones(1))
            log_input("x", x)
            log_input("y", y)
            log_input("z", z)
            torch.addmv(x, y, z)
            torch.addmv(x, y, z, beta=1)
            torch.addmv(x, y, z, beta=2)
            torch.addmv(x, y, z, alpha=2)
            torch.addmv(x, y, z, beta=2, alpha=2)

        # The expectation is that beta/alpha don't show up when they're
        # defaulted.  This is even if the user explicitly specified it.
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[1] = input('x')
$1: f32[1, 1] = input('y')
$2: f32[1] = input('z')
$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""",
        )

    def test_kwarg_only_and_positional_default(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1))
            log_input("x", x)
            torch.ops.aten._foobar(x)
            torch.ops.aten._foobar(x, False)
            torch.ops.aten._foobar(x, arg3=False)
            torch.ops.aten._foobar(x, False, arg3=False)

        # What we are testing here is that we omit arg2
        # if it is defaulted, even if a kwarg is set
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten._foobar.default($0)
$2: f32[1] = torch._ops.aten._foobar.default($0, False)
$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
        )

    def test_produce_real_type(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(2, 2))
            log_input("x", x)
            x.to(dtype=torch.double)  # non-optional dtype
            torch.cumprod(x, 0, dtype=torch.double)  # optional dtype
            x[:, 1].contiguous(
                memory_format=torch.contiguous_format
            )  # optional memory format
            # There doesn't appear to be any layout signatures which are
            # triggerable using tensor subclasses (need to use a mode)

        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[2, 2] = input('x')
$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
        )

    def test_optional_tensor_list(self) -> None:
        def weird(xs):
            print("woof")
            return torch.empty(())

        with _scoped_library("my_lib", "DEF") as my_lib:
            my_lib.define("weird(Tensor?[] self) -> Tensor")
            my_lib.impl("weird", weird, "CPU")
            with capture_logs() as logs:
                x = LoggingTensor(torch.ones(2, 2))
                log_input("x", x)
                torch.ops.my_lib.weird.default([None, x])

        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[2, 2] = input('x')
$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
        )

    def test_list_ret(self) -> None:
        # test all sequence types are permissible returns
        for list_type in (list, tuple):

            class A(torch.Tensor):
                @staticmethod
                def __new__(cls, elem):
                    return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

                @classmethod
                def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                    if func.overloadpacket == torch.ops.aten.split:
                        with no_dispatch():
                            return list_type(torch.split(*args))
                    else:
                        raise AssertionError(f"unrecognized func: {func}")

            self.assertEqual(
                torch.split(A(torch.tensor([0, 1])), 2),
                torch.split(torch.tensor([0, 1]), 2),
            )

    def test_invalid_ret(self) -> None:
        # test invalid return gets reasonable error message
        class A(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                return "arf"

        # Wobbles depending on NDEBUG mode of pybind11
        self.assertRaisesRegex(
            RuntimeError,
            "Unable to cast",
            lambda: A(torch.zeros(1)).neg(),
        )
        self.assertRaisesRegex(
            RuntimeError,
            "Unable to cast",
            lambda: A(torch.zeros(1)).detach(),
        )

    def test_detach_appears_twice_when_called_once(self) -> None:
        with capture_logs() as logs:
            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
            log_input("x", x)
            x.detach()
        # FIXME: We actually want this to emit a single detach. However,
        # it currently emits two, for reasons unclear to us. Leaving
        # this test here to make sure we don't regress even further (it
        # would be bad if calling .detach() once emits 3+ detaches).
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[1] = input('x')
$1: f32[1] = torch._ops.aten.detach.default($0)
$2: f32[1] = torch._ops.aten.detach.default($1)""",
        )

    def test_storage(self) -> None:
        # For now, just make sure it doesn't crash.  Ideally, we should
        # return some virtual storage that is safe to work with
        x = LoggingTensor(torch.ones(1))
        storage = x.untyped_storage()
        self.assertRaises(RuntimeError, lambda: storage.data_ptr())

    def test_make_wrapper_subclass_noalloc(self) -> None:
        # This is ludicrously big (8TB) and this should pass because wrapper
        # subclasses don't allocate
        torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))

    def test_version(self) -> None:
        x = LoggingTensor(torch.ones(1))
        prev_vc = x._version
        x.detach().add_(2)
        cur_vc = x._version
        self.assertNotEqual(prev_vc, cur_vc)
        x.data.add_(2)
        self.assertEqual(cur_vc, x._version)

    def test_subclass_priority(self) -> None:
        class ErrorA(RuntimeError):
            pass

        class ErrorB(RuntimeError):
            pass

        # The big tests for code coverage are test_precedence_semantics in
        # test_overrides.py; this is just to make sure it is wired up at all
        # correctly for __torch_dispatch__
        class A(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                raise ErrorA

        class B(A):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                raise ErrorB

        self.assertRaises(
            ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))
        )
        self.assertRaises(
            ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))
        )
        self.assertRaises(
            ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))
        )
        self.assertRaises(
            ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))
        )

    def test_format(self) -> None:
        x = LoggingTensor(torch.ones(1))
        s1 = str(x)
        s2 = repr(x)
        s3 = f"{x}"
        self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
        self.assertEqual(s1, s2)
        self.assertEqual(s1, s3)

    def test_custom_autograd(self) -> None:
        escape = [None]

        class Square(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x):
                y = x**2
                ctx.save_for_backward(x)
                return y

            @staticmethod
            def backward(ctx, grad_output):
                assert isinstance(grad_output, LoggingTensor)
                (x,) = ctx.saved_tensors
                assert isinstance(x, LoggingTensor)
                escape[0] = x
                return grad_output * 2 * x

        with capture_logs() as logs:
            x = LoggingTensor(torch.ones(1), requires_grad=True)
            log_input("x", x)
            x.grad = LoggingTensor(torch.zeros(1))
            log_input("x.grad", x.grad)
            y = Square.apply(x)
            grad_output = LoggingTensor(torch.ones(1))
            log_input("grad_output", grad_output)
            y.backward(grad_output)

        with torch.no_grad():
            self.assertEqual(escape[0], x)
            self.assertEqual(escape[0]._version, x._version)
            # TODO: figure out why x.requires_grad = False doesn't
            # trigger an error for LoggingTensor
            x.add_(2)
            self.assertEqual(escape[0], x)
            # TODO: figure out why this is broken
            # self.assertEqual(escape[0]._version, x._version)

        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[1] = input('x')
$1: f32[1] = input('x.grad')
$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
$3: f32[1] = input('grad_output')
$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""",
        )

    def test_subclass_creation(self):
        # Make sure these statements runs without error
        # In particular checking that when internal detach returns
        # subclasses, these are cleanly overwritten.
        class Foo(torch.Tensor):
            pass

        err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
        with self.assertRaisesRegex(RuntimeError, err_msg):
            a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
        with self.assertRaisesRegex(RuntimeError, err_msg):
            b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
        with self.assertRaisesRegex(RuntimeError, err_msg):
            Foo(LoggingTensor(torch.rand(2)))

        with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
            torch.Tensor._make_wrapper_subclass(Foo, (2, 2))

    def test_new_ones(self) -> None:
        class MyTensor(torch.Tensor):
            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                return MyTensor(3)

        self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)

    def test_like(self) -> None:
        class MyTensor(torch.Tensor):
            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                return MyTensor(3)

        for f in ["empty", "ones", "rand", "randn", "zeros"]:
            f_name = f + "_like"
            self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)

        self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor)
        self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)

    def test_make_fx_with_subclass(self) -> None:
        def f(x, y):
            # Returns (TwoTensor, Tensor)
            return x * y, y + y

        x_a = torch.zeros(4)
        x_b = torch.zeros(4)
        y = torch.ones(4)

        # make_fx() is not responsible for unwrapping tensor subclass inputs,
        # so we do it manually here.
        # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
        # convention as f(*args). Unwrapping tensor subclass inputs can potentially change
        # the number of input args to the graph, breaking that assumption
        def f_to_trace(x_a, x_b, y):
            x = TwoTensor(x_a, x_b)
            out1, out2 = f(x, y)
            out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
            return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)

        fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y)
        self.assertExpectedInline(
            fx_g.code,
            """\



def forward(self, x_a_1, x_b_1, y_1):
    mul = torch.ops.aten.mul.Tensor(x_a_1, y_1);  x_a_1 = None
    mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1);  x_b_1 = None
    add = torch.ops.aten.add.Tensor(y_1, y_1);  y_1 = None
    return (mul, mul_1, add)
    """,
        )

    # See https://github.com/pytorch/pytorch/issues/117794
    def test_return_and_correct_aliasing_gives_correct_stride(self):
        t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
        x = torch.randn(2, 2)
        # slicing should result in the same stride for TwoTensor as a dense tensor would give
        self.assertEqual(t[:, 0].stride(), x[:, 0].stride())

    def test_make_wrapper_subclass_propagates_metadata(self) -> None:
        class WrapperTensor(torch.Tensor):
            elem: torch.Tensor

            __slots__ = ["elem"]

            @staticmethod
            def __new__(cls, elem, *args, **kwargs):
                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
                    cls,
                    elem.size(),
                    dtype=elem.dtype,
                    layout=elem.layout,
                    device=elem.device,
                    requires_grad=elem.requires_grad,
                    strides=elem.stride(),
                    storage_offset=elem.storage_offset(),
                )
                r.elem = elem
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                raise RuntimeError("NYI")

        # non-contiguous strides, non-zero storage offset
        x = torch.randn(4, 6).t().diagonal(offset=2)
        y = WrapperTensor(x)
        self.assertEqual(y.size(), x.size())
        self.assertEqual(y.stride(), x.stride())
        self.assertEqual(y.storage_offset(), x.storage_offset())

    def test_wrapper_subclass_serializes(self) -> None:
        with tempfile.TemporaryFile() as f:
            # purposefully use int64 to test non-default dtype
            x = LoggingTensor(torch.randperm(3))
            torch.save(x, f)
            f.seek(0)
            x_loaded = torch.load(f)
            self.assertTrue(type(x_loaded) is type(x))
            self.assertEqual(x, x_loaded)
            self.assertEqual(x.elem, x_loaded.elem)
            self.assertFalse(x is x_loaded)

    def test_deepcopy_wrapper_subclass(self) -> None:
        # purposefully use int64 to test non-default dtype
        x = LoggingTensor(torch.randperm(3))
        x_copy = deepcopy(x)
        self.assertTrue(type(x_copy) is type(x))
        self.assertEqual(x, x_copy)
        self.assertEqual(x.elem, x_copy.elem)
        self.assertFalse(x is x_copy)

    def test_deepcopy_wrapper_subclass_with_clone_returning_different_type(
        self,
    ) -> None:
        class MyWrapperTensor(torch.Tensor):
            elem: torch.Tensor

            __slots__ = ["elem"]

            @staticmethod
            def __new__(cls, elem, *args, **kwargs):
                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
                    cls,
                    elem.size(),
                    dtype=elem.dtype,
                    layout=elem.layout,
                    device=elem.device,
                    requires_grad=elem.requires_grad,
                    strides=elem.stride(),
                    storage_offset=elem.storage_offset(),
                )
                r.elem = elem
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                if func.overloadpacket.__name__ == "clone":
                    # Return a plain tensor from clone().
                    return args[0].elem.clone()
                raise RuntimeError("NYI")

            # NB: The default Tensor.__torch_function__ implementation called for deepcopy
            # disables __torch_function__ by the time we get to clone(), so there is no need to
            # explicitly disable __torch_function__ for this subclass.

        x = MyWrapperTensor(torch.randn(3))
        with self.assertRaisesRegex(
            RuntimeError,
            "for which cloning returns another instance of the same subclass",
        ):
            x_copy = deepcopy(x)

    def test_deepcopy_non_wrapper_subclass(self) -> None:
        # Ensure correct error is thrown for common error cases.
        class SubTensorError1(torch.Tensor):
            # Default implementation of new_empty() returns a plain tensor.
            pass

        class SubTensorError2(torch.Tensor):
            # new_empty() incorrectly returns a different type (i.e. a plain tensor).
            def new_empty(self, shape):
                return torch.Tensor(shape)

        for error_cls in [SubTensorError1, SubTensorError2]:
            x = error_cls(3)
            with self.assertRaisesRegex(
                RuntimeError,
                "for which that function returns another instance of the same subclass",
            ):
                x_copy = deepcopy(x)

        # Ensure a correctly implemented new_empty() causes deepcopy() to work.
        class SubTensorSuccess(torch.Tensor):
            def new_empty(self, shape):
                return type(self)(shape)

        x = SubTensorSuccess(3)
        x_copy = deepcopy(x)
        self.assertIs(type(x_copy), type(x))

    def test_wrapper_subclass_extra_dispatch_keys(self) -> None:
        class ExtraKeysTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, elem, *args, **kwargs):
                # NB: only the non-kwarg overload of _make_wrapper_subclass supports
                #     extra dispatch keys. We probably want to unify the two APIs
                #     in the future.
                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
                    cls,
                    elem.size(),
                    elem.stride(),
                    elem.storage_offset(),
                    torch.contiguous_format,
                    elem.dtype,
                    elem.layout,
                    elem.device,
                    False,
                    False,
                    None,
                    False,
                    False,
                    DispatchKeySet(DispatchKey.NestedTensor),
                )
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                pass

        x = ExtraKeysTensor(torch.randn(3))
        self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor))
        self.assertFalse(
            torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor)
        )

    def test_wrapper_subclass_multiprocessing_preserves_dtype(self):
        # a and b have dtype of int64, which is purposefully different from the default
        # assumed by _make_wrapper_subclass().
        a = torch.randperm(5)
        b = torch.randperm(5)
        data = TwoTensor(a, b)
        expected_dtype = data.dtype

        loader = torch.utils.data.DataLoader(
            [data, data],
            batch_size=2,
            num_workers=2,
            collate_fn=_identity,
        )
        for batch in loader:
            self.assertEqual(batch[0].dtype, expected_dtype)

    def test_index_put_where_only_index_is_subclass(self) -> None:
        called_funcs = []

        class MyTensor(torch.Tensor):
            elem: torch.Tensor
            __slots__ = ["elem"]

            @staticmethod
            def __new__(cls, elem, *args, **kwargs):
                r = torch.Tensor._make_wrapper_subclass(
                    cls,
                    elem.size(),
                    dtype=elem.dtype,
                    layout=elem.layout,
                    device=elem.device,
                    requires_grad=elem.requires_grad,
                )
                r.elem = elem
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                called_funcs.append(func)
                return MyTensor(torch.tensor(3))

        x = torch.randn(3, 3)
        idxs = (MyTensor(torch.tensor(0)),)
        v = torch.randn(1)
        res = x.index_put_(idxs, v)
        self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])

    def test_torch_dispatch_mode_basic(self) -> None:
        with capture_logs(is_mode=True) as logs:
            with LoggingTensorMode():
                torch.empty([])
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
        )

    def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
        x = torch.randn([])
        y = torch.randn([])
        with capture_logs(is_mode=True) as logs:
            with LoggingTensorMode():
                x + y
        self.assertExpectedInline(
            "\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)"""
        )

    def test_nested_push_logging_tensor_mode(self):
        x = torch.randn([])
        y = torch.randn([])
        with capture_logs(is_mode=True) as logs:
            with LoggingTensorMode():
                with LoggingTensorMode():
                    torch.empty([])
                    x + y

        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
        )

    def test_capture_logs_with_torch_dispatch_mode(self):
        x = torch.randn([])
        y = torch.randn([])
        with capture_logs_with_logging_tensor_mode() as logs:
            torch.empty([])
            x + y
        self.assertExpectedInline(
            "\n".join(logs),
            """\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
        )

        x = torch.randn([])
        y = torch.randn([])

        with capture_logs_with_logging_tensor_mode() as logs1:
            with capture_logs_with_logging_tensor_mode() as logs2:
                torch.empty([])
                x + y

        self.assertExpectedInline(
            "\n".join(logs2),
            """\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
        )

        self.assertEqual(logs1, logs2)

    def test_torch_dispatch_mode_subclass_priority(self) -> None:
        class ErrorA(RuntimeError):
            pass

        class ErrorB(RuntimeError):
            pass

        class A(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                with AMode():
                    raise ErrorA

        class B(A):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                with BMode():
                    func(*args, **kwargs)

        class AMode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                raise ErrorA

        class BMode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                raise ErrorB

        a = A(torch.empty(1))
        b = B(torch.empty(1))
        with self.assertRaises(ErrorA):
            a + a
        with self.assertRaises(ErrorB):
            a + b

        # B has precedence over A due to the subclass relationship yet
        # modes take precedence over arguments
        with self.assertRaises(ErrorA):
            with AMode():
                b + b
        with self.assertRaises(ErrorB):
            with BMode():
                a + a
        with self.assertRaises(ErrorB):
            with BMode():
                a + b

    def test_mode_with_make_subclass(self):
        class SubTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

        class BasicMode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                return func(*args, **kwargs)

        x = torch.randn(3)
        with BasicMode():
            y = SubTensor(x)
        self.assertIsInstance(y, SubTensor)

    def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
        with capture_logs(is_mode=True) as logs1:
            with LoggingTensorMode():
                torch.ones([2, 3])
                with no_dispatch():
                    torch.ones([2, 3])
        with capture_logs(is_mode=True) as logs2:
            with LoggingTensorMode():
                torch.ones([2, 3])
        self.assertEqual(logs1, logs2)

    def test_shallow_copy_and_detach(self) -> None:
        seen = set()
        test_case = self

        class TestMode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                tree_map_only(
                    torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs)
                )
                if kwargs is None:
                    kwargs = {}
                r = func(*args, **kwargs)
                tree_map_only(torch.Tensor, lambda t: seen.add(t), r)
                return r

        with TestMode():
            x = torch.randn(3, requires_grad=True)
            loss = (x * x).sum()
            loss.backward()

    def test_exception_handling(self):
        class A(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

        class AMode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                if func.__name__ == "randn.default":
                    raise RuntimeError
                return A(torch.zeros(()))

        with AMode():
            try:
                torch.randn(())
            except RuntimeError:
                pass
            self.assertTrue(isinstance(torch.zeros(()), A))

    def test_with_mode_created_separately(self):
        class ErrorA(RuntimeError):
            pass

        class A(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                raise ErrorA

        x = A()
        with self.assertRaises(ErrorA):
            with x:
                torch.empty([])

    def test_with_nested_modes(self):
        class ErrorA(RuntimeError):
            def __init__(self, msg):
                super().__init__(msg)

        class A(TorchDispatchMode):
            def __init__(self, msg):
                self.msg = msg

            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                raise ErrorA(self.msg)

        with self.assertRaisesRegex(ErrorA, "layer2"):
            with A("layer1"):
                with A("layer2"):
                    torch.empty([])

    def test_make_subclass_with_modes(self):
        class ModeTensor(torch.Tensor):
            def __new__(cls, elem, mode):
                r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
                r.elem = elem
                r.mode = mode
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                raise NotImplementedError("Shouldn't be here")

        class Mode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                def unwrap(e):
                    if isinstance(e, ModeTensor):
                        return e.elem
                    else:
                        return e

                def wrap(t):
                    if isinstance(t, torch.Tensor):
                        return ModeTensor(t, self)
                    else:
                        return t

                return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))

        class BasicMode(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                return func(*args, **kwargs)

        x = torch.tensor(4.0)
        with Mode():
            y = x + x
            z = y + y
        self.assertIsInstance(y, ModeTensor)
        self.assertIsInstance(z, ModeTensor)

        with Mode():
            with BasicMode():  # we can't nest two modes that call make_subclass because it only accepts vanilla tensors
                y = x + x
                z = y + y
        self.assertIsInstance(y, ModeTensor)
        self.assertIsInstance(z, ModeTensor)

        assert self.assertRaisesRegex(
            RuntimeError,
            "subclass Mode but.* associated to a python object of type Mode",
        )

    def test_notimplemented_mode(self):
        sub_count = 0

        class PoliteMode(TorchDispatchMode):
            def __init__(self):
                self.pre_count = 0
                self.post_count = 0

            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                self.pre_count += 1
                if any(t is not torch.Tensor for t in types):
                    return NotImplemented
                self.post_count += 1
                return func(*args, **kwargs)

        class SubTensor(torch.Tensor):
            def __new__(cls, elem):
                r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
                r.elem = elem
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                nonlocal sub_count
                sub_count += 1

                def unwrap(t):
                    if isinstance(t, SubTensor):
                        return t.elem
                    else:
                        return t

                return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

        a = SubTensor(torch.randn(2))
        with PoliteMode() as mode:
            a.abs()

        self.assertEqual(mode.pre_count, 2)
        self.assertEqual(mode.post_count, 1)
        self.assertEqual(sub_count, 1)

        # make sure this doesn't error
        with PoliteMode():
            with PoliteMode():
                a.abs()

    def test_nesting_same_mode(self):
        # If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.

        with capture_logs(is_mode=True) as logs:
            with LoggingTensorMode() as reenabled:
                with reenabled:
                    torch.empty([])
            self.assertExpectedInline(
                "\n".join(logs),
                """\
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
            )

    def test_error_using_class_method_on_mode(self):
        class A(TorchDispatchMode):
            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                return func(args, kwargs)

        x = torch.tensor(5.0)
        with self.assertRaisesRegex(
            RuntimeError, "classmethod is not supported, please make it a plain method"
        ):
            with A():
                x + x

    def test_get_cur_mode(self):
        class A(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                pass

        self.assertEqual(_get_current_dispatch_mode(), None)

        with A() as mode1:
            self.assertEqual(_get_current_dispatch_mode(), mode1)

        with mode1:
            with A() as mode2:
                self.assertEqual(_get_current_dispatch_mode(), mode2)

    def test_get_mode_stack(self):
        class A(TorchDispatchMode):
            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                pass

        self.assertEqual(_get_current_dispatch_mode_stack(), [])

        with A() as mode1:
            self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])

        with mode1:
            with A() as mode2:
                self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])

    def test_all_same_mode(self):
        x = LoggingTensorMode()
        y = LoggingTensorMode()
        self.assertTrue(all_same_mode([x, x, x]))
        self.assertFalse(all_same_mode([x, None]))
        self.assertFalse(all_same_mode([x, y]))

    def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
        x = LoggingTensor(torch.tensor([2.0, 3.0]))
        with self.assertRaisesRegex(
            RuntimeError, "is not supported for tensor subclasses."
        ):
            x.tolist()
        with self.assertRaisesRegex(
            RuntimeError, "is not supported for tensor subclasses."
        ):
            x.numpy()
        with self.assertRaises(AssertionError):
            self.assertEqual(x, None)

    def test_record_stream(self) -> None:
        class TestMode(TorchDispatchMode):
            def __init__(self, testcase):
                self.testcase = testcase

            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                self.testcase.assertEqual(func.name(), "aten::record_stream")
                self.testcase.assertIsInstance(args[0], torch.Tensor)
                self.testcase.assertIsInstance(args[1], torch.Stream)
                self.testcase.assertEqual(args[1].stream_id, 1)
                self.testcase.assertEqual(args[1].device_index, 2)
                self.testcase.assertEqual(args[1].device_type, 3)

        t = torch.tensor(5.0)
        s = torch.Stream(stream_id=1, device_index=2, device_type=3)
        with TestMode(self):
            t.record_stream(s)

    def test_return_stream(self) -> None:
        with _scoped_library("test_return_stream", "DEF") as l_def:
            l_def.define("return_stream(Tensor self) -> Stream")
            with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
                l_impl.impl(
                    "return_stream",
                    lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2),
                )

                class TestMode(TorchDispatchMode):
                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
                        return torch.Stream(stream_id=1, device_index=2, device_type=3)

                t = torch.tensor(5.0)
                s = torch.ops.test_return_stream.return_stream(t)
                self.assertIsInstance(s, torch.Stream)
                self.assertEqual(s.stream_id, 0)
                self.assertEqual(s.device_index, 1)
                self.assertEqual(s.device_type, 2)

                with TestMode():
                    s = torch.ops.test_return_stream.return_stream(t)
                self.assertIsInstance(s, torch.Stream)
                self.assertEqual(s.stream_id, 1)
                self.assertEqual(s.device_index, 2)
                self.assertEqual(s.device_type, 3)

    def test_subclass_autograd_device_check(self) -> None:
        class NonWrapperSubclass(torch.Tensor):
            elem: torch.Tensor

            __slots__ = ["elem"]

            @staticmethod
            def __new__(cls, elem, *args, **kwargs):
                # Wrong device here!
                r = torch.Tensor._make_subclass(
                    cls, elem.to("meta"), elem.requires_grad
                )
                # ...the real tensor is held as an element on the tensor.
                r.elem = elem
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e):
                    return e.elem if isinstance(e, NonWrapperSubclass) else e

                def wrap(e):
                    return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e

                rs = tree_map(
                    wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
                )
                logging.getLogger("NonWrapperSubclass").info(
                    f"{func.__module__}.{func.__name__}", args, kwargs, rs  # noqa: G004
                )
                return rs

        x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
        y = torch.randn(2, requires_grad=True)
        z = x * y
        self.assertIsInstance(z, NonWrapperSubclass)
        z.sum().backward(torch.tensor(1))
        self.assertEqual(x.grad, y)
        self.assertEqual(y.grad, x)

    def test_none_wrapping(self):
        # A Tensor subclass that returns None when doing add
        # See LoggingTensor above for more details on the subclass
        class SubclassWithNone(torch.Tensor):
            @staticmethod
            def __new__(cls, elem, *args, **kwargs):
                r = torch.Tensor._make_wrapper_subclass(
                    cls,
                    elem.size(),
                    dtype=elem.dtype,
                    layout=elem.layout,
                    device=elem.device,
                    requires_grad=elem.requires_grad,
                )
                r.elem = elem
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                def unwrap(e):
                    return e.elem if isinstance(e, SubclassWithNone) else e

                def wrap(e):
                    return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e

                rs = tree_map(
                    wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
                )
                if func.overloadpacket.__name__ == "add":
                    return None
                else:
                    return rs

        x = SubclassWithNone(torch.rand(2))
        # Make sure both run without error
        self.assertIsInstance(x * 2, SubclassWithNone)
        self.assertIsNone(x + 2)

        x.requires_grad_()
        out = x.acos().sum()

        # The backward of acos does add then rsqrt so here we make sure that the
        # undefined Tensor generated by the user code is nicely handled.
        # If acos formula changes in the future, this can be replaced by any other
        # function that does add then something in the backward in a composite way
        with self.assertRaisesRegex(RuntimeError, "but got None"):
            out.backward()

    def test_storage_can_be_converted_to_python_object(self):
        s = torch.Storage()
        z = LoggingTensor(torch.empty([]))
        z.set_(s)

    def test_autograd_in_attr(self):
        # We want the wrapped Tensor to require gradients!
        true_t = torch.rand(2, requires_grad=True)
        t = LoggingTensorReentrant(true_t)

        out = t + 2

        self.assertFalse(out.requires_grad)
        self.assertIsNone(out.grad_fn)

        self.assertTrue(out.elem.requires_grad)
        self.assertIsNotNone(out.elem.grad_fn)

        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
            out.sum().backward()

        out.elem.sum().backward()

        self.assertIsNone(t.grad)
        self.assertIsNotNone(t.elem.grad)

    def test_dispatch_super_call(self):
        called = []

        class SubTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                called.append(func)
                return super().__torch_dispatch__(func, types, args, kwargs)

        x = torch.randn(2)
        y = torch.randn(2)
        self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
        self.assertEqual(called, [torch.ops.aten.add.Tensor])

    def test_dispatch_super_call_list_arg(self):
        called = []

        class SubTensorWithListArg(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                called.append(func)
                return super().__torch_dispatch__(func, types, list(args), kwargs)

        x = torch.randn(2)
        self.assertEqual(SubTensorWithListArg(x).neg(), x.neg())
        self.assertEqual(called, [torch.ops.aten.neg.default])

    def test_dispatch_super_dont_autograd(self):
        called = []

        class SubTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                called.append(func)
                # This argument still requires grad because it was passed
                # through directly...
                self.assertTrue(args[0].requires_grad)
                r = super().__torch_dispatch__(func, types, args, kwargs)
                # But the output better not require grad, because that means
                # you did autograd again in torch dispatch (oops)
                self.assertFalse(r.requires_grad)
                return r

        x = SubTensor(torch.randn(2, requires_grad=True))
        x.neg()
        self.assertEqual(called, [torch.ops.aten.neg.default])

    def test_set_data(self):
        called = 0

        class SubTensor(torch.Tensor):
            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                nonlocal called
                called += 1
                return super().__torch_dispatch__(func, types, args, kwargs)

        x = SubTensor(torch.empty(2))
        x.data
        self.assertEqual(called, 1)
        x.data = torch.empty(2)
        self.assertEqual(called, 1)
        x.data
        self.assertEqual(called, 2)
        self.assertIs(type(x), SubTensor)
        x.set_(torch.empty(2))
        self.assertEqual(called, 3)
        x.data
        self.assertEqual(called, 4)
        self.assertIs(type(x), SubTensor)

    def test_construct_int_tensor(self):
        class SubTensor(torch.Tensor):
            pass

        # should not fail
        SubTensor(torch.zeros(2, dtype=torch.int))

    def test_multiple_ops_subclass(self):
        # This is a Direct Subclass, don't do that!
        class MySubclass(torch.Tensor):
            @staticmethod
            def __new__(cls, elem):
                r = torch.Tensor._make_subclass(cls, elem)
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                with no_dispatch():
                    return func(*args, **kwargs)

        x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
        y = x.conj()
        # Details of the bug that this tests for:
        # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
        # There are a few calls to the dispatcher that are going to happen here:
        #  - call_exp: User calling exp on y
        #    - PythonTLSSnapshot: records the TLS on entry and redispatch
        #    - AutogradCPU: no input requires grad, so does nothing and redispatch
        #    - Conjugate: no special implementation for exp: use the fallback that
        #                 first clone the Tensor (to materialize the conj) then redispatch
        #      - call_clone: conjugate fallback calling clone on y
        #        - PythonTLSSnapshot: records the TLS on entry and redispatch
        #        - (AutogradCPU: skipped as autograd added itself to the exclude set above)
        #        - Conjugate: special implementation for clone: just skip this key
        #        - Python: Reset the TLS based on the snapshot above and call the user implementation (this
        #                  actually calls into the dispatcher again but since we disable both our keys
        #                  before, not detailed here)
        #        - exit Python: restore the TLS and exit
        #        - exit Conjugate: nothing was inplace so just exit
        #        - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
        #    - Python: Reset the TLS again based on the snapshot. <- this used to fail
        #    - More steps....
        y.exp()

    @staticmethod
    def subclass_helper(cls, data, use_wrapper_subclass, **kwargs):
        if use_wrapper_subclass:
            kwargs["device"] = data.device
            kwargs["dtype"] = data.dtype
            kwargs["layout"] = data.layout
            kwargs["requires_grad"] = True
            return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)  # type: ignore[attr-defined]
        else:
            return torch.Tensor._make_subclass(cls, data, True, **kwargs)

    def test_is_contiguous_slow_path(self):
        data = torch.randn(3, 3)
        contiguous_data = data.clone()
        not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))

        for use_wrapper_subclass in [True, False]:

            class ExampleTensor1(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    return NotImplemented

            class ExampleTensor2(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.aten.is_contiguous:
                        return contiguous_data.is_contiguous()
                    return NotImplemented

            class ExampleTensor3(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.aten.is_contiguous:
                        return not_contiguous_data.is_contiguous()
                    return NotImplemented

            err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
            e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
            with self.assertRaisesRegex(TypeError, err_msg):
                e.is_contiguous()
            with self.assertRaisesRegex(TypeError, err_msg):
                e.contiguous()

            e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
            self.assertEqual(e.is_contiguous(), True)
            e.contiguous()  # this will just return the original TensorImpl since is_contiguous = True

            err_msg = "Multiple dispatch failed for"
            e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
            self.assertEqual(e.is_contiguous(), False)
            with self.assertRaisesRegex(TypeError, err_msg):
                e.contiguous()

    def test_fancy_strides(self):
        calls = []

        class ExampleTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, data):
                return TestPythonDispatch.subclass_helper(
                    cls, data, False, dispatch_sizes_strides_policy="strides"
                )

            @classmethod
            def __torch_dispatch__(cls, func, types, args, kwargs):
                if func in [
                    torch.ops.aten.is_contiguous.default,
                    torch.ops.aten.is_contiguous.memory_format,
                    torch.ops.aten.is_strides_like_format.default,
                    torch.ops.aten.is_non_overlapping_and_dense.default,
                    torch.ops.aten.stride.default,
                ]:
                    calls.append((func, list(args)[1:]))
                    return None
                with no_dispatch():
                    return func(*args, **kwargs)

        e = ExampleTensor(torch.randn(2, 2))
        self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
        self.assertEqual(
            calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])]
        )
        calls.clear()
        self.assertFalse(
            torch.ops.aten.is_strides_like_format.default(e, torch.channels_last)
        )
        self.assertEqual(
            calls,
            [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])],
        )
        calls.clear()
        self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
        self.assertEqual(
            calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])]
        )

    def test_device_slowpath(self):
        for use_wrapper_subclass in [True]:

            class ExampleTensor1(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_device=True
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    return NotImplemented

            class ExampleTensor2(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_device=True
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.prim.device:
                        return torch.device("meta")
                    return NotImplemented

            class ExampleTensor3(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_device=True
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.prim.device:
                        return torch.device("meta")
                    return NotImplemented

            err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'"
            with self.assertRaisesRegex(TypeError, err_msg):
                e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
                e.device()

            ten = torch.rand([1])
            e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
            self.assertEqual(e.device.type, "meta")
            self.assertEqual(ten.type_as(e).device.type, "meta")

            e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
            self.assertEqual(e.device.type, "meta")
            self.assertEqual(ten.type_as(e).device.type, "meta")

    def test_dim_slowpath(self):
        data = torch.randn(3, 3)

        for use_wrapper_subclass in [True, False]:

            class DimNotImplementedTensor(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    return NotImplemented

            class DimImplementedTensor(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.aten.dim:
                        return data.dim()
                    return NotImplemented

            err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'"
            e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
            with self.assertRaisesRegex(TypeError, err_msg):
                e.dim()

            t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
            self.assertEqual(t.dim(), 2)

    def test_maybe_tuple_bug(self):
        class T(torch.Tensor):
            @classmethod
            def __torch_function__(cls, *args, **kwargs):
                pass

        a = torch.rand(3)

        a[[T(), T()]]

    def test_standard_is_not_subclass(self):
        # https://github.com/pytorch/pytorch/issues/79079
        self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))

    def test_sym_sizes_strides_slow_path(self):
        class TestTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, *args, **kwargs):
                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
                    cls, (0,), dispatch_sizes_strides_policy="sizes"
                )
                return r

            @classmethod
            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
                if func in (
                    torch.ops.aten.sym_size.default,
                    torch.ops.aten.sym_stride.default,
                ):
                    from torch._dynamo.source import ConstantSource
                    from torch.fx.experimental.symbolic_shapes import (
                        DimDynamic,
                        ShapeEnv,
                    )

                    shape_env = ShapeEnv()
                    si = shape_env.create_symintnode(
                        shape_env.create_symbol(
                            123,
                            source=ConstantSource("abc"),
                            dynamic_dim=DimDynamic.DUCK,
                            constraint_dim=None,
                        ),
                        hint=123,
                    )
                    return (si,)

        t = TestTensor()
        si = t.size()[0]
        self.assertIsInstance(si, torch.SymInt)
        si = t.stride()[0]
        self.assertIsInstance(si, torch.SymInt)

    def test_strides_slow_path(self):
        for use_wrapper_subclass in [True, False]:

            class StridesNotImplemented(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    return NotImplemented

            class StridesCustomReturn(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func == torch.ops.aten.sym_stride.default:
                        return (4, 2)
                    return NotImplemented

            class StridesDefaultReturn(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func == torch.ops.aten.sym_stride.default:
                        return None
                    return NotImplemented

            err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'"
            e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
            with self.assertRaisesRegex(TypeError, err_msg):
                e.stride()

            e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
            self.assertEqual(e.stride(), (4, 2))

            e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
            self.assertEqual(e.stride(), (2, 1))

    def test_sizes_slow_path(self):
        for use_wrapper_subclass in [True, False]:
            data = torch.randn(6, 2)

            class SizesNotImplemented(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.aten.dim:
                        return data.dim()
                    return NotImplemented

            class SizesCustomReturn(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.aten.dim:
                        return data.dim()
                    if func.overloadpacket == torch.ops.aten.sym_size:
                        return (5, 3)
                    return NotImplemented

            class SizesDefaultReturn(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.aten.dim:
                        return data.dim()
                    if func.overloadpacket == torch.ops.aten.sym_size:
                        return None
                    return NotImplemented

            err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'"
            e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
            with self.assertRaisesRegex(TypeError, err_msg):
                e.size()

            e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
            self.assertEqual(e.size(), (5, 3))

            e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
            self.assertEqual(e.size(), (4, 2))

    def test_custom_size_policy_dynamic_shapes(self):
        data = torch.randn(6, 2)

        class CustomSizeDynamicShapesTensor(torch.Tensor):
            @staticmethod
            def __new__(cls, inner):
                return torch.Tensor._make_wrapper_subclass(
                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
                    # Calling the overload that has kwargs causes us to go down the first overload path,
                    # which will **always** specialize sizes.
                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
                    cls,
                    inner.size(),
                    inner.stride(),
                    None,
                    None,
                    inner.dtype,
                    inner.layout,
                    inner.device,
                    False,
                    inner.requires_grad,
                    "sizes",
                )

            def __init__(self, inner):
                self.inner = inner

            @classmethod
            def __torch_dispatch__(cls, func, types, args, kwargs):
                if func == torch.ops.aten.sym_size.default:
                    return args[0].inner.shape
                if func == torch.ops.aten.sym_stride.default:
                    return args[0].inner.shape
                return NotImplemented

        x = torch.ones(2, 2)

        def trace_fn(x):
            x_wrapper = CustomSizeDynamicShapesTensor(x)
            return x_wrapper.size(), x_wrapper.stride()

        fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
        self.assertExpectedInline(
            fx_g.code.strip(),
            """\
def forward(self, x_1):
    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1);  x_1 = None
    return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""",
        )

    def test_data_ptr_respects_numel_slow_path(self):
        data = torch.randn(6, 2)

        class NumelDefaultReturn(torch.Tensor):
            @staticmethod
            def __new__(cls, data, wrapper):
                return TestPythonDispatch.subclass_helper(
                    cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
                )

            @classmethod
            def __torch_dispatch__(cls, func, types, args, kwargs):
                if func.overloadpacket == torch.ops.aten.dim:
                    return data.dim()
                if func.overloadpacket == torch.ops.aten.numel:
                    numel_called[0] = True
                    return None
                return NotImplemented

        for use_wrapper_subclass in (False, True):
            numel_called = [False]
            e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass)
            e.data_ptr()
            self.assertTrue(numel_called[0])

    def test_layout_slow_path(self):
        for use_wrapper_subclass in [True, False]:
            data = torch.randn(6, 2)

            class LayoutNotImplemented(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_layout=True
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    return NotImplemented

            class LayoutCustomReturn(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_layout=True
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.prim.layout:
                        return torch.sparse_csr
                    return NotImplemented

            class LayoutDefaultReturn(torch.Tensor):
                @staticmethod
                def __new__(cls, data, wrapper):
                    return TestPythonDispatch.subclass_helper(
                        cls, data, wrapper, dispatch_layout=True
                    )

                @classmethod
                def __torch_dispatch__(cls, func, types, args, kwargs):
                    if func.overloadpacket == torch.ops.prim.layout:
                        return data.layout
                    return NotImplemented

            err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'"
            e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
            with self.assertRaisesRegex(TypeError, err_msg):
                e.layout

            e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
            self.assertEqual(e.layout, torch.sparse_csr)

            e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
            self.assertEqual(e.layout, torch.strided)


class TestPythonDispatcher(TestCase):
    def test_basic(self):
        x = torch.randn(2, requires_grad=True)
        r = torch._C._EnablePythonDispatcher()
        torch.add(x, x)

    def test_lstsq(self):
        a = torch.randn(4, 3)
        b = torch.rand(4, 3)
        expected_shape = torch.linalg.lstsq(a, b).solution.shape
        r = torch._C._EnablePythonDispatcher()
        python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
        self.assertEqual(expected_shape, python_disp_shape)


class TestWrapperSubclassAliasing(TestCase):
    def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
        def to_subclass(t: torch.Tensor):
            return TwoTensor(t, t.clone())

        result_ref = op(*args, **kwargs)

        args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
        kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)

        result_test = op(*args_subclass, **kwargs_subclass)

        args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
        args_ref_flat_tensors = [
            x for x in args_ref_flat if isinstance(x, torch.Tensor)
        ]

        args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
        args_test_flat_tensors = [
            x for x in args_test_flat if isinstance(x, torch.Tensor)
        ]

        result_ref_flat = pytree.tree_leaves(result_ref)
        result_ref_flat_tensors = [
            x for x in result_ref_flat if isinstance(x, torch.Tensor)
        ]

        result_test_flat = pytree.tree_leaves(result_test)
        result_test_flat_tensors = [
            x for x in result_test_flat if isinstance(x, torch.Tensor)
        ]

        for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
            for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
                out_is_inpt = o_ref is a_ref
                if out_is_inpt:
                    self.assertTrue(o_test is a_test)

                out_aliases_inpt = StorageWeakRef(
                    o_ref.untyped_storage()
                ) == StorageWeakRef(a_ref.untyped_storage())
                if out_aliases_inpt:
                    self.assertTrue(
                        StorageWeakRef(o_test.untyped_storage())
                        == StorageWeakRef(a_test.untyped_storage())
                    )
                else:
                    self.assertFalse(
                        StorageWeakRef(o_test.untyped_storage())
                        == StorageWeakRef(a_test.untyped_storage())
                    )

    # This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
    # a util for wrapper subclasses to promise correct aliasing behavior.
    # It's probably overkill to test every OpInfo,
    # so I picked a sampling of ops with representative schemas.
    @ops(
        [
            op
            for op in op_db
            if op.name
            in [
                "mul",  # out-of-place
                "cat",  # out-of-place (TensorList input)
                "index",  # out-of-place (Optional TensorList input)
                "mul_",  # inplace
                "view",  # view
                "t_",  # inplace-view
                "split",  # view (multi-return)
                "native_batch_norm",  # mutable op (returns outputs and mutates some inputs)
            ]
        ],
        allowed_dtypes=(torch.float,),
    )
    def test_wrapper_subclass_aliasing(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype)
        sample = first_sample(self, samples)
        args = (sample.input, *sample.args)
        kwargs = sample.kwargs
        self._test_wrapper_subclass_aliasing(op, args, kwargs)

    @ops(custom_op_db, allowed_dtypes=(torch.float,))
    def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
        samples = op.sample_inputs(device, dtype)
        sample = first_sample(self, samples)
        args = (sample.input, *sample.args)
        kwargs = sample.kwargs
        self._test_wrapper_subclass_aliasing(op, args, kwargs)

    def test_wrapper_subclass_aliasing_conv2d(self, device):
        args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
        kwargs = {}
        # conv2d has a default arg 'int[2] strides=0',
        # which torchscript expands into 'int[2] strides=[0, 0]'
        # Make sure that _return_and_correct_aliasing can handle this case
        # (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
        with torch.inference_mode():
            self._test_wrapper_subclass_aliasing(
                torch.ops.aten.conv2d.default, args, kwargs
            )

    def test_wrapper_subclass_aliasing_out_op(self, device):
        # Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors
        args = (torch.ones(4), torch.ones(4))
        kwargs = {"out": torch.empty(4)}
        self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)


instantiate_device_type_tests(TestWrapperSubclassAliasing, globals())

if __name__ == "__main__":
    run_tests()
