Reenable assert sanity testing with ADInplaceOrView reenable (#88102)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88102
Approved by: https://github.com/albanD
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index aea6fcb..5ed23f8 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -717,7 +717,11 @@
 
         # Some attribute queries that can be serviced directly
         # See Note [is_coalesced is dispatched]
-        if func in [torch.ops.aten.is_coalesced.default]:
+        if func in {
+            torch.ops.aten.is_coalesced.default,
+            torch.ops.aten.dense_dim.default,
+            torch.ops.aten.sparse_dim.default,
+        }:
             # NB: no_dispatch is ok here too, this func is very simple
             with in_kernel_invocation_manager(self):
                 return func(*args, **kwargs)
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 0e2bbe4..7e2039f 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -25,10 +25,11 @@
     assert a == b, f"{a} != {b}"
 
 
-def assert_metadata_eq(assert_eq, m1, m2):
+def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False):
     def go(m1, m2):
         assert_eq(m1.dtype, m2.dtype)
-        assert_eq(m1.shape, m2.shape)
+        if not skip_symbolic:
+            assert_eq(m1.shape, m2.shape)
         assert_eq(m1.requires_grad, m2.requires_grad)
         assert_eq(m1.is_leaf, m2.is_leaf)
         assert_eq(m1.grad_fn is None, m2.grad_fn is None)
@@ -38,14 +39,15 @@
         assert_eq(m1.is_neg(), m2.is_neg())
         assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None)
         if safe_grad(m1) is not None:
-            go(m1.grad, m2.grad)
+            go(safe_grad(m1), safe_grad(m2))
         if m1.is_sparse:
             assert_eq(m1.dense_dim(), m2.dense_dim())
             assert_eq(m1.sparse_dim(), m2.sparse_dim())
             assert_eq(m1.is_coalesced(), m2.is_coalesced())
         else:
-            assert_eq(m1.stride(), m2.stride())
-            assert_eq(m1.storage_offset(), m2.storage_offset())
+            if not skip_symbolic:
+                assert_eq(m1.stride(), m2.stride())
+                assert_eq(m1.storage_offset(), m2.storage_offset())
             assert_eq(m1._is_view(), m2._is_view())
             if m1._is_view():
                 go(m1._base, m2._base)
@@ -262,63 +264,83 @@
                             == real_dtype
                         )
 
-                    if base.dtype == t.dtype:
-                        pass
-                    elif is_c_of_r(base.dtype, t.dtype):
-                        base = torch.view_as_real(base)
-                    elif is_c_of_r(t.dtype, base.dtype):
-                        base = torch.view_as_complex(base)
-                    else:
-                        # This is not guaranteed to succeed.  If it fails, it
-                        # means there is another dtype-converting view function
-                        # that hasn't been handled here
-                        base = base.view(t.dtype)
+                    # In some situations, MetaConverter may be called in a
+                    # context where autograd is disabled.  For the _is_view
+                    # assert to pass, we have to setup the autograd view
+                    # metadata anyway.  Do this by reenabling the
+                    # ADInplaceOrView key.  This is kind of a hack.
+                    old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
+                        torch._C.DispatchKey.ADInplaceOrView
+                    )
+                    torch._C._dispatch_tls_set_dispatch_key_excluded(
+                        torch._C.DispatchKey.ADInplaceOrView, False
+                    )
+                    try:
 
-                    # This is very tricky.  Naively, you might expect this
-                    # to hold:
-                    #
-                    #   if t.requires_grad and not safe_is_leaf(t)
-                    #       assert t._base.requires_grad
-                    #
-                    # But it's not true!  As you can see in the following
-                    # program:
-                    #
-                    #   x = torch.zeros(4)
-                    #   y = x.view(1, 4)
-                    #   y.requires_grad = True
-                    #   z = y.view(1, 1, 4)
-                    #   assert z._base is x
-                    #
-                    # So we may have to do *two* views out of the base to
-                    # recreate this situation.
+                        if base.dtype == t.dtype:
+                            pass
+                        elif is_c_of_r(base.dtype, t.dtype):
+                            base = torch.view_as_real(base)
+                        elif is_c_of_r(t.dtype, base.dtype):
+                            base = torch.view_as_complex(base)
+                        else:
+                            # This is not guaranteed to succeed.  If it fails, it
+                            # means there is another dtype-converting view function
+                            # that hasn't been handled here
+                            base = base.view(t.dtype)
 
