Add sparse compressed fake tensor support (#120920)

As in the title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120920
Approved by: https://github.com/ezyang
diff --git a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128 b/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_complex128
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64 b/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestAutogradDeviceTypeCPU.test_sparse_ctor_getter_backward_cpu_float64
+++ /dev/null
diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py
index 07623f3..f15a690 100644
--- a/test/nn/test_embedding.py
+++ b/test/nn/test_embedding.py
@@ -703,7 +703,7 @@
             torch._embedding_bag_forward_only
         )
         for i, f in enumerate(funcs):
-            err_type = ValueError if i == 0 else RuntimeError
+            err_type = (ValueError, RuntimeError) if i == 0 else RuntimeError
 
             weight = torch.full((2, 6,), 0, dtype=torch.float64, device=device)
             indices = torch.full((2, 0, 0, 6, 6,), 2, dtype=torch.int64, device=device)
diff --git a/test/test_meta.py b/test/test_meta.py
index 10f670c..a7f4325 100644
--- a/test/test_meta.py
+++ b/test/test_meta.py
@@ -8,7 +8,7 @@
 from torch.overrides import resolve_name
 from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
 from torch.utils import _pytree as pytree
-from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq
+from torch._subclasses.meta_utils import MetaConverter, assert_metadata_eq, is_sparse_any
 import torch.utils._python_dispatch
 from torch._dispatch.python import enable_python_dispatcher
 from torch._ops import OpOverload, OpOverloadPacket
@@ -490,7 +490,9 @@
             return self.s
 
     def go(t):
-        if isinstance(t, torch.Tensor):
+        if is_sparse_any(t):
+            return t
+        elif isinstance(t, torch.Tensor):
             return Lit(f"{t} stride={t.stride()}")
         else:
             return t
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 321020b..94cc4ce 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4379,6 +4379,56 @@
             self.maxDiff = orig_maxDiff
             raise
 
+    @all_sparse_layouts('layout', include_strided=False)
+    @parametrize("dtype", [torch.float64])
+    def test_fake(self, dtype, layout):
+        from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor
+        fake_mode = FakeTensorMode()
+        index_dtype = torch.int64
+        device = 'cpu'
+        for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
+            f = FakeTensor.from_tensor(t, fake_mode)
+            self.assertIsInstance(f, FakeTensor)
+            self.assertEqual(f.layout, layout)
+            self.assertEqual(f.shape, t.shape)
+            self.assertEqual(f.device, t.device)
+            if layout is torch.sparse_coo:
+                nnz = 0
+                indices = f._indices()
+                self.assertEqual(indices.dtype, index_dtype)
+                self.assertEqual(indices.device, t.device)
+                self.assertEqual(indices.shape, (*t._indices().shape[:-1], nnz))
+                values = f._values()
+                self.assertEqual(values.dtype, dtype)
+                self.assertEqual(values.device, t.device)
+                self.assertEqual(values.shape, (nnz, *t._values().shape[1:]))
+            else:
+                nnz = 0
+                if layout in {torch.sparse_csr, torch.sparse_bsr}:
+                    f_compressed_indices, f_plain_indices = f.crow_indices(), f.col_indices()
+                    compressed_indices, plain_indices = t.crow_indices(), t.col_indices()
+                else:
+                    f_compressed_indices, f_plain_indices = f.ccol_indices(), f.row_indices()
+                    compressed_indices, plain_indices = t.ccol_indices(), t.row_indices()
+                f_values = f.values()
+                values = t.values()
+                batch_dims = len(compressed_indices.shape) - 1
+                self.assertEqual(f_compressed_indices.layout, compressed_indices.layout)
+                self.assertEqual(f_compressed_indices.shape, compressed_indices.shape)
+                self.assertEqual(f_compressed_indices.dtype, compressed_indices.dtype)
+                self.assertEqual(f_compressed_indices.device, compressed_indices.device)
+
+                self.assertEqual(f_plain_indices.layout, plain_indices.layout)
+                self.assertEqual(f_plain_indices.shape, (*plain_indices.shape[:-1], nnz))
+                self.assertEqual(f_plain_indices.dtype, plain_indices.dtype)
+                self.assertEqual(f_plain_indices.device, plain_indices.device)
+
+                batch_dim = plain_indices.ndim - 1
+                self.assertEqual(f_values.layout, values.layout)
+                self.assertEqual(f_values.shape, (*values.shape[:batch_dim], nnz, *values.shape[batch_dim + 1:]))
+                self.assertEqual(f_values.dtype, values.dtype)
+                self.assertEqual(f_values.device, values.device)
+
 
 class _SparseDataset(torch.utils.data.Dataset):
     # An utility class used in TestSparseAny.test_dataloader method.
diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py
index 368babd..14454c3 100644
--- a/torch/_dynamo/variables/builder.py
+++ b/torch/_dynamo/variables/builder.py
@@ -27,6 +27,7 @@
 from torch._ops import HigherOrderOperator
 from torch._streambase import _EventBase, _StreamBase
 from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
+from torch._subclasses.meta_utils import is_sparse_any
 from torch.fx.experimental._backward_state import BackwardState
 from torch.fx.experimental.symbolic_shapes import (
     _constrain_range_for_size,
@@ -1059,6 +1060,11 @@
         ):
             unimplemented("torch.compile does not support strided NestedTensor")
 
+        if is_sparse_any(value):
+            unimplemented(
+                f"torch.compile does not support sparse Tensor with {value.layout} layout"
+            )
+
         tensor_variable = wrap_fx_proxy(
             tx=self.tx,
             proxy=tensor_proxy,
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index cac26c3..95b22e5 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -25,6 +25,7 @@
 import torch.fx
 import torch.random
 from torch._dynamo import compiled_autograd
+from torch._subclasses.meta_utils import is_sparse_any
 
 from torch.fx.experimental.symbolic_shapes import (
     guard_scalar,
@@ -149,7 +150,11 @@
             "is_sparse": value.is_sparse,
             "class_type": type(value),
         }
-        if not has_free_symbols(value):
+        if is_sparse_any(value) and not has_free_symbols(value):
+            props["size"] = tuple(
+                [int(s) if is_symbolic(s) else s for s in value.size()]
+            )
+        elif not has_free_symbols(value):
             # this is a fully static shape, and the keys on props here inform specialization.
             # We have to cast to int here, because these might get accessed as ConstantVariable, which has
             # a strict no-symint policy. If we got here due to not having free symbols, this is a known constant
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 0ec490b..9871c42 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -23,6 +23,7 @@
     assert_eq,
     assert_metadata_eq,
     is_sparse_any,
+    is_sparse_compressed,
     MetaConverter,
 )
 from torch._utils import render_call
@@ -1057,6 +1058,8 @@
                     raise _BypassDispatchCache("constant attribute")
                 if arg.is_sparse:
                     raise _BypassDispatchCache("sparse tensor")
+                if is_sparse_compressed(arg):
+                    raise _BypassDispatchCache("sparse compressed tensor")
                 result.append(extract_tensor_metadata(arg))
             elif isinstance(arg, torch.Tensor):
                 raise _BypassDispatchCache("non-fake tensor")
@@ -1099,6 +1102,9 @@
         if output.is_sparse:
             raise _BypassDispatchCache("sparse output")
 
+        if is_sparse_compressed(output):
+            raise _BypassDispatchCache("sparse compressed output")
+
         # Can an in-place op really reference a kwarg? If so, then we need
         # to extend the implementation to handle it.
         for kval in kwargs.values():
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index f629f66..fc23516 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -152,7 +152,7 @@
         # hold a weak ref to self, otherwise it will be kept alive
         # by the del_ten closure
         self_weak_ref = weakref.ref(self)
-        if t.is_sparse or t.is_mkldnn or is_functorch_wrapped_tensor(t):
+        if is_sparse_any(t) or t.is_mkldnn or is_functorch_wrapped_tensor(t):
             weak_st = None
         else:
             weak_st = StorageWeakRef(t._typed_storage())
@@ -298,6 +298,10 @@
             with torch.inference_mode(t.is_inference()):
                 if t.is_sparse:
                     is_leaf = safe_is_leaf(t)
+
+                    # The lambda function below is similar to
+                    # `t.to(device='meta')` except the latter
+                    # preserves nnz value
                     r = callback(
                         lambda: torch.ops.aten._sparse_coo_tensor_with_dims(
                             t.sparse_dim(),
@@ -321,6 +325,64 @@
                         with torch.enable_grad():
                             r = r.clone()
                             r._coalesced_(t.is_coalesced())
+                elif is_sparse_compressed(t):
+                    is_leaf = safe_is_leaf(t)
+
+                    def mk_meta():
+                        nnz = 0
+                        batch_dim = t.ndim - t.sparse_dim() - t.dense_dim()
+                        batch_size = t.shape[:batch_dim]
+                        if t.layout in {torch.sparse_csr, torch.sparse_bsr}:
+                            index_dtype = t.crow_indices().dtype
+                            compressed_indices = torch.empty(
+                                t.crow_indices().shape, device="meta", dtype=index_dtype
+                            )
+                            plain_indices = torch.empty(
+                                (*t.col_indices().shape[:-1], nnz),
+                                device="meta",
+                                dtype=index_dtype,
+                            )
+                        else:
+                            index_dtype = t.ccol_indices().dtype
+                            compressed_indices = torch.empty(
+                                t.ccol_indices().shape, device="meta", dtype=index_dtype
+                            )
+                            plain_indices = torch.empty(
+                                (*t.row_indices().shape[:-1], nnz),
+                                device="meta",
+                                dtype=index_dtype,
+                            )
+                        values_shape = t.values().shape
+                        values = torch.empty(
+                            (
+                                *values_shape[:batch_dim],
+                                nnz,
+                                *values_shape[batch_dim + 1 :],
+                            ),
+                            dtype=t.dtype,
+                            device="meta",
+                        )
+                        return torch.ops.aten.sparse_compressed_tensor(
+                            compressed_indices,
+                            plain_indices,
+                            values,
+                            t.shape,
+                            layout=t.layout,
+                            dtype=t.dtype,
+                            device="meta",
+                        )
+
+                    # `mk_meta()` is similar to `t.to(device='meta'))`
+                    # except `to('meta')` preserves nnz value while
+                    # `mk_meta` result has nnz == 0.
+                    r = callback(mk_meta)
+
+                    assert safe_is_leaf(r), "the callback you passed in doesn't detach"
+                    if t.requires_grad:
+                        r.requires_grad = True
+                    if t.requires_grad and not is_leaf:
+                        with torch.enable_grad():
+                            r = r.clone()
                 elif t.is_nested and not is_traceable_wrapper_subclass(t):
                     # TODO: Handle this better in Dynamo?
                     # There are checks there now, but this can still be triggered by a dense
@@ -679,8 +741,6 @@
         if isinstance(t, torch.Tensor) or is_traceable_wrapper_subclass(t):
             if t.device.type != "xla" and any(
                 [
-                    t.is_sparse_csr,
-                    t.layout in [torch.sparse_csc, torch.sparse_bsr, torch.sparse_bsc],
                     t.is_quantized,
                     t._is_view() and t._base is not None and t._base.is_sparse,
                     torch._is_functional_tensor(t),
diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py
index a5ccc9f..0471649 100644
--- a/torch/fx/experimental/symbolic_shapes.py
+++ b/torch/fx/experimental/symbolic_shapes.py
@@ -280,6 +280,8 @@
         yield val
     elif isinstance(val, (int, float, bool)):
         pass
+    elif is_sparse_any(val):
+        yield from _iterate_exprs(val.size())
     elif isinstance(val, torch.Tensor):
         yield from _iterate_exprs(val.size())
         yield from _iterate_exprs(val.stride())
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 2d8ecd8..9d9b7a2 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -11793,6 +11793,12 @@
                DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_gradgrad'),
                # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
                DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
+               # NotImplementedError: Could not run 'aten::sparse_sampled_addmm' with arguments from the 'SparseCsrMeta' backend.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace_all_strides'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestFakeTensor', 'test_fake_crossref_backward_no_amp'),
            )),
     OpInfo('sparse.mm',
            dtypes=floating_types_and(torch.bfloat16),
@@ -11836,6 +11842,10 @@
                DecorateInfo(unittest.skip("Skipped!"), 'TestFwdGradients', 'test_forward_mode_AD'),
                # ValueError: Sparse output is not supported at gradcheck yet. Please call to_dense(masked_grad=...) ...
                DecorateInfo(unittest.skip("Skipped!"), 'TestBwdGradients', 'test_fn_fail_gradgrad'),
+               # NotImplementedError: Could not run 'aten::_sparse_mm_reduce_impl' with arguments from the 'SparseCsrMeta' backend
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_dispatch_symbolic_meta_outplace'),
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMeta', 'test_meta_outplace'),
            )),
     UnaryUfuncInfo('i0',
                    ref=np_unary_ufunc_integer_promotion_wrapper(