Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
diff --git a/test/test_ops.py b/test/test_ops.py
index fcf2cdf..52cc42a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -16,6 +16,7 @@
     floating_and_complex_types_and,
     all_types_and_complex_and,
 )
+from test_proxy_tensor import xfail, skipOps
 
 from torch.testing._internal.common_utils import (
     TestCase,
@@ -59,6 +60,8 @@
     FakeTensor,
     FakeTensorMode,
 )
+from torch._subclasses.fake_utils import outputs_alias_inputs
+
 from torch.utils._python_dispatch import enable_torch_dispatch_mode
 import torch._prims as prims
 from torch._prims.context import TorchRefsMode
@@ -94,6 +97,8 @@
 )
 _ops_and_refs = op_db + python_ref_db
 
+aten = torch.ops.aten
+
 # Tests that apply to all operators and aren't related to any particular
 #   system
 @skipIfSlowGradcheckEnv
@@ -1775,8 +1780,10 @@
     "nn.functional.pixel_unshuffle",
 )
 
-fake_striding_skips = (
-    "diag_embed",
+# tests which have inconsistent fake tensor stride propagation
+# XXX: no new tests should be added to this list as a result of a
+# decomp or prim, see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
+fake_tensor_stride_failing_ops = [
     "fft.fft2",
     "fft.fft",
     "fft.fftn",
@@ -1797,11 +1804,35 @@
     "fft.rfftn",
     "svd",
     "linalg.svd",
-)
+]
 
+fake_backward_xfails = fake_tensor_stride_failing_ops + [
+    "linalg.cond",
+    "linalg.matrix_norm",
+    "linalg.norm",
+    "linalg.svd",
+    "linalg.svdvals",
+    "nn.functional.binary_cross_entropy_with_logits",
+    "nn.functional.huber_loss",
+    "nn.functional.logsigmoid",
+    "nn.functional.multilabel_soft_margin_loss",
+    "pca_lowrank",
+    "roll",
+    "svd_lowrank",
+    "sgn",
+    "cholesky",
+    "linalg.eigh",
+    "symeig",
+]
+
+fake_backward_xfails = [xfail(stride_skip) for stride_skip in fake_backward_xfails] + [
+    xfail("segment_reduce", "lengths"),
+    xfail("norm", "nuc"),
+    xfail('linalg.norm', 'subgradients_at_zero'),  # can accept vector inputs
+]
 
 @skipIfSlowGradcheckEnv
-class TestFakeTensorNonErroring(TestCase):
+class TestFakeTensor(TestCase):
     def _test_fake_helper(self, device, dtype, op, context):
         name = op.name
         if op.variant_test_name:
@@ -1834,15 +1865,6 @@
                     with enable_torch_dispatch_mode(mode):
                         res_fake = op(input, *args, **kwargs)
 
-                def outputs_alias_inputs(outputs, inputs):
-                    input_storages = set()
-                    for out in tree_flatten(outputs)[0]:
-                        if isinstance(out, torch.Tensor):
-                            input_storages.add(out.storage()._cdata)
-                    for inp in tree_flatten(inputs)[0]:
-                        if isinstance(inp, torch.Tensor) and inp.storage()._cdata in input_storages:
-                            return True
-                    return False
 
                 for fake_out, real_out in zip(
                     tree_flatten(res_fake)[0], tree_flatten(res)[0]
@@ -1855,12 +1877,10 @@
                     # if you see a shape exception here, you may need to add
                     # a `dynamic_output_shape` tag to an operator
 
-                    check_strides = name not in fake_striding_skips
+                    check_strides = name not in fake_tensor_stride_failing_ops
 
-                    # if there is a striding failure here as a result of adding a primtorch ref,
-                    # feel free to add the op to `fake_striding_skips` but please tag
-                    # @eellison on the pr.
-                    # see: https://github.com/pytorch/pytorch/issues/78050
+                    # prims/decomps must correctly model strides,
+                    # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
                     prims.utils.compare_tensor_meta(fake_out, real_out, check_strides)
 
                     if name not in aliasing_failures:
@@ -1888,12 +1908,40 @@
         context = torch.cuda.amp.autocast if device == "cuda" else torch.cpu.amp.autocast
         self._test_fake_helper(device, dtype, op, context)
 
+    @onlyCUDA
+    @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
+    @skipOps('TestFakeTensor', 'test_fake_crossref_backward', fake_backward_xfails)
+    def test_fake_crossref_backward(self, device, dtype, op):
+        # tests fake tensor property propagation through a cross ref mode
+        # on ops which support backward
+        samples = op.sample_inputs(device, dtype, requires_grad=True)
+
+        for iter, sample in enumerate(samples):
+            args = [sample.input] + list(sample.args)
+            kwargs = sample.kwargs
+
+            # skip these to speed up tests
+            common_skip_ops = (
+                aten.detach.default,
+                aten.empty_strided.default,
+                aten.copy_.default,
+                aten.is_same_size.default,
+
+            )
+            # TODO: enable check_aliasing, too many failures :/
+            with torch._subclasses.CrossRefFakeMode(ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=False):
+                with warnings.catch_warnings():
+                    composite_compliance.compute_expected_grads(
+                        op.get_op(), args, kwargs,
+                        sample.output_process_fn_grad,
+                        op.gradcheck_wrapper)
+
 
 instantiate_device_type_tests(TestCommon, globals())
 instantiate_device_type_tests(TestCompositeCompliance, globals())
 instantiate_device_type_tests(TestMathBits, globals())
 instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
-instantiate_device_type_tests(TestFakeTensorNonErroring, globals())
+instantiate_device_type_tests(TestFakeTensor, globals())
 instantiate_device_type_tests(TestTags, globals())
 
 if __name__ == "__main__":
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 10e2f25..70b5c00 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -42,7 +42,10 @@
 
 class CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode):
     def __init__(self):
-        super(CrossRefSparseFakeMode, self).__init__(self.ignore_op, check_strides=False)  # TODO: enable stride checking
+        super(CrossRefSparseFakeMode, self).__init__(
+            self.ignore_op, check_strides=False,
+            check_aliasing=False,
+        )  # TODO: enable stride/alias checking
 
     # empty_like excluded for now due to sparse complex
     # aten._to_dense.default this one is getting called with csc
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 7ee6ed7..b5399d7 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -839,6 +839,7 @@
 def _set_neg(x: Tensor, neg: _bool) -> None: ...
 def _add_meta_to_tls_dispatch_include() -> None: ...
 def _remove_meta_from_tls_dispatch_include() -> None: ...
+def _has_storage(x: Tensor) -> _bool: ...
 # NB: There is no Capsule type in typing, see
 # https://code.activestate.com/lists/python-dev/139675/
 def _to_dlpack(data: Tensor) -> Any: ...  # THPModule_toDLPack
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 05dc65d..125d9d5 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1065,7 +1065,7 @@
     input_ndim = input.dim()
     computation_dtype = utils.get_computation_dtype(input.dtype)
     grad_out_cast, input_cast, weight_cast, bias_cast = [
-        x.to(computation_dtype) if x is not None else x
+        x.to(computation_dtype).contiguous() if x is not None else x
         for x in (grad_out, input, weight, bias)
     ]
     assert grad_out_cast is not None
@@ -1085,9 +1085,9 @@
     M = prod(outer_dims)  # type: ignore[arg-type]
     if M <= 0 or N <= 0:
         return (
-            input.new_zeros(input_shape),
-            input.new_zeros(input_shape[axis:]),
-            input.new_zeros(input_shape[axis:]),
+            input.new_zeros(input_shape) if output_mask[0] else None,
+            input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
+            input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
         )
 
     x_hat = (input_cast - mean) * rstd
@@ -1118,7 +1118,7 @@
         if len(outer_dim_indices) > 0:
             d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
         else:
