Reenable assert sanity testing with ADInplaceOrView reenable (#88102)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88102
Approved by: https://github.com/albanD
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index aea6fcb..5ed23f8 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -717,7 +717,11 @@
# Some attribute queries that can be serviced directly
# See Note [is_coalesced is dispatched]
- if func in [torch.ops.aten.is_coalesced.default]:
+ if func in {
+ torch.ops.aten.is_coalesced.default,
+ torch.ops.aten.dense_dim.default,
+ torch.ops.aten.sparse_dim.default,
+ }:
# NB: no_dispatch is ok here too, this func is very simple
with in_kernel_invocation_manager(self):
return func(*args, **kwargs)
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 0e2bbe4..7e2039f 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -25,10 +25,11 @@
assert a == b, f"{a} != {b}"
-def assert_metadata_eq(assert_eq, m1, m2):
+def assert_metadata_eq(assert_eq, m1, m2, *, skip_symbolic=False):
def go(m1, m2):
assert_eq(m1.dtype, m2.dtype)
- assert_eq(m1.shape, m2.shape)
+ if not skip_symbolic:
+ assert_eq(m1.shape, m2.shape)
assert_eq(m1.requires_grad, m2.requires_grad)
assert_eq(m1.is_leaf, m2.is_leaf)
assert_eq(m1.grad_fn is None, m2.grad_fn is None)
@@ -38,14 +39,15 @@
assert_eq(m1.is_neg(), m2.is_neg())
assert_eq(safe_grad(m1) is not None, safe_grad(m2) is not None)
if safe_grad(m1) is not None:
- go(m1.grad, m2.grad)
+ go(safe_grad(m1), safe_grad(m2))
if m1.is_sparse:
assert_eq(m1.dense_dim(), m2.dense_dim())
assert_eq(m1.sparse_dim(), m2.sparse_dim())
assert_eq(m1.is_coalesced(), m2.is_coalesced())
else:
- assert_eq(m1.stride(), m2.stride())
- assert_eq(m1.storage_offset(), m2.storage_offset())
+ if not skip_symbolic:
+ assert_eq(m1.stride(), m2.stride())
+ assert_eq(m1.storage_offset(), m2.storage_offset())
assert_eq(m1._is_view(), m2._is_view())
if m1._is_view():
go(m1._base, m2._base)
@@ -262,63 +264,83 @@
== real_dtype
)
- if base.dtype == t.dtype:
- pass
- elif is_c_of_r(base.dtype, t.dtype):
- base = torch.view_as_real(base)
- elif is_c_of_r(t.dtype, base.dtype):
- base = torch.view_as_complex(base)
- else:
- # This is not guaranteed to succeed. If it fails, it
- # means there is another dtype-converting view function
- # that hasn't been handled here
- base = base.view(t.dtype)
+ # In some situations, MetaConverter may be called in a
+ # context where autograd is disabled. For the _is_view
+ # assert to pass, we have to setup the autograd view
+ # metadata anyway. Do this by reenabling the
+ # ADInplaceOrView key. This is kind of a hack.
+ old_exclude = torch._C._dispatch_tls_is_dispatch_key_excluded(
+ torch._C.DispatchKey.ADInplaceOrView
+ )
+ torch._C._dispatch_tls_set_dispatch_key_excluded(
+ torch._C.DispatchKey.ADInplaceOrView, False
+ )
+ try:
- # This is very tricky. Naively, you might expect this
- # to hold:
- #
- # if t.requires_grad and not safe_is_leaf(t)
- # assert t._base.requires_grad
- #
- # But it's not true! As you can see in the following
- # program:
- #
- # x = torch.zeros(4)
- # y = x.view(1, 4)
- # y.requires_grad = True
- # z = y.view(1, 1, 4)
- # assert z._base is x
- #
- # So we may have to do *two* views out of the base to
- # recreate this situation.
+ if base.dtype == t.dtype:
+ pass
+ elif is_c_of_r(base.dtype, t.dtype):
+ base = torch.view_as_real(base)
+ elif is_c_of_r(t.dtype, base.dtype):
+ base = torch.view_as_complex(base)
+ else:
+ # This is not guaranteed to succeed. If it fails, it
+ # means there is another dtype-converting view function
+ # that hasn't been handled here
+ base = base.view(t.dtype)
- sizes, strides = sym_sizes_strides(t)
- if safe_is_leaf(t):
- # Leaf views that track view metadata are created by
- # creating a view inside a no_grad block
- with torch.no_grad():
- r = base.as_strided(sizes, strides, sym(t.storage_offset()))
- # As it's a leaf, we can directly assign requires_grad
- r.requires_grad = t.requires_grad
- else:
- if t._base.requires_grad == t.requires_grad:
- # Easy case, just run the view op
- with torch.enable_grad():
+ # This is very tricky. Naively, you might expect this
+ # to hold:
+ #
+ # if t.requires_grad and not safe_is_leaf(t)
+ # assert t._base.requires_grad
+ #
+ # But it's not true! As you can see in the following
+ # program:
+ #
+ # x = torch.zeros(4)
+ # y = x.view(1, 4)
+ # y.requires_grad = True
+ # z = y.view(1, 1, 4)
+ # assert z._base is x
+ #
+ # So we may have to do *two* views out of the base to
+ # recreate this situation.
+
+ sizes, strides = sym_sizes_strides(t)
+
+ if safe_is_leaf(t):
+ # Leaf views that track view metadata are created by
+ # creating a view inside a no_grad block
+ with torch.no_grad():
r = base.as_strided(
sizes, strides, sym(t.storage_offset())
)
+ # As it's a leaf, we can directly assign requires_grad
+ r.requires_grad = t.requires_grad
else:
- # Obscure case. Create a leaf view and give it the
- # correct requires_grad, then do the final view.
- # NB: Can't have a non-leaf without requiring grad!
- assert t.requires_grad
- with torch.no_grad():
- mid = base.view(base.shape)
- mid.requires_grad = t.requires_grad
- with torch.enable_grad():
- r = mid.as_strided(
- sizes, strides, sym(t.storage_offset())
- )
+ if t._base.requires_grad == t.requires_grad:
+ # Easy case, just run the view op
+ with torch.enable_grad():
+ r = base.as_strided(
+ sizes, strides, sym(t.storage_offset())
+ )
+ else:
+ # Obscure case. Create a leaf view and give it the
+ # correct requires_grad, then do the final view.
+ # NB: Can't have a non-leaf without requiring grad!
+ assert t.requires_grad
+ with torch.no_grad():
+ mid = base.view(base.shape)
+ mid.requires_grad = t.requires_grad
+ with torch.enable_grad():
+ r = mid.as_strided(
+ sizes, strides, sym(t.storage_offset())
+ )
+ finally:
+ torch._C._dispatch_tls_set_dispatch_key_excluded(
+ torch._C.DispatchKey.ADInplaceOrView, old_exclude
+ )
else:
is_leaf = safe_is_leaf(t)
@@ -389,15 +411,12 @@
with maybe_fake_mgr, torch.no_grad():
r.set_(r_s, storage_offset, sizes, strides)
- with warnings.catch_warnings():
- warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
- grad_not_none = t.grad is not None
- if grad_not_none:
- r.grad = self.meta_tensor(t.grad, shape_env, callback)
+ if safe_grad(t) is not None:
+ r.grad = self.meta_tensor(safe_grad(t), shape_env, callback)
torch._C._set_conj(r, t.is_conj())
torch._C._set_neg(r, t.is_neg())
# This can be skipped if necessary for performance reasons
- # assert_metadata_eq(assert_eq, t, r)
+ assert_metadata_eq(assert_eq, t, r, skip_symbolic=True)
self.set_tensor_memo(t, r)
return self.get_tensor_memo(t)
diff --git a/torchgen/model.py b/torchgen/model.py
index f87f2be..c1b906d 100644
--- a/torchgen/model.py
+++ b/torchgen/model.py
@@ -59,7 +59,7 @@
]
# This doesn't have to be in sync with the header, it only needs to contain
-# entries that we actually use in the codegen
+# entries that we actually use in the codegen or want pyi entries for
class DispatchKey(Enum):
Undefined = 0
CatchAll = Undefined
@@ -92,6 +92,7 @@
TESTING_ONLY_GenericWrapper = auto()
TESTING_ONLY_GenericMode = auto()
+ ADInplaceOrView = auto()
Autograd = auto()
CompositeImplicitAutograd = auto()
CompositeImplicitAutogradNestedTensor = auto()