Fix incorrect sparse_dim in COO.zero_() and in binary operations with zero-sized COO operands (#98292)

Fixes https://github.com/pytorch/pytorch/issues/97627

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98292
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/amjames
diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
index 161220c..c29eaa5 100644
--- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
@@ -99,7 +99,7 @@
 // hummu hummu
 SparseTensor& zero_sparse_(SparseTensor& self) {
   AT_ASSERT(self.is_sparse());
-  at::zeros_out(self, get_sparse_impl(self)->sizes());
+  self.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
   return self._coalesced_(true);
 }
 
@@ -863,12 +863,17 @@
 
   // Short-circuit if either s_ or d is empty.
   if (!s_._nnz() || !s_.numel() || !d.numel()) {
-    const auto sparse_dim = static_cast<int64_t>(res_shape.size());
-    const auto indices = at::empty({sparse_dim, 0}, s_._indices().options());
-    const auto values = at::empty({0}, s_._values().options().dtype(res.scalar_type()));
-    get_sparse_impl(res)->raw_resize_(sparse_dim, /*dense_dim=*/0, /*size=*/res_shape);
-    get_sparse_impl(res)->set_indices_and_values_unsafe(indices, values);
-    get_sparse_impl(res)->set_nnz_and_narrow(0);
+    const int64_t dense_dim = s_.dense_dim();
+    const int64_t sparse_dim = static_cast<int64_t>(res_shape.size()) - dense_dim;
+    const int64_t nnz = 0;
+    const auto indices = at::empty({sparse_dim, nnz}, s_._indices().options());
+    auto res_values_shape = s_._values().sizes().vec();
+    res_values_shape[0] = nnz;
+    const auto values = at::empty(res_values_shape, s_._values().options().dtype(res.scalar_type()));
+    auto* res_impl = get_sparse_impl(res);
+    res_impl->raw_resize_(sparse_dim, dense_dim, /*size=*/res_shape);
+    res_impl->set_indices_and_values_unsafe(indices, values);
+    res_impl->set_nnz_and_narrow(nnz);
     return res._coalesced_(true);
   }
 
@@ -900,9 +905,10 @@
     // op(s.values, d).dtype == <common dtype>.
     const auto values = op(d_filtered, s_values);
     const auto res_values = is_same_tensor(s_, res) ? values : values.to(res.scalar_type());
-    get_sparse_impl(res)->raw_resize_(sparse_dim, dense_dim, res_shape);
-    get_sparse_impl(res)->set_indices_and_values_unsafe(res_indices, res_values);
-    get_sparse_impl(res)->set_nnz_and_narrow(s._nnz());
+    auto* res_impl = get_sparse_impl(res);
+    res_impl->raw_resize_(sparse_dim, dense_dim, res_shape);
+    res_impl->set_indices_and_values_unsafe(res_indices, res_values);
+    res_impl->set_nnz_and_narrow(s._nnz());
     return res._coalesced_(s.is_coalesced());
   };
 
@@ -1000,10 +1006,10 @@
 
     return indices;
   }();
-
-  get_sparse_impl(res)->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
-  get_sparse_impl(res)->set_indices_and_values_unsafe(res_indices, res_values);
-  get_sparse_impl(res)->set_nnz_and_narrow(res_nnz);
+  auto* res_impl = get_sparse_impl(res);
+  res_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
+  res_impl->set_indices_and_values_unsafe(res_indices, res_values);
+  res_impl->set_nnz_and_narrow(res_nnz);
   // By design of index expansion and that s is coalesced,
   // the result is also coalesced.
   return res._coalesced_(true);
@@ -1066,7 +1072,6 @@
   AT_ASSERT(!t_.is_cuda()); // dispatch argument
   TORCH_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor");
   TORCH_CHECK(!src_.is_cuda(), "mul: expected 'other' to be a CPU tensor, but got a CUDA tensor");
-
   // case mul(sparse, dense)
   if (!src_.is_sparse()) {
     return _mul_dense_sparse_out(src_, t_, r);
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 3652d58..c6690bc 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4774,24 +4774,16 @@
         for sample in op.sample_inputs_sparse(layout, device, dtype):
             t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
             batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
-
             result = op.op(t_inp, *t_args, **t_kwargs)
 
             # Check rop(inp, ...).shape == inp.shape
             self.assertEqual(result.shape, t_inp.shape)
 
-            if layout is torch.sparse_coo and t_inp.numel() == 0 and op.name == 'mul' and t_inp.dense_dim() > 0:
-                # BUG: gh-97627
-                with self.assertRaisesRegex(
-                        AssertionError,
-                        "Scalars are not equal!"):
-                    self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
-            else:
-                # Check rop(inp, ...).sparse_dim() == inp.sparse_dim()
-                self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
+            # Check rop(inp, ...).sparse_dim() == inp.sparse_dim()
+            self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
 
-                # Check rop(inp, ...).dense_dim() == inp.dense_dim()
-                self.assertEqual(result.dense_dim(), t_inp.dense_dim())
+            # Check rop(inp, ...).dense_dim() == inp.dense_dim()
+            self.assertEqual(result.dense_dim(), t_inp.dense_dim())
 
             # Check invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...)
             try:
diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py
index 17ba097..fdb177f 100644
--- a/torch/testing/_internal/opinfo/definitions/sparse.py
+++ b/torch/testing/_internal/opinfo/definitions/sparse.py
@@ -569,19 +569,6 @@
             sample,
             error_regex="crow_indices is supposed to be a vector, but got 2 dimensional tensor",
         )
-    elif (
-        layout is torch.sparse_csr
-        and t_inp.dense_dim() > 0
-        and t_inp._nnz() == 0
-        and t_args[0].ndim > 0
-    ):
-        return ErrorInput(
-            sample,
-            error_regex=(
-                "Only tensors with two sparse dimensions can be converted to the SparseCsr layout"
-                ", got self with 3 sparse dimensions."
-            ),
-        )
     elif layout is torch.sparse_csc and t_args[0].ndim > 0:
         return ErrorInput(
             sample, error_regex="Expected result Tensor to be of format CSR"