should_check_strides (#85416)

This PR ports `should_check_strides` checks from `origin/symbolic-shapes` to `master` as the part of our dynamic shapes landing effort.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85416
Approved by: https://github.com/ezyang
diff --git a/test/test_meta.py b/test/test_meta.py
index 652b531..ca702ff 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -185,7 +185,26 @@
         del m
         self.assertIs(ref(), None)
 
-def assert_ref_meta_equal(test_case, meta_rs, rs, msg_callable):
+CHECK_STRIDES = {
+    torch.Tensor.__getitem__,
+}
+
+def should_check_strides(func):
+    if func in CHECK_STRIDES:
+        return True
+    if not isinstance(func, torch._ops.OpOverload):
+        return False
+    # Prims are expected to model strides correctly
+    if func.namespace == "prims":
+        return True
+    # Check if it's a view, by testing if any of the returns have
+    # a non-empty alias set
+    if any(r.alias_info.before_set for r in func._schema.returns if r.alias_info):
+        return True
+    # TODO: check for TensorIterator
+    return False
+
+def assert_ref_meta_equal(test_case, func, meta_rs, rs, msg_callable):
     flat_meta_rs, _ = tree_flatten(meta_rs)
     flat_rs, _ = tree_flatten(rs)
     test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
@@ -200,8 +219,9 @@
         test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
         # NOTE: stride checking is currently disabled
         # See https://github.com/pytorch/pytorch/issues/78050
-        # same_strides, _ = prims.utils.check_significant_strides(meta_r, r)
-        # test_assert(same_strides, f"but real stride was {r.stride()}")
+        if should_check_strides(func):
+            same_strides, _ = torch._prims_common.check_significant_strides(meta_r, r)
+            test_assert(same_strides, f"but real stride was {r.stride()}")
         test_assert(
             meta_r.storage_offset() == r.storage_offset(),
             f"but real storage_offset was {r.storage_offset()}")
@@ -363,7 +383,7 @@
         else:
             try:
                 delim = ',\n  '
-                assert_ref_meta_equal(test_case, meta_rs, rs, lambda msg: f"""\
+                assert_ref_meta_equal(test_case, func, meta_rs, rs, lambda msg: f"""\
 meta disagrees with real impl:
 {resolve_name(func)}(
   {delim.join(map(verbose_print, meta_args))},
diff --git a/torch/_ops.py b/torch/_ops.py
index 4fa22fe..0c9478a 100644
--- a/torch/_ops.py
+++ b/torch/_ops.py
@@ -260,6 +260,10 @@
     def __str__(self):
         return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname)
 
+    @property
+    def namespace(self):
+        return self._schema.name.split("::")[0]
+
     def decompose(self, *args, **kwargs):
         dk = torch._C.DispatchKey.CompositeImplicitAutograd
         if dk in self.py_kernels: