Fix incorrect sparse_dim in COO.zero_() and in binary operations with zero-sized COO operands (#98292)
Fixes https://github.com/pytorch/pytorch/issues/97627
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98292
Approved by: https://github.com/nikitaved, https://github.com/cpuhrsch, https://github.com/amjames
diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
index 161220c..c29eaa5 100644
--- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp
@@ -99,7 +99,7 @@
// hummu hummu
SparseTensor& zero_sparse_(SparseTensor& self) {
AT_ASSERT(self.is_sparse());
- at::zeros_out(self, get_sparse_impl(self)->sizes());
+ self.sparse_resize_and_clear_(self.sizes(), self.sparse_dim(), self.dense_dim());
return self._coalesced_(true);
}
@@ -863,12 +863,17 @@
// Short-circuit if either s_ or d is empty.
if (!s_._nnz() || !s_.numel() || !d.numel()) {
- const auto sparse_dim = static_cast<int64_t>(res_shape.size());
- const auto indices = at::empty({sparse_dim, 0}, s_._indices().options());
- const auto values = at::empty({0}, s_._values().options().dtype(res.scalar_type()));
- get_sparse_impl(res)->raw_resize_(sparse_dim, /*dense_dim=*/0, /*size=*/res_shape);
- get_sparse_impl(res)->set_indices_and_values_unsafe(indices, values);
- get_sparse_impl(res)->set_nnz_and_narrow(0);
+ const int64_t dense_dim = s_.dense_dim();
+ const int64_t sparse_dim = static_cast<int64_t>(res_shape.size()) - dense_dim;
+ const int64_t nnz = 0;
+ const auto indices = at::empty({sparse_dim, nnz}, s_._indices().options());
+ auto res_values_shape = s_._values().sizes().vec();
+ res_values_shape[0] = nnz;
+ const auto values = at::empty(res_values_shape, s_._values().options().dtype(res.scalar_type()));
+ auto* res_impl = get_sparse_impl(res);
+ res_impl->raw_resize_(sparse_dim, dense_dim, /*size=*/res_shape);
+ res_impl->set_indices_and_values_unsafe(indices, values);
+ res_impl->set_nnz_and_narrow(nnz);
return res._coalesced_(true);
}
@@ -900,9 +905,10 @@
// op(s.values, d).dtype == <common dtype>.
const auto values = op(d_filtered, s_values);
const auto res_values = is_same_tensor(s_, res) ? values : values.to(res.scalar_type());
- get_sparse_impl(res)->raw_resize_(sparse_dim, dense_dim, res_shape);
- get_sparse_impl(res)->set_indices_and_values_unsafe(res_indices, res_values);
- get_sparse_impl(res)->set_nnz_and_narrow(s._nnz());
+ auto* res_impl = get_sparse_impl(res);
+ res_impl->raw_resize_(sparse_dim, dense_dim, res_shape);
+ res_impl->set_indices_and_values_unsafe(res_indices, res_values);
+ res_impl->set_nnz_and_narrow(s._nnz());
return res._coalesced_(s.is_coalesced());
};
@@ -1000,10 +1006,10 @@
return indices;
}();
-
- get_sparse_impl(res)->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
- get_sparse_impl(res)->set_indices_and_values_unsafe(res_indices, res_values);
- get_sparse_impl(res)->set_nnz_and_narrow(res_nnz);
+ auto* res_impl = get_sparse_impl(res);
+ res_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
+ res_impl->set_indices_and_values_unsafe(res_indices, res_values);
+ res_impl->set_nnz_and_narrow(res_nnz);
// By design of index expansion and that s is coalesced,
// the result is also coalesced.
return res._coalesced_(true);
@@ -1066,7 +1072,6 @@
AT_ASSERT(!t_.is_cuda()); // dispatch argument
TORCH_CHECK(!r.is_cuda(), "mul: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(!src_.is_cuda(), "mul: expected 'other' to be a CPU tensor, but got a CUDA tensor");
-
// case mul(sparse, dense)
if (!src_.is_sparse()) {
return _mul_dense_sparse_out(src_, t_, r);
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 3652d58..c6690bc 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -4774,24 +4774,16 @@
for sample in op.sample_inputs_sparse(layout, device, dtype):
t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs
batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim()
-
result = op.op(t_inp, *t_args, **t_kwargs)
# Check rop(inp, ...).shape == inp.shape
self.assertEqual(result.shape, t_inp.shape)
- if layout is torch.sparse_coo and t_inp.numel() == 0 and op.name == 'mul' and t_inp.dense_dim() > 0:
- # BUG: gh-97627
- with self.assertRaisesRegex(
- AssertionError,
- "Scalars are not equal!"):
- self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
- else:
- # Check rop(inp, ...).sparse_dim() == inp.sparse_dim()
- self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
+ # Check rop(inp, ...).sparse_dim() == inp.sparse_dim()
+ self.assertEqual(result.sparse_dim(), t_inp.sparse_dim())
- # Check rop(inp, ...).dense_dim() == inp.dense_dim()
- self.assertEqual(result.dense_dim(), t_inp.dense_dim())
+ # Check rop(inp, ...).dense_dim() == inp.dense_dim()
+ self.assertEqual(result.dense_dim(), t_inp.dense_dim())
# Check invariant rop(inp, ...).to_dense() == rop(inp.to_dense(), ...)
try:
diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py
index 17ba097..fdb177f 100644
--- a/torch/testing/_internal/opinfo/definitions/sparse.py
+++ b/torch/testing/_internal/opinfo/definitions/sparse.py
@@ -569,19 +569,6 @@
sample,
error_regex="crow_indices is supposed to be a vector, but got 2 dimensional tensor",
)
- elif (
- layout is torch.sparse_csr
- and t_inp.dense_dim() > 0
- and t_inp._nnz() == 0
- and t_args[0].ndim > 0
- ):
- return ErrorInput(
- sample,
- error_regex=(
- "Only tensors with two sparse dimensions can be converted to the SparseCsr layout"
- ", got self with 3 sparse dimensions."
- ),
- )
elif layout is torch.sparse_csc and t_args[0].ndim > 0:
return ErrorInput(
sample, error_regex="Expected result Tensor to be of format CSR"