-                    sizes, strides = sym_sizes_strides(t)
-                    if safe_is_leaf(t):
-                        # Leaf views that track view metadata are created by
-                        # creating a view inside a no_grad block
-                        with torch.no_grad():
-                            r = base.as_strided(sizes, strides, sym(t.storage_offset()))
-                        # As it's a leaf, we can directly assign requires_grad
-                        r.requires_grad = t.requires_grad
-                    else:
-                        if t._base.requires_grad == t.requires_grad:
-                            # Easy case, just run the view op
-                            with torch.enable_grad():
+                        # This is very tricky.  Naively, you might expect this
+                        # to hold:
+                        #
+                        #   if t.requires_grad and not safe_is_leaf(t)
+                        #       assert t._base.requires_grad
+                        #
+                        # But it's not true!  As you can see in the following
+                        # program:
+                        #
+                        #   x = torch.zeros(4)
+                        #   y = x.view(1, 4)
+                        #   y.requires_grad = True
+                        #   z = y.view(1, 1, 4)
+                        #   assert z._base is x
+                        #
+                        # So we may have to do *two* views out of the base to
+                        # recreate this situation.
+
+                        sizes, strides = sym_sizes_strides(t)
+
+                        if safe_is_leaf(t):
+                            # Leaf views that track view metadata are created by
+                            # creating a view inside a no_grad block
+                            with torch.no_grad():
                                 r = base.as_strided(
                                     sizes, strides, sym(t.storage_offset())
                                 )
+                            # As it's a leaf, we can directly assign requires_grad
+                            r.requires_grad = t.requires_grad
                         else:
-                            # Obscure case.  Create a leaf view and give it the
-                            # correct requires_grad, then do the final view.
-                            # NB: Can't have a non-leaf without requiring grad!
-                            assert t.requires_grad
-                            with torch.no_grad():
-                                mid = base.view(base.shape)
-                            mid.requires_grad = t.requires_grad
-                            with torch.enable_grad():
-                                r = mid.as_strided(
-                                    sizes, strides, sym(t.storage_offset())
-                                )
+                            if t._base.requires_grad == t.requires_grad:
+                                # Easy case, just run the view op
+                                with torch.enable_grad():
+                                    r = base.as_strided(
+                                        sizes, strides, sym(t.storage_offset())
+                                    )
+                            else:
+                                # Obscure case.  Create a leaf view and give it the
+                                # correct requires_grad, then do the final view.
+                                # NB: Can't have a non-leaf without requiring grad!
+                                assert t.requires_grad
+                                with torch.no_grad():
+                                    mid = base.view(base.shape)
+                                mid.requires_grad = t.requires_grad
+                                with torch.enable_grad():
+                                    r = mid.as_strided(
+                                        sizes, strides, sym(t.storage_offset())
+                                    )
+                    finally:
+                        torch._C._dispatch_tls_set_dispatch_key_excluded(
+                            torch._C.DispatchKey.ADInplaceOrView, old_exclude
+                        )
 
                 else:
                     is_leaf = safe_is_leaf(t)
@@ -389,15 +411,12 @@
                         with maybe_fake_mgr, torch.no_grad():
                             r.set_(r_s, storage_offset, sizes, strides)
 
-                with warnings.catch_warnings():
-                    warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
-                    grad_not_none = t.grad is not None
-                if grad_not_none:
-                    r.grad = self.meta_tensor(t.grad, shape_env, callback)
+                if safe_grad(t) is not None:
+                    r.grad = self.meta_tensor(safe_grad(t), shape_env, callback)
                 torch._C._set_conj(r, t.is_conj())
                 torch._C._set_neg(r, t.is_neg())
             # This can be skipped if necessary for performance reasons
-            # assert_metadata_eq(assert_eq, t, r)
+            assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
             self.set_tensor_memo(t, r)
 
         return self.get_tensor_memo(t)
diff --git a/torchgen/model.py b/torchgen/model.py
index f87f2be..c1b906d 100644
--- a/torchgen/model.py
+++ b/torchgen/model.py
@@ -59,7 +59,7 @@
 ]
 
 # This doesn't have to be in sync with the header, it only needs to contain
-# entries that we actually use in the codegen
+# entries that we actually use in the codegen or want pyi entries for
 class DispatchKey(Enum):
     Undefined = 0
     CatchAll = Undefined
@@ -92,6 +92,7 @@
     TESTING_ONLY_GenericWrapper = auto()
     TESTING_ONLY_GenericMode = auto()
 
+    ADInplaceOrView = auto()
     Autograd = auto()
     CompositeImplicitAutograd = auto()
     CompositeImplicitAutogradNestedTensor = auto()