Add FakeCrossRef tests for backwards, Fix Layer Norm Backward Decomp (#85417)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85417
Approved by: https://github.com/ezyang
diff --git a/test/test_ops.py b/test/test_ops.py
index fcf2cdf..52cc42a 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -16,6 +16,7 @@
floating_and_complex_types_and,
all_types_and_complex_and,
)
+from test_proxy_tensor import xfail, skipOps
from torch.testing._internal.common_utils import (
TestCase,
@@ -59,6 +60,8 @@
FakeTensor,
FakeTensorMode,
)
+from torch._subclasses.fake_utils import outputs_alias_inputs
+
from torch.utils._python_dispatch import enable_torch_dispatch_mode
import torch._prims as prims
from torch._prims.context import TorchRefsMode
@@ -94,6 +97,8 @@
)
_ops_and_refs = op_db + python_ref_db
+aten = torch.ops.aten
+
# Tests that apply to all operators and aren't related to any particular
# system
@skipIfSlowGradcheckEnv
@@ -1775,8 +1780,10 @@
"nn.functional.pixel_unshuffle",
)
-fake_striding_skips = (
- "diag_embed",
+# tests which have inconsistent fake tensor stride propagation
+# XXX: no new tests should be added to this list as a result of a
+# decomp or prim, see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
+fake_tensor_stride_failing_ops = [
"fft.fft2",
"fft.fft",
"fft.fftn",
@@ -1797,11 +1804,35 @@
"fft.rfftn",
"svd",
"linalg.svd",
-)
+]
+fake_backward_xfails = fake_tensor_stride_failing_ops + [
+ "linalg.cond",
+ "linalg.matrix_norm",
+ "linalg.norm",
+ "linalg.svd",
+ "linalg.svdvals",
+ "nn.functional.binary_cross_entropy_with_logits",
+ "nn.functional.huber_loss",
+ "nn.functional.logsigmoid",
+ "nn.functional.multilabel_soft_margin_loss",
+ "pca_lowrank",
+ "roll",
+ "svd_lowrank",
+ "sgn",
+ "cholesky",
+ "linalg.eigh",
+ "symeig",
+]
+
+fake_backward_xfails = [xfail(stride_skip) for stride_skip in fake_backward_xfails] + [
+ xfail("segment_reduce", "lengths"),
+ xfail("norm", "nuc"),
+ xfail('linalg.norm', 'subgradients_at_zero'), # can accept vector inputs
+]
@skipIfSlowGradcheckEnv
-class TestFakeTensorNonErroring(TestCase):
+class TestFakeTensor(TestCase):
def _test_fake_helper(self, device, dtype, op, context):
name = op.name
if op.variant_test_name:
@@ -1834,15 +1865,6 @@
with enable_torch_dispatch_mode(mode):
res_fake = op(input, *args, **kwargs)
- def outputs_alias_inputs(outputs, inputs):
- input_storages = set()
- for out in tree_flatten(outputs)[0]:
- if isinstance(out, torch.Tensor):
- input_storages.add(out.storage()._cdata)
- for inp in tree_flatten(inputs)[0]:
- if isinstance(inp, torch.Tensor) and inp.storage()._cdata in input_storages:
- return True
- return False
for fake_out, real_out in zip(
tree_flatten(res_fake)[0], tree_flatten(res)[0]
@@ -1855,12 +1877,10 @@
# if you see a shape exception here, you may need to add
# a `dynamic_output_shape` tag to an operator
- check_strides = name not in fake_striding_skips
+ check_strides = name not in fake_tensor_stride_failing_ops
- # if there is a striding failure here as a result of adding a primtorch ref,
- # feel free to add the op to `fake_striding_skips` but please tag
- # @eellison on the pr.
- # see: https://github.com/pytorch/pytorch/issues/78050
+ # prims/decomps must correctly model strides,
+ # see https://github.com/pytorch/pytorch/issues/78050#issuecomment-1253950325
prims.utils.compare_tensor_meta(fake_out, real_out, check_strides)
if name not in aliasing_failures:
@@ -1888,12 +1908,40 @@
context = torch.cuda.amp.autocast if device == "cuda" else torch.cpu.amp.autocast
self._test_fake_helper(device, dtype, op, context)
+ @onlyCUDA
+ @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
+ @skipOps('TestFakeTensor', 'test_fake_crossref_backward', fake_backward_xfails)
+ def test_fake_crossref_backward(self, device, dtype, op):
+ # tests fake tensor property propagation through a cross ref mode
+ # on ops which support backward
+ samples = op.sample_inputs(device, dtype, requires_grad=True)
+
+ for iter, sample in enumerate(samples):
+ args = [sample.input] + list(sample.args)
+ kwargs = sample.kwargs
+
+ # skip these to speed up tests
+ common_skip_ops = (
+ aten.detach.default,
+ aten.empty_strided.default,
+ aten.copy_.default,
+ aten.is_same_size.default,
+
+ )
+ # TODO: enable check_aliasing, too many failures :/
+ with torch._subclasses.CrossRefFakeMode(ignore_op_fn=lambda fn: fn in common_skip_ops, check_aliasing=False):
+ with warnings.catch_warnings():
+ composite_compliance.compute_expected_grads(
+ op.get_op(), args, kwargs,
+ sample.output_process_fn_grad,
+ op.gradcheck_wrapper)
+
instantiate_device_type_tests(TestCommon, globals())
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu")
-instantiate_device_type_tests(TestFakeTensorNonErroring, globals())
+instantiate_device_type_tests(TestFakeTensor, globals())
instantiate_device_type_tests(TestTags, globals())
if __name__ == "__main__":
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 10e2f25..70b5c00 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -42,7 +42,10 @@
class CrossRefSparseFakeMode(torch._subclasses.CrossRefFakeMode):
def __init__(self):
- super(CrossRefSparseFakeMode, self).__init__(self.ignore_op, check_strides=False) # TODO: enable stride checking
+ super(CrossRefSparseFakeMode, self).__init__(
+ self.ignore_op, check_strides=False,
+ check_aliasing=False,
+ ) # TODO: enable stride/alias checking
# empty_like excluded for now due to sparse complex
# aten._to_dense.default this one is getting called with csc
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 7ee6ed7..b5399d7 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -839,6 +839,7 @@
def _set_neg(x: Tensor, neg: _bool) -> None: ...
def _add_meta_to_tls_dispatch_include() -> None: ...
def _remove_meta_from_tls_dispatch_include() -> None: ...
+def _has_storage(x: Tensor) -> _bool: ...
# NB: There is no Capsule type in typing, see
# https://code.activestate.com/lists/python-dev/139675/
def _to_dlpack(data: Tensor) -> Any: ... # THPModule_toDLPack
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 05dc65d..125d9d5 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1065,7 +1065,7 @@
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
grad_out_cast, input_cast, weight_cast, bias_cast = [
- x.to(computation_dtype) if x is not None else x
+ x.to(computation_dtype).contiguous() if x is not None else x
for x in (grad_out, input, weight, bias)
]
assert grad_out_cast is not None
@@ -1085,9 +1085,9 @@
M = prod(outer_dims) # type: ignore[arg-type]
if M <= 0 or N <= 0:
return (
- input.new_zeros(input_shape),
- input.new_zeros(input_shape[axis:]),
- input.new_zeros(input_shape[axis:]),
+ input.new_zeros(input_shape) if output_mask[0] else None,
+ input.new_zeros(input_shape[axis:]) if output_mask[1] else None,
+ input.new_zeros(input_shape[axis:]) if output_mask[2] else None,
)
x_hat = (input_cast - mean) * rstd
@@ -1118,7 +1118,7 @@
if len(outer_dim_indices) > 0:
d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
else:
- d_bias = grad_out_cast
+ d_bias = grad_out_cast.clone()
return (
_maybe_cast(d_input, input.dtype),
diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py
index 580d8f8..1e533a5 100644
--- a/torch/_subclasses/fake_utils.py
+++ b/torch/_subclasses/fake_utils.py
@@ -1,3 +1,4 @@
+import warnings
from typing import Callable, Union
import torch
@@ -9,17 +10,45 @@
aten = torch.ops.aten
+def outputs_alias_inputs(outputs, inputs):
+ input_storages = set()
+ for out in tree_flatten(outputs)[0]:
+ if isinstance(out, torch.Tensor) and torch._C._has_storage(out):
+ input_storages.add(out.storage()._cdata)
+ for inp in tree_flatten(inputs)[0]:
+ if (
+ isinstance(inp, torch.Tensor)
+ and torch._C._has_storage(inp)
+ and inp.storage()._cdata in input_storages
+ ):
+ return True
+ return False
+
+
+def outputs_are_inputs(outputs, inputs):
+ input_ids = set()
+ for out in tree_flatten(outputs)[0]:
+ if isinstance(out, torch.Tensor):
+ input_ids.add(id(out))
+ for inp in tree_flatten(inputs)[0]:
+ if isinstance(inp, torch.Tensor) and id(inp) in input_ids:
+ return True
+ return False
+
+
class CrossRefFakeMode(TorchDispatchMode):
def __init__(
self,
ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
*,
check_strides=True,
+ check_aliasing=True,
):
self.ignore_op_fn = (
ignore_op_fn if ignore_op_fn is not None else lambda fn: False
)
self.check_strides = check_strides
+ self.check_aliasing = check_aliasing
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
@@ -50,16 +79,48 @@
fake_args, fake_kwargs = pytree.tree_map_only(
torch.Tensor, fake_mode.from_tensor, (args, kwargs)
)
- fake_r = func(*fake_args, **fake_kwargs)
+ with warnings.catch_warnings():
+ fake_r = func(*fake_args, **fake_kwargs)
except UnsupportedFakeTensorException:
pass
r = func(*args, **kwargs)
if fake_r is not None:
+ r_flat, _ = tree_flatten(r)
+ f_flat, _ = tree_flatten(fake_r)
+ assert len(r_flat) == len(
+ r_flat
+ ), f"Mismatch {len(r_flat)} != {len(r_flat)} on {func}"
+
+ if self.check_aliasing:
+ r_aliasing = outputs_alias_inputs(r, (args, kwargs))
+ f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
+ assert (
+ r_aliasing == f_aliasing
+ ), f"Mismatch on {func}: {r_aliasing} != {f_aliasing}"
+
+ r_identity_eq = outputs_are_inputs(r, (args, kwargs))
+ f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
+ assert (
+ r_identity_eq == f_identity_eq
+ ), f"Mismatch on {func}: {r_identity_eq} != {f_identity_eq}"
+
for r_out, fake_out in zip(tree_flatten(r)[0], tree_flatten(fake_r)[0]):
- r_ten = isinstance(r_out, torch.Tensor)
- assert r_ten == isinstance(fake_out, torch.Tensor)
- if r_ten:
+ r_is_ten = isinstance(r_out, torch.Tensor)
+ assert r_is_ten == isinstance(
+ fake_out, torch.Tensor
+ ), f"Mismatched number of tensor outputs on {func}"
+ if r_is_ten:
+ assert (
+ r_out.requires_grad == fake_out.requires_grad
+ ), f"Mismatch on {func}"
+ if torch._C._has_storage(r_out):
+ r_offset = r_out.storage_offset()
+ f_offset = fake_out.storage_offset()
+ assert (
+ r_offset == f_offset
+ ), f"Mismatch on {func}: {r_offset} != {f_offset}"
+
try:
torch._prims.utils.compare_tensor_meta(
r_out, fake_out, check_strides=self.check_strides
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 5a7263d..06b7014 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -1269,6 +1269,8 @@
py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
return toString(x.key_set());
});
+ py_module.def(
+ "_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
py_module.def("_add_meta_to_tls_dispatch_include", []() {
auto local_keyset = c10::impl::tls_local_dispatch_key_set();
diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py
index dc423eb..dadc967 100644
--- a/torch/testing/_internal/composite_compliance.py
+++ b/torch/testing/_internal/composite_compliance.py
@@ -399,6 +399,26 @@
return leaf_tensors
+def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
+ if gradcheck_wrapper is None:
+ results = op(*args, **kwargs)
+ else:
+ results = gradcheck_wrapper(op, *args, **kwargs)
+
+ if output_process_fn_grad is not None:
+ results = output_process_fn_grad(results)
+
+ flat_results, _ = tree_flatten(results)
+ flat_diff_results = [r for r in flat_results if r.requires_grad]
+ assert len(flat_diff_results) > 0
+
+ grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
+ leaf_tensors = gather_leaf_tensors(args, kwargs)
+ assert len(leaf_tensors) > 0
+ return torch.autograd.grad(flat_diff_results, leaf_tensors,
+ grads, allow_unused=True, retain_graph=True)
+
+
# Checks if the backward formula is composite compliant by testing
# all possible permutations of {inputs, grad_outputs} being
# CompositeCompliantTensor or regular Tensors.
@@ -411,27 +431,7 @@
gradcheck_wrapper=None, assert_equal_fn=None):
CCT = generate_cct()
- def compute_expected_grads(args, kwargs):
- if gradcheck_wrapper is None:
- results = op(*args, **kwargs)
- else:
- results = gradcheck_wrapper(op, *args, **kwargs)
-
- if output_process_fn_grad is not None:
- results = output_process_fn_grad(results)
-
- flat_results, _ = tree_flatten(results)
- flat_diff_results = [r for r in flat_results if r.requires_grad]
- assert len(flat_diff_results) > 0
-
- grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
- for r in flat_diff_results]
- leaf_tensors = gather_leaf_tensors(args, kwargs)
- assert len(leaf_tensors) > 0
- return torch.autograd.grad(flat_diff_results, leaf_tensors,
- grads, allow_unused=True, retain_graph=True)
-
- expected = compute_expected_grads(args, kwargs)
+ expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT):
new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice