check sparse sizes (#47148)

Summary:
checks sizes of sparse tensors when comparing them in assertEqual.
Removes additional checks in safeCoalesce, safeCoalesce should not be a test for `.coalesce()` function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/47148

Reviewed By: mruberry

Differential Revision: D24823127

Pulled By: ngimel

fbshipit-source-id: 9303a6ff74aa3c9d9207803d05c0be2325fe392a
diff --git a/test/test_sparse.py b/test/test_sparse.py
index e365c94..37461da 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -14,6 +14,7 @@
 from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
 from numbers import Number
 from torch.autograd.gradcheck import gradcheck
+from typing import Dict, Any
 
 # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
 # sharding on sandcastle. This line silences flake warnings
@@ -161,7 +162,7 @@
             self.assertEqual(i, x._indices())
             self.assertEqual(v, x._values())
             self.assertEqual(x.ndimension(), len(with_size))
-            self.assertEqual(self.safeCoalesce(x)._nnz(), nnz)
+            self.assertEqual(x.coalesce()._nnz(), nnz)
             self.assertEqual(list(x.size()), with_size)
 
             # Test .indices() and .values()
@@ -183,7 +184,7 @@
         i = self.index_tensor([[9, 0, 0, 0, 8, 1, 1, 1, 2, 7, 2, 2, 3, 4, 6, 9]])
         v = self.value_tensor([[idx**2, idx] for idx in range(i.size(1))])
         x = self.sparse_tensor(i, v, torch.Size([10, 2]))
-        self.assertEqual(self.safeCoalesce(x)._nnz(), 9)
+        self.assertEqual(x.coalesce()._nnz(), 9)
 
         # Make sure we can access empty indices / values
         x = self.legacy_sparse_tensor()
@@ -191,13 +192,50 @@
         self.assertEqual(x._values().numel(), 0)
 
     def test_coalesce(self):
+
+        def _test_coalesce(x):
+            tc = t.coalesce()
+            self.assertEqual(tc.to_dense(), t.to_dense())
+            self.assertTrue(tc.is_coalesced())
+            # Our code below doesn't work when nnz is 0, because
+            # then it's a 0D tensor, not a 2D tensor.
+            if t._nnz() == 0:
+                self.assertEqual(t._indices(), tc._indices())
+                self.assertEqual(t._values(), tc._values())
+                return tc
+
+            value_map: Dict[Any, Any] = {}
+            for idx, val in zip(t._indices().t(), t._values()):
+                idx_tup = tuple(idx.tolist())
+                if idx_tup in value_map:
+                    value_map[idx_tup] += val
+                else:
+                    value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
+
+            new_indices = sorted(list(value_map.keys()))
+            _new_values = [value_map[idx] for idx in new_indices]
+            if t._values().ndimension() < 2:
+                new_values = t._values().new(_new_values)
+            else:
+                new_values = torch.stack(_new_values)
+
+            new_indices = t._indices().new(new_indices).t()
+            tg = t.new(new_indices, new_values, t.size())
+
+            self.assertEqual(tc._indices(), tg._indices())
+            self.assertEqual(tc._values(), tg._values())
+
+            if t.is_coalesced():
+                self.assertEqual(tc._indices(), t._indices())
+                self.assertEqual(tc._values(), t._values())
+
         for empty_i, empty_v, empty_nnz in itertools.product([True, False], repeat=3):
             sparse_size = [] if empty_i else [2, 1]
             dense_size = [1, 0, 2] if empty_v else [1, 2]
             nnz = 0 if empty_nnz else 5
 
             t, _, _ = self._gen_sparse(len(sparse_size), nnz, sparse_size + dense_size)
-            self.safeCoalesce(t)  # this tests correctness
+            _test_coalesce(t)  # this tests correctness
 
     def test_ctor_size_checks(self):
         indices = self.index_tensor([
@@ -399,7 +437,7 @@
 
     def test_contig(self):
         def test_tensor(x, exp_i, exp_v):
-            x = self.safeCoalesce(x)
+            x = x.coalesce()
             self.assertEqual(exp_i, x._indices())
             self.assertEqual(exp_v, x._values())
 
@@ -479,7 +517,7 @@
 
     def test_contig_hybrid(self):
         def test_tensor(x, exp_i, exp_v):
-            x = self.safeCoalesce(x)
+            x = x.coalesce()
             self.assertEqual(exp_i, x._indices())
             self.assertEqual(exp_v, x._values())
 
@@ -2065,7 +2103,9 @@
             if not x.is_cuda:
                 # CUDA sparse tensors currently requires the size to be
                 # specified if nDimV > 0
-                self.assertEqual(x.new(indices, values), x)
+                out = x.new(indices, values).coalesce()
+                x_c = x.coalesce()
+                self.assertEqual((out.indices(), out.values()), (x_c.indices(), x_c.values()))
             self.assertEqual(x.new(indices, values, x.size()), x)
 
         test_shape(3, 10, 100)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 92c5c52..df606c1 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -861,7 +861,6 @@
         if is_uncoalesced:
             v = torch.cat([v, torch.randn_like(v)], 0)
             i = torch.cat([i, i], 1)
-
         x = torch.sparse_coo_tensor(i, v, torch.Size(size))
 
         if not is_uncoalesced:
@@ -877,47 +876,7 @@
         return x, x._indices().clone(), x._values().clone()
 
     def safeToDense(self, t):
-        r = self.safeCoalesce(t)
-        return r.to_dense()
-
-    def safeCoalesce(self, t):
-        tc = t.coalesce()
-        self.assertEqual(tc.to_dense(), t.to_dense())
-        self.assertTrue(tc.is_coalesced())
-
-        # Our code below doesn't work when nnz is 0, because
-        # then it's a 0D tensor, not a 2D tensor.
-        if t._nnz() == 0:
-            self.assertEqual(t._indices(), tc._indices())
-            self.assertEqual(t._values(), tc._values())
-            return tc
-
-        value_map: Dict[Any, Any] = {}
-        for idx, val in zip(t._indices().t(), t._values()):
-            idx_tup = tuple(idx.tolist())
-            if idx_tup in value_map:
-                value_map[idx_tup] += val
-            else:
-                value_map[idx_tup] = val.clone() if isinstance(val, torch.Tensor) else val
-
-        new_indices = sorted(list(value_map.keys()))
-        _new_values = [value_map[idx] for idx in new_indices]
-        if t._values().ndimension() < 2:
-            new_values = t._values().new(_new_values)
-        else:
-            new_values = torch.stack(_new_values)
-
-        new_indices = t._indices().new(new_indices).t()
-        tg = t.new(new_indices, new_values, t.size())
-
-        self.assertEqual(tc._indices(), tg._indices())
-        self.assertEqual(tc._values(), tg._values())
-
-        if t.is_coalesced():
-            self.assertEqual(tc._indices(), t._indices())
-            self.assertEqual(tc._values(), t._values())
-
-        return tg
+        return t.coalesce().to_dense()
 
     # Compares the given Torch and NumPy functions on the given tensor-like object.
     # NOTE: both torch_fn and np_fn should be functions that take a single
@@ -1081,8 +1040,15 @@
             super().assertEqual(x.is_sparse, y.is_sparse, msg=msg)
             super().assertEqual(x.is_quantized, y.is_quantized, msg=msg)
             if x.is_sparse:
-                x = self.safeCoalesce(x)
-                y = self.safeCoalesce(y)
+                if x.size() != y.size():
+                    debug_msg_sparse = ("Attempted to compare equality of tensors with different sizes. "
+                                        f"Got sizes {x.size()} and {y.size()}.")
+                    if msg is None:
+                        msg = debug_msg_sparse
+                    self.assertTrue(False, msg=msg)
+
+                x = x.coalesce()
+                y = y.coalesce()
                 indices_result, debug_msg = self._compareTensors(x._indices(), y._indices(),
                                                                  rtol=rtol, atol=atol,
                                                                  equal_nan=equal_nan, exact_dtype=exact_dtype,