Add sparse compressed fake tensor support (#120920)
As in the title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920
Approved by: https://github.com/ezyang
diff --git a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128 b/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64 b/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64
+++ /dev/null
diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py
index 07623f3..f15a690 100644
--- a/test/nn/test_embedding.py
+++ b/test/nn/test_embedding.py
@@ -703,7 +703,7 @@
torch._embedding_bag_forward_only
)
for i, f in enumerate(funcs):
- err_type = ValueError if i == 0 else RuntimeError
+ err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError
weight = torch.full((2, 6,), 0, dtype=torch.float64, device=device)
indices = torch.full((2, 0, 0, 6, 6,), 2, dtype=torch.int64, device=device)
diff --git a/test/test_meta.py b/test/test_meta.py
index 10f670c..a7f4325 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -8,7 +8,7 @@
from torch.overrides import resolve_name
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
-from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
+from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
import torch.utils._python_dispatch
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import OpOverload, OpOverloadPacket
@@ -490,7 +490,9 @@
return self.s
def go(t):
- if isinstance(t, torch.Tensor):
+ if is_sparse_any(t):
+ return t
+ elif isinstance(t, torch.Tensor):
return Lit(f"{t} stride={t.stride()}")
else:
return t
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 321020b..94cc4ce 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4379,6 +4379,56 @@
self.maxDiff = orig_maxDiff
raise
+ @all_sparse_layouts('layout', include_strided=False)
+ @parametrize("dtype", [torch.float64])
+ def test_fake(self, dtype, layout):
+ from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
+ fake_mode = FakeTensorMode()
+ index_dtype = torch.int64
+ device = 'cpu'
+ for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
+ f = FakeTensor.from_tensor(t, fake_mode)
+ self.assertIsInstance(f, FakeTensor)
+ self.assertEqual(f.layout, layout)
+ self.assertEqual(f.shape, t.shape)
+ self.assertEqual(f.device, t.device)
+ if layout is torch.sparse_coo:
+ nnz = 0
+ indices = f._indices()
+ self.assertEqual(indices.dtype, index_dtype)
+ self.assertEqual(indices.device, t.device)
+ self.assertEqual(indices.shape, (*t._indices().shape[:-1], nnz))
+ values = f._values()
+ self.assertEqual(values.dtype, dtype)
+ self.assertEqual(values.device, t.device)
+ self.assertEqual(values.shape, (nnz, *t._values().shape[1:]))
+ else:
+ nnz = 0
+ if layout in {torch.sparse_csr, torch.sparse_bsr}:
+ f_compressed_indices, f_plain_indices = f.crow_indices(), f.col_indices()
+ compressed_indices, plain_indices = t.crow_indices(), t.col_indices()
+ else:
+ f_compressed_indices, f_plain_indices = f.ccol_indices(), f.row_indices()
+ compressed_indices, plain_indices = t.ccol_indices(), t.row_indices()
+ f_values = f.values()
+ values = t.values()
+ batch_dims = len(compressed_indices.shape) - 1
+ self.assertEqual(f_compressed_indices.layout, compressed_indices.layout)
+ self.assertEqual(f_compressed_indices.shape, compressed_indices.shape)
+ self.assertEqual(f_compressed_indices.dtype, compressed_indices.dtype)
+ self.assertEqual(f_compressed_indices.device, compressed_indices.device)
+
+ self.assertEqual(f_plain_indices.layout, plain_indices.layout)
+ self.assertEqual(f_plain_indices.shape, (*plain_indices.shape[:-1], nnz))
+ self.assertEqual(f_plain_indices.dtype, plain_indices.dtype)
+ self.assertEqual(f_plain_indices.device, plain_indices.device)
+
+ batch_dim = plain_indices.ndim - 1
+ self.assertEqual(f_values.layout, values.layout)
+ self.assertEqual(f_values.shape, (*values.shape[:batch_dim], nnz, *values.shape[batch_dim + 1:]))
+ self.assertEqual(f_values.dtype, values.dtype)
+ self.assertEqual(f_values.device, values.device)
+
class _SparseDataset(torch.utils.data.Dataset):
# An utility class used in TestSparseAny.test_dataloader method.
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 368babd..14454c3 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -27,6 +27,7 @@
from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
+from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
@@ -1059,6 +1060,11 @@
):
unimplemented("torch.compile does not support strided NestedTensor")
+ if is_sparse_any(value):
+ unimplemented(
+ f"torch.compile does not support sparse Tensor with {value.layout} layout"
+ )
+
tensor_variable = wrap_fx_proxy(
tx=self.tx,
proxy=tensor_proxy,
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index cac26c3..95b22e5 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -25,6 +25,7 @@
import torch.fx
import torch.random
from torch._dynamo import compiled_autograd
+from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import (
guard_scalar,
@@ -149,7 +150,11 @@
"is_sparse": value.is_sparse,
"class_type": type(value),
}
- if not has_free_symbols(value):
+ if is_sparse_any(value) and not has_free_symbols(value):
+ props["size"] = tuple(
+ [int(s) if is_symbolic(s) else s for s in value.size()]
+ )
+ elif not has_free_symbols(value):
# this is a fully static shape, and the keys on props here inform specialization.
# We have to cast to int here, because these might get accessed as ConstantVariable, which has
# a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 0ec490b..9871c42 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -23,6 +23,7 @@
assert_eq,
assert_metadata_eq,
is_sparse_any,
+ is_sparse_compressed,
MetaConverter,
)
from torch._utils import render_call
@@ -1057,6 +1058,8 @@
raise _BypassDispatchCache("constant attribute")
if arg.is_sparse:
raise _BypassDispatchCache("sparse tensor")
+ if is_sparse_compressed(arg):
+ raise _BypassDispatchCache("sparse compressed tensor")
result.append(extract_tensor_metadata(arg))
elif isinstance(arg, torch.Tensor):
raise _BypassDispatchCache("non-fake tensor")
@@ -1099,6 +1102,9 @@
if output.is_sparse:
raise _BypassDispatchCache("sparse output")
+ if is_sparse_compressed(output):
+ raise _BypassDispatchCache("sparse compressed output")
+
# Can an in-place op really reference a kwarg? If so, then we need
# to extend the implementation to handle it.
for kval in kwargs.values():
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index f629f66..fc23516 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -152,7 +152,7 @@
# hold a weak ref to self, otherwise it will be kept alive
# by the del_ten closure
self_weak_ref = weakref.ref(self)
- if t.is_sparse or t.is_mkldnn or is_functorch_wrapped_tensor(t):
+ if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
weak_st = None
else:
weak_st = StorageWeakRef(t._typed_storage())
@@ -298,6 +298,10 @@
with torch.inference_mode(t.is_inference()):
if t.is_sparse:
is_leaf = safe_is_leaf(t)
+
+ # The lambda function below is similar to
+ # `t.to(device='meta')` except the latter
+ # preserves nnz value
r = callback(
lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
t.sparse_dim(),
@@ -321,6 +325,64 @@
with torch.enable_grad():
r = r.clone()
r._coalesced_(t.is_coalesced())
+ elif is_sparse_compressed(t):
+ is_leaf = safe_is_leaf(t)
+
+ def mk_meta():
+ nnz = 0
+ batch_dim = t.ndim - t.sparse_dim() - t.dense_dim()
+ batch_size = t.shape[:batch_dim]
+ if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
+ index_dtype = t.crow_indices().dtype
+ compressed_indices = torch.empty(
+ t.crow_indices().shape, device="meta", dtype=index_dtype
+ )
+ plain_indices = torch.empty(
+ (*t.col_indices().shape[:-1], nnz),
+ device="meta",
+ dtype=index_dtype,
+ )
+ else:
+ index_dtype = t.ccol_indices().dtype
+ compressed_indices = torch.empty(
+ t.ccol_indices().shape, device="meta", dtype=index_dtype
+ )
+ plain_indices = torch.empty(
+ (*t.row_indices().shape[:-1], nnz),
+ device="meta",
+ dtype=index_dtype,
+ )
+ values_shape = t.values().shape
+ values = torch.empty(
+ (
+ *values_shape[:batch_dim],
+ nnz,
+ *values_shape[batch_dim + 1 :],
+ ),
+ dtype=t.dtype,
+ device="meta",
+ )
+ return torch.ops.aten.sparse_compressed_tensor(
+ compressed_indices,
+ plain_indices,
+ values,
+ t.shape,
+ layout=t.layout,
+ dtype=t.dtype,
+ device="meta",
+ )
+
+ # `mk_meta()` is similar to `t.to(device='meta'))`
+ # except `to('meta')` preserves nnz value while
+ # `mk_meta` result has nnz == 0.
+ r = callback(mk_meta)
+
+ assert safe_is_leaf(r), "the callback you passed in doesn't detach"
+ if t.requires_grad:
+ r.requires_grad = True
+ if t.requires_grad and not is_leaf:
+ with torch.enable_grad():
+ r = r.clone()
elif t.is_nested and not is_traceable_wrapper_subclass(t):
# TODO: Handle this better in Dynamo?
# There are checks there now, but this can still be triggered by a dense
@@ -679,8 +741,6 @@
if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
if t.device.type != "xla" and any(
[
- t.is_sparse_csr,
- t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
t.is_quantized,
t._is_view() and t._base is not None and t._base.is_sparse,
torch._is_functional_tensor(t),
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index a5ccc9f..0471649 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -280,6 +280,8 @@
yield val
elif isinstance(val, (int, float, bool)):
pass
+ elif is_sparse_any(val):
+ yield from _iterate_exprs(val.size())
elif isinstance(val, torch.Tensor):
yield from _iterate_exprs(val.size())
yield from _iterate_exprs(val.stride())
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 2d8ecd8..9d9b7a2 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11793,6 +11793,12 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+ # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend.
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
)),
OpInfo('sparse.mm',
dtypes=floating_types_and(torch.bfloat16),
@@ -11836,6 +11842,10 @@
DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
# ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
+ # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+ DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
)),
UnaryUfuncInfo('i0',
ref=np_unary_ufunc_integer_promotion_wrapper(