-            d_bias = grad_out_cast
+            d_bias = grad_out_cast.clone()
 
     return (
         _maybe_cast(d_input, input.dtype),
diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py
index 580d8f8..1e533a5 100644
--- a/torch/_subclasses/fake_utils.py
+++ b/torch/_subclasses/fake_utils.py
@@ -1,3 +1,4 @@
+import warnings
 from typing import Callable, Union
 
 import torch
@@ -9,17 +10,45 @@
 aten = torch.ops.aten
 
 
+def outputs_alias_inputs(outputs, inputs):
+    input_storages = set()
+    for out in tree_flatten(outputs)[0]:
+        if isinstance(out, torch.Tensor) and torch._C._has_storage(out):
+            input_storages.add(out.storage()._cdata)
+    for inp in tree_flatten(inputs)[0]:
+        if (
+            isinstance(inp, torch.Tensor)
+            and torch._C._has_storage(inp)
+            and inp.storage()._cdata in input_storages
+        ):
+            return True
+    return False
+
+
+def outputs_are_inputs(outputs, inputs):
+    input_ids = set()
+    for out in tree_flatten(outputs)[0]:
+        if isinstance(out, torch.Tensor):
+            input_ids.add(id(out))
+    for inp in tree_flatten(inputs)[0]:
+        if isinstance(inp, torch.Tensor) and id(inp) in input_ids:
+            return True
+    return False
+
+
 class CrossRefFakeMode(TorchDispatchMode):
     def __init__(
         self,
         ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
         *,
         check_strides=True,
+        check_aliasing=True,
     ):
         self.ignore_op_fn = (
             ignore_op_fn if ignore_op_fn is not None else lambda fn: False
         )
         self.check_strides = check_strides
+        self.check_aliasing = check_aliasing
 
     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
         kwargs = kwargs or {}
@@ -50,16 +79,48 @@
                     fake_args, fake_kwargs = pytree.tree_map_only(
                         torch.Tensor, fake_mode.from_tensor, (args, kwargs)
                     )
-                    fake_r = func(*fake_args, **fake_kwargs)
+                    with warnings.catch_warnings():
+                        fake_r = func(*fake_args, **fake_kwargs)
             except UnsupportedFakeTensorException:
                 pass
 
         r = func(*args, **kwargs)
         if fake_r is not None:
+            r_flat, _ = tree_flatten(r)
+            f_flat, _ = tree_flatten(fake_r)
+            assert len(r_flat) == len(
+                r_flat
+            ), f"Mismatch {len(r_flat)} != {len(r_flat)} on {func}"
+
+            if self.check_aliasing:
+                r_aliasing = outputs_alias_inputs(r, (args, kwargs))
+                f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
+                assert (
+                    r_aliasing == f_aliasing
+                ), f"Mismatch on {func}: {r_aliasing} != {f_aliasing}"
+
+                r_identity_eq = outputs_are_inputs(r, (args, kwargs))
+                f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
+                assert (
+                    r_identity_eq == f_identity_eq
+                ), f"Mismatch on {func}: {r_identity_eq} != {f_identity_eq}"
+
             for r_out, fake_out in zip(tree_flatten(r)[0], tree_flatten(fake_r)[0]):
-                r_ten = isinstance(r_out, torch.Tensor)
-                assert r_ten == isinstance(fake_out, torch.Tensor)
-                if r_ten:
+                r_is_ten = isinstance(r_out, torch.Tensor)
+                assert r_is_ten == isinstance(
+                    fake_out, torch.Tensor
+                ), f"Mismatched number of tensor outputs on {func}"
+                if r_is_ten:
+                    assert (
+                        r_out.requires_grad == fake_out.requires_grad
+                    ), f"Mismatch on {func}"
+                    if torch._C._has_storage(r_out):
+                        r_offset = r_out.storage_offset()
+                        f_offset = fake_out.storage_offset()
+                        assert (
+                            r_offset == f_offset
+                        ), f"Mismatch on {func}: {r_offset} != {f_offset}"
+
                     try:
                         torch._prims.utils.compare_tensor_meta(
                             r_out, fake_out, check_strides=self.check_strides
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 5a7263d..06b7014 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -1269,6 +1269,8 @@
   py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
     return toString(x.key_set());
   });
+  py_module.def(
+      "_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
 
   py_module.def("_add_meta_to_tls_dispatch_include", []() {
     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py
index dc423eb..dadc967 100644
--- a/torch/testing/_internal/composite_compliance.py
+++ b/torch/testing/_internal/composite_compliance.py
@@ -399,6 +399,26 @@
     return leaf_tensors
 
 
+def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
+    if gradcheck_wrapper is None:
+        results = op(*args, **kwargs)
+    else:
+        results = gradcheck_wrapper(op, *args, **kwargs)
+
+    if output_process_fn_grad is not None:
+        results = output_process_fn_grad(results)
+
+    flat_results, _ = tree_flatten(results)
+    flat_diff_results = [r for r in flat_results if r.requires_grad]
+    assert len(flat_diff_results) > 0
+
+    grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
+    leaf_tensors = gather_leaf_tensors(args, kwargs)
+    assert len(leaf_tensors) > 0
+    return torch.autograd.grad(flat_diff_results, leaf_tensors,
+                               grads, allow_unused=True, retain_graph=True)
+
+
 # Checks if the backward formula is composite compliant by testing
 # all possible permutations of {inputs, grad_outputs} being
 # CompositeCompliantTensor or regular Tensors.
@@ -411,27 +431,7 @@
                            gradcheck_wrapper=None, assert_equal_fn=None):
     CCT = generate_cct()
 
-    def compute_expected_grads(args, kwargs):
-        if gradcheck_wrapper is None:
-            results = op(*args, **kwargs)
-        else:
-            results = gradcheck_wrapper(op, *args, **kwargs)
-
-        if output_process_fn_grad is not None:
-            results = output_process_fn_grad(results)
-
-        flat_results, _ = tree_flatten(results)
-        flat_diff_results = [r for r in flat_results if r.requires_grad]
-        assert len(flat_diff_results) > 0
-
-        grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
-                 for r in flat_diff_results]
-        leaf_tensors = gather_leaf_tensors(args, kwargs)
-        assert len(leaf_tensors) > 0
-        return torch.autograd.grad(flat_diff_results, leaf_tensors,
-                                   grads, allow_unused=True, retain_graph=True)
-
-    expected = compute_expected_grads(args, kwargs)
+    expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
 
     for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
         new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice