should_check_strides (#85416)
This PR ports `should_check_strides` checks from `origin/symbolic-shapes` to `master` as the part of our dynamic shapes landing effort.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85416
Approved by: https://github.com/ezyang
diff --git a/test/test_meta.py b/test/test_meta.py
index 652b531..ca702ff 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -185,7 +185,26 @@
del m
self.assertIs(ref(), None)
-def assert_ref_meta_equal(test_case, meta_rs, rs, msg_callable):
+CHECK_STRIDES = {
+ torch.Tensor.__getitem__,
+}
+
+def should_check_strides(func):
+ if func in CHECK_STRIDES:
+ return True
+ if not isinstance(func, torch._ops.OpOverload):
+ return False
+ # Prims are expected to model strides correctly
+ if func.namespace == "prims":
+ return True
+ # Check if it's a view, by testing if any of the returns have
+ # a non-empty alias set
+ if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info):
+ return True
+ # TODO: check for TensorIterator
+ return False
+
+def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
flat_meta_rs, _ = tree_flatten(meta_rs)
flat_rs, _ = tree_flatten(rs)
test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
@@ -200,8 +219,9 @@
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
# NOTE: stride checking is currently disabled
# See https://github.com/pytorch/pytorch/issues/78050
- # same_strides, _ = prims.utils.check_significant_strides(meta_r, r)
- # test_assert(same_strides, f"but real stride was {r.stride()}")
+ if should_check_strides(func):
+ same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
+ test_assert(same_strides, f"but real stride was {r.stride()}")
test_assert(
meta_r.storage_offset() == r.storage_offset(),
f"but real storage_offset was {r.storage_offset()}")
@@ -363,7 +383,7 @@
else:
try:
delim = ',\n '
- assert_ref_meta_equal(test_case, meta_rs, rs, lambda msg: f"""\
+ assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\
meta disagrees with real impl:
{resolve_name(func)}(
{delim.join(map(verbose_print, meta_args))},
diff --git a/torch/_ops.py b/torch/_ops.py
index 4fa22fe..0c9478a 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -260,6 +260,10 @@
def __str__(self):
return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
+ @property
+ def namespace(self):
+ return self._schema.name.split("::")[0]
+
def decompose(self, *args, **kwargs):
dk = torch._C.DispatchKey.CompositeImplicitAutograd
if dk in self.py_kernels: