Revert "Extend CSR constructor to support batched indices and values"
This reverts commit c074a530029cd2fbbbe74c325842bc1aef6a8ac4.
Reverted https://github.com/pytorch/pytorch/pull/74542 on behalf of https://github.com/malfet
diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp
index 7808f75..2029189 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.cpp
+++ b/aten/src/ATen/SparseCsrTensorImpl.cpp
@@ -60,22 +60,17 @@
}
void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
- auto rows = size[size.size() - 2];
- auto cols = size[size.size() - 1];
+ auto rows = size[0];
+ auto cols = size[1];
auto old_crow_indices_size = crow_indices_.size(-1);
-
- auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
- new_crow_indices_size.push_back(rows + 1);
- crow_indices_.resize_(new_crow_indices_size);
+ crow_indices_.resize_({rows + 1});
if (rows + 1 >= old_crow_indices_size) {
crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
} else {
crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
}
- auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
- col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
- col_indices_.resize_(col_indices_values_size);
- values_.resize_(col_indices_values_size);
+ col_indices_.resize_({std::min<int64_t>(nnz, rows*cols)});
+ values_.resize_({std::min<int64_t>(nnz, rows*cols)});
sizes_and_strides_.set_sizes(size);
}
diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h
index b90176f..850e0a0 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.h
+++ b/aten/src/ATen/SparseCsrTensorImpl.h
@@ -43,7 +43,7 @@
const Tensor& crow_indices() const { return crow_indices_; }
const Tensor& col_indices() const { return col_indices_; }
const Tensor& values() const { return values_; }
- int nnz() { return col_indices_.size(-1); }
+ int nnz() { return values_.size(0); }
/**
* Return a TensorImpl that is a shallow-copy of this TensorImpl.
diff --git a/aten/src/ATen/mkl/SparseDescriptors.h b/aten/src/ATen/mkl/SparseDescriptors.h
index 2c152e0..46d6568 100644
--- a/aten/src/ATen/mkl/SparseDescriptors.h
+++ b/aten/src/ATen/mkl/SparseDescriptors.h
@@ -101,7 +101,7 @@
sparse_matrix_t raw_descriptor;
// Assuming that the last two dimensions are block elements of the matrix
- if (values.dim() == 3 && crow_indices.dim() == 1 && col_indices.dim() == 1) {
+ if (values.dim() == 3) {
TORCH_CHECK(
values.size(-1) == values.size(-2),
"MKL Sparse doesn't support matrices with non-square blocks.");
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 3df5beb..f91d964 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -9,7 +9,6 @@
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorImpl.h>
-#include <ATen/native/LinearAlgebraUtils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@@ -57,51 +56,29 @@
// Shape and Strides invariants
TORCH_CHECK(
- size.size() >= 2,
- "size of a batched CSR tensor must have length >= 2, but got: ",
+ size.size() == 2,
+ "size of a CSR tensor must be of length 2, but got: ",
size.size());
TORCH_CHECK(
- crow_indices.dim() >= 1,
- "crow_indices must have dim >= 1 but got crow_indices.dim() = ",
+ crow_indices.dim() == 1,
+ "crow_indices must have dim=1 but got crow_indices.dim()=",
crow_indices.dim());
TORCH_CHECK(
- col_indices.dim() >= 1,
- "col_indices must have dim >= 1 but got col_indices.dim() = ",
+ col_indices.dim() == 1,
+ "col_indices must have dim=1 but got col_indices.dim()=",
col_indices.dim());
TORCH_CHECK(
- values.dim() >= 1,
- "values must have dim >= 1 but got values.dim() = ",
+ values.dim() == 1,
+ "values must have dim=1 but got values.dim()=",
values.dim());
-
+ // Note, this check also enforces `crow_indices.numel() >= 1`
TORCH_CHECK(
- crow_indices.dim() == col_indices.dim(),
- "Number of dimensions of crow_indices and col_indices must be the same.");
- TORCH_CHECK(
- crow_indices.dim() == values.dim(),
- "Number of dimensions of indices and values must be the same.");
- TORCH_CHECK(
- crow_indices.dim() == size.size() - 1,
- "Number of dimensions of indices must be one less than the number of dimensions of the provided size.");
-
- // All batch sizes must be the same
- auto batch_size = size.slice(0, size.size() - 2);
- auto crow_indices_batch_size = crow_indices.sizes().slice(0, crow_indices.dim() - 1);
- auto col_indices_batch_size = col_indices.sizes().slice(0, col_indices.dim() - 1);
- auto values_batch_size = values.sizes().slice(0, values.dim() - 1);
- TORCH_CHECK(
- batch_size == crow_indices_batch_size &&
- batch_size == col_indices_batch_size &&
- batch_size == values_batch_size,
- "All batch dimensions of the provided size, indices, and values must be the same.");
-
- // Note, this check also enforces `crow_indices.size(-1) >= 1`
- TORCH_CHECK(
- crow_indices.size(-1) == (size[size.size() - 2] + 1),
- "crow_indices.size(-1) must be equal to size[-2] + 1 (that is ", size[size.size() - 2] + 1, "), but got: ",
- crow_indices.size(-1));
+ crow_indices.numel() == (size[0] + 1),
+ "crow_indices.numel() must be size(0) + 1, but got: ",
+ crow_indices.numel());
TORCH_CHECK(
col_indices.numel() == values.numel(),
- "col_indices and values must have the same number of elements, but got col_indices.numel(): ",
+ "col_indices and values must have equal sizes, but got col_indices.numel(): ",
col_indices.numel(),
", values.numel(): ",
values.numel());
@@ -109,28 +86,22 @@
// Indices invariants
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
Tensor crow_indices_cpu = crow_indices.to(kCPU);
- auto crow_indices_data_ptr = crow_indices_cpu.data_ptr<index_t>();
- auto batch_stride = crow_indices_cpu.dim() >= 2 ? crow_indices_cpu.stride(-2) : 0;
- for (const auto batch_id : c10::irange(batchCount(crow_indices_cpu))) {
- TORCH_CHECK(
- crow_indices_data_ptr[batch_id*batch_stride] == 0,
- "(Batch element ", batch_id, ") ",
- ": 0th value of crow_indices must be 0, but it is ", crow_indices_data_ptr[batch_id*batch_stride]);
- TORCH_CHECK(
- crow_indices_data_ptr[batch_id*batch_stride + crow_indices.size(-1) - 1] == col_indices.size(-1),
- "(Batch element ", batch_id, ") ",
- "last value of crow_indices should be equal to the length of col_indices.");
+ auto crow_indices_accessor = crow_indices_cpu.accessor<index_t, 1>();
+ TORCH_CHECK(
+ crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
- for (int i = 1; i <= size[size.size() - 2]; i++) {
- TORCH_CHECK(
- crow_indices_data_ptr[batch_id*batch_stride + i - 1] <= crow_indices_data_ptr[batch_id*batch_stride + i],
- "(Batch element ", batch_id, ") ",
- "at position i = ", i, ", the condition crow_indices[i - 1] <= crow_indices[i] fails");
- }
+ TORCH_CHECK(
+ crow_indices_accessor[crow_indices.numel() - 1] == col_indices.numel(),
+ "last value of crow_indices should be equal to the length of col_indices.");
+
+ for (int i = 1; i <= size[0]; i++) {
+ TORCH_CHECK(
+ crow_indices_accessor[i - 1] <= crow_indices_accessor[i],
+ "at position i = ", i, ", this condition crow_indices[i - 1] <= crow_indices[i] fails");
}
if (col_indices.numel() > 0) {
TORCH_CHECK(0 <= col_indices.min().item<index_t>(), "col_indices.min() should be greater or equal to zero");
- TORCH_CHECK(size[size.size() - 1] > col_indices.max().item<index_t>(), "size[-1] should be greater than col_indices.max()");
+ TORCH_CHECK(size[1] > col_indices.max().item<index_t>(), "size(1) should be greater than col_indices.max()");
}
});
@@ -242,10 +213,13 @@
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
- // std::array<int64_t, 2> size = {0, 0};
- auto size = DimVector(IntArrayRef(col_indices.sizes().data(), col_indices.dim() - 1));
- size.push_back(crow_indices.size(-1) - 1);
- size.push_back(col_indices.max().item<int64_t>() + 1);
+ std::array<int64_t, 2> size = {0, 0};
+ if (col_indices.numel() > 0) {
+ AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_construct_check", [&] {
+ size[0] = crow_indices.numel() - 1;
+ size[1] = col_indices.max().item<index_t>() + 1;
+ });
+ }
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
@@ -269,21 +243,16 @@
c10::optional<MemoryFormat> optional_memory_format) {
check_size_nonnegative(size);
- TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse CSR matrices are supported, but got size ", size);
+ TORCH_CHECK(size.size() == 2, "torch.empty: Only 2D sparse CSR tensors are supported.");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::SparseCsr);
- auto rows = size[size.size() - 2];
+ auto rows = size[0];
int64_t nnz = 0;
- auto crow_indices_size = DimVector(size.slice(0, size.size() - 2));
- crow_indices_size.push_back(rows + 1);
- auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
- col_indices_values_size.push_back(nnz);
-
TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
- auto crow_indices = at::empty(crow_indices_size, options);
- auto col_indices = at::empty(col_indices_values_size, options);
- auto values = at::empty(col_indices_values_size, options.dtype(dtype));
+ auto crow_indices = at::empty({rows + 1}, options);
+ auto col_indices = at::empty({nnz}, options);
+ auto values = at::empty({nnz}, options.dtype(dtype));
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
@@ -301,13 +270,13 @@
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
check_size_nonnegative(size);
- TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
+ TORCH_CHECK(size.size() == 2, "torch.resize_: Only 2D sparse CSR tensors are supported.");
TORCH_CHECK(
- self.size(-1) <= size[size.size() - 1],
+ self.size(1) <= size[1],
"torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
"The original number of columns is ",
- self.size(-1),
- " while the requested new number of columns is ", size[size.size() - 1], ".");
+ self.size(1),
+ " while the requested new number of columns is ", size[1], ".");
get_sparse_csr_impl(self)->resize_(self._nnz(), size);
return self;
}
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index 7ae5f3b..76eaf61 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -638,10 +638,13 @@
" in add operation");
auto src_values = src.values();
+ auto src_crow_indices = src.crow_indices();
+ auto src_col_indices = src.col_indices();
resize_output(out, dense.sizes());
Tensor resultBuffer = out;
+ Tensor valuesBuffer = src_values.to(commonDtype);
if (out.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
@@ -649,15 +652,6 @@
resultBuffer.copy_(dense);
}
- if (src._nnz() == 0) {
- return;
- }
-
- auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
- resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
- auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
- auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf,
kBool,
@@ -677,26 +671,27 @@
&alpha,
&src_crow_indices,
&src_col_indices]() {
- auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
- auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
+ auto values_accessor = valuesBuffer.accessor<scalar_t, 1>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
auto crow_indices_accessor =
- src_crow_indices.accessor<index_t, 2>();
+ src_crow_indices.accessor<index_t, 1>();
auto col_indices_accessor =
- src_col_indices.accessor<index_t, 2>();
- auto out_strides = resultBuffer.strides();
+ src_col_indices.accessor<index_t, 1>();
+ auto out_strides0 = resultBuffer.strides()[0];
+ auto out_strides1 = resultBuffer.strides()[1];
- for (const auto batch_idx : c10::irange(batch_count)) {
- for (const auto irow : c10::irange(src_crow_indices.size(-1) - 1)) {
- index_t start_index = crow_indices_accessor[batch_idx][irow];
- index_t end_index = crow_indices_accessor[batch_idx][irow + 1];
- for (const auto i : c10::irange(start_index, end_index)) {
- auto icol = col_indices_accessor[batch_idx][i];
- auto index = batch_idx * out_strides[0] + irow * out_strides[1] + icol * out_strides[2];
- out_ptr[index] += cast_value * values_accessor[batch_idx][i];
- }
+ for (index_t irow = 0; irow < src_crow_indices.size(0) - 1;
+ ++irow) {
+ index_t start_index = crow_indices_accessor[irow];
+ index_t end_index = crow_indices_accessor[irow + 1];
+
+ for (index_t i = start_index; i < end_index; ++i) {
+ auto icol = col_indices_accessor[i];
+ auto index = resultBuffer.storage_offset() +
+ irow * out_strides0 + icol * out_strides1;
+ out_ptr[index] += cast_value * values_accessor[i];
}
}
});
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index a567bbe..3444942 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -978,18 +978,6 @@
auto B_col_indices_ptr = B_col_indices.data_ptr<int>();
auto C_col_indices_ptr = C_col_indices.data_ptr<int>();
- // Windows compilers don't support nested macros
- // so we need this lambda outside of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
- auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
- // For some reason POINTER_MODE_HOST is not working here
- // Let's extract manually the nnz from the C_crow_indices
- #if AT_ROCM_ENABLED()
- return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
- #else
- return nnz;
- #endif
- };
-
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
C.scalar_type(), "add_out_sparse_csr_cuda_impl", [&] {
auto beta_ = beta.to<scalar_t>();
@@ -1050,8 +1038,6 @@
&nnzC,
work_data.get());
- nnzC = fix_nnz(nnzC);
-
// Resize result using nnz information from cusparse
col_indices_and_values_resize_(C, nnzC);
C_col_indices = C.col_indices();
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
index 436aefa..c13984f 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
@@ -159,26 +159,18 @@
" in add operation");
Tensor src_values = src.values();
+ Tensor src_crow_indices = src.crow_indices();
+ Tensor src_col_indices = src.col_indices();
resize_output(output, dense.sizes());
Tensor resultBuffer = output;
-
+ Tensor valuesBuffer = src_values.to(commonDtype);
if (output.scalar_type() != commonDtype) {
resultBuffer = dense.to(commonDtype);
} else if (!is_same_tensor(output, dense)) {
resultBuffer.copy_(dense);
}
-
- if (src._nnz() == 0) {
- return output;
- }
-
- auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
- resultBuffer = resultBuffer.view({-1, output.size(-2), output.size(-1)});
- auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
- auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
kHalf, kBool, kBFloat16,
commonDtype,
@@ -188,7 +180,6 @@
src_crow_indices.scalar_type(),
"csr_add_out_crow_indices",
[&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
- auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
scalar_t cast_value = alpha.to<scalar_t>();
@@ -198,11 +189,8 @@
int64_t out_storage_offset = resultBuffer.storage_offset();
auto out_strides = resultBuffer.strides();
- auto out_strides0 = out_strides[0];
- auto out_strides1 = out_strides[1];
- auto crow_stride0 = src_crow_indices.stride(0);
- auto col_stride0 = src_col_indices.stride(0);
- auto val_stride0 = valuesBuffer.stride(0);
+ int64_t out_strides0 = out_strides[0];
+ int64_t out_strides1 = out_strides[1];
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
at::cuda::ThrustAllocator allocator;
@@ -212,29 +200,24 @@
thrust::for_each(
policy,
thrust::make_counting_iterator(int64_t(0)),
- thrust::make_counting_iterator(int64_t(src_crow_indices.size(-1) - 1)),
+ thrust::make_counting_iterator(int64_t(src_crow_indices.size(0) - 1)),
[values_accessor,
crow_indices_accessor,
col_indices_accessor,
out_ptr,
- cast_value,
+ out_storage_offset,
out_strides0,
- out_strides1,
- crow_stride0,
- col_stride0,
- val_stride0,
- batch_count
+ cast_value,
+ out_strides1
]__device__(int64_t irow) {
- for (index_t batch_idx = 0; batch_idx < batch_count; batch_idx++) {
- index_t start_index = crow_indices_accessor[batch_idx*crow_stride0 + irow];
- index_t end_index = crow_indices_accessor[batch_idx*crow_stride0 + irow + 1];
+ index_t start_index = crow_indices_accessor[irow];
+ index_t end_index = crow_indices_accessor[irow + 1];
for (index_t i = start_index; i < end_index; ++i) {
- auto icol = col_indices_accessor[batch_idx*col_stride0 + i];
- auto index = batch_idx * out_strides0 + irow * out_strides1 + icol;
- out_ptr[index] += cast_value * values_accessor[batch_idx*val_stride0 + i];
+ auto icol = col_indices_accessor[i];
+ auto index = out_storage_offset + irow * out_strides0 + icol * out_strides1;
+ out_ptr[index] += cast_value * values_accessor[i];
}
- }
});
});
});
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 004bcf9..60ebebf 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -166,33 +166,6 @@
self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values())
@dtypes(*get_all_dtypes())
- def test_sparse_csr_batch_constructor(self, device, dtype):
- batch_shape = (2, 3)
- crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
- col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
- values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
- for index_dtype in [torch.int32, torch.int64]:
- sparse = torch.sparse_csr_tensor(crow_indices.to(index_dtype),
- col_indices.to(index_dtype),
- values,
- size=(*batch_shape, 2, 10),
- dtype=dtype,
- device=device)
- self.assertEqual((*batch_shape, 2, 10), sparse.shape)
- self.assertEqual(crow_indices.to(index_dtype), sparse.crow_indices())
- self.assertEqual(col_indices.to(index_dtype), sparse.col_indices())
- self.assertEqual(values, sparse.values())
-
- @dtypes(*get_all_dtypes())
- def test_sparse_csr_batch_constructor_shape_inference(self, device, dtype):
- batch_shape = (2, 3)
- crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
- col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
- values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
- sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
- self.assertEqual((*batch_shape, crow_indices.shape[-1] - 1, col_indices.max() + 1), sparse.shape)
-
- @dtypes(*get_all_dtypes())
def test_sparse_csr_constructor_from_lists(self, device, dtype):
# without size
sparse = torch.sparse_csr_tensor([0, 2, 4],
@@ -225,17 +198,15 @@
@dtypes(*get_all_dtypes())
def test_empty(self, device, dtype):
ns = [5, 2, 0]
- batch_shapes = [(), (2,), (2, 3)]
- for m, n, b in itertools.product(ns, ns, batch_shapes):
- shape = (*b, m, n)
+ for shape in itertools.product(ns, ns):
result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, dtype)
self.assertEqual(result.device, torch.device(device))
self.assertEqual(result.layout, torch.sparse_csr)
- self.assertEqual(result.crow_indices().shape, (*b, shape[-2] + 1,))
- self.assertEqual(result.col_indices().shape, (*b, 0,))
- self.assertEqual(result.values().shape, (*b, 0,))
+ self.assertEqual(result.crow_indices().shape, (shape[0] + 1,))
+ self.assertEqual(result.col_indices().shape, (0,))
+ self.assertEqual(result.values().shape, (0,))
self.assertEqual(result._nnz(), 0)
self.assertEqual(result.crow_indices().device, torch.device(device))
self.assertEqual(result.col_indices().device, torch.device(device))
@@ -247,22 +218,23 @@
@skipMeta
@dtypes(*get_all_dtypes())
def test_empty_errors(self, device, dtype):
- with self.assertRaisesRegex(RuntimeError, "torch.empty: Only batched sparse CSR matrices are supported, but got size"):
+ with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)
+ with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
+ torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)
+
@skipMeta
@dtypes(*get_all_dtypes())
def test_clone(self, device, dtype):
- from operator import mul
- from functools import reduce
- for batch_shape in ((), (2,), (2, 3)):
- prod = reduce(mul, batch_shape, 1)
- crow_indices = torch.tensor([0, 2, 4], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
- col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
- values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1)
- sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
- cloned_sparse = sparse.clone()
- self.assertEqual(sparse, cloned_sparse)
+ x = torch.sparse_csr_tensor([0, 2, 4],
+ [0, 1, 0, 1],
+ [1, 2, 3, 4],
+ dtype=dtype,
+ device=device)
+ y = x.clone()
+
+ self.assertEqual(x, y)
@skipMeta
@dtypes(*get_all_dtypes())
@@ -277,10 +249,9 @@
self.assertEqual(a, b)
ns = [5, 2, 0]
- batch_shapes = [(), (2,), (2, 3)]
- for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
- run_test((*b, m, n), 0, index_dtype)
- run_test((*b, m, n), m * n, index_dtype)
+ for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
+ run_test(shape, 0, index_dtype)
+ run_test(shape, shape[0] * shape[1], index_dtype)
@skipMeta
@dtypes(*get_all_dtypes())
@@ -304,31 +275,25 @@
@skipMeta
@dtypes(*get_all_dtypes())
def test_resize(self, device, dtype):
- batch_shapes = [(), (2,), (2, 3)]
- for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
- shape = (*b, 2, 3)
+ for index_dtype in [torch.int32, torch.int64]:
+ shape = (2, 3)
nnz = 6
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
- new_shape = (*b, 4, 5)
+ new_shape = (4, 5)
a.resize_(new_shape)
self.assertEqual(a.shape, new_shape)
# resize to larger shape doesn't add specified elements
self.assertEqual(a._nnz(), nnz)
- new_shape = (*b, 1, 5)
+ new_shape = (1, 5)
a.resize_(new_shape)
self.assertEqual(a.shape, new_shape)
# resize to smaller shape trims specified elements
self.assertEqual(a._nnz(), 5)
- # trim batched dimensions
- a.resize_(new_shape[-2], new_shape[-1])
- self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
- self.assertEqual(a._nnz(), 5)
-
@skipMeta
@dtypes(*get_all_dtypes())
def test_resize_errors(self, device, dtype):
@@ -337,7 +302,7 @@
nnz = 6
a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
- with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
+ with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
new_shape = (4,)
a.resize_(new_shape)
@@ -382,62 +347,49 @@
torch.tensor([1, 2, 3, 4]))
def test_factory_shape_invariants_check(self, device):
- crow_indices = torch.tensor([0, 2, 4], device=device)
- col_indices = torch.tensor([0, 1, 0, 1], device=device)
- values = torch.tensor([1, 2, 3, 4], device=device)
+ crow_indices = [0, 2, 4]
+ col_indices = [0, 1, 0, 1]
+ values = [1, 2, 3, 4]
size = (2, 10)
- torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device)
+ torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
+ device=device)
- with self.assertRaisesRegex(RuntimeError, r"size of a batched CSR tensor must have length >= 2, but got: 1"):
- torch.sparse_csr_tensor(crow_indices, col_indices, values,
- size=(2,),
+ with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
+ torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
+ size=(2, 10, 2),
device=device)
- with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim >= 1 but got crow_indices\.dim\(\)\ = 0"):
- torch.sparse_csr_tensor(torch.zeros((), device=device, dtype=torch.int64),
- col_indices,
- values,
+ with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
+ torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
+ torch.tensor(col_indices),
+ torch.tensor(values),
size,
device=device)
- with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim >= 1 but got col_indices\.dim\(\)\ = 0"):
- torch.sparse_csr_tensor(crow_indices,
- torch.zeros((), device=device, dtype=torch.int64),
- values,
+ with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
+ torch.sparse_csr_tensor(torch.tensor(crow_indices),
+ torch.tensor(col_indices).repeat(2, 1),
+ torch.tensor(values),
size,
device=device)
- with self.assertRaisesRegex(RuntimeError, r"values must have dim >= 1 but got values\.dim\(\)\ = 0"):
- torch.sparse_csr_tensor(crow_indices,
- col_indices,
- torch.zeros((), device=device, dtype=torch.int64),
+ with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
+ torch.sparse_csr_tensor(torch.tensor(crow_indices),
+ torch.tensor(col_indices),
+ torch.tensor(values).repeat(2, 1),
size,
device=device)
with self.assertRaisesRegex(RuntimeError,
- r"crow_indices\.size\(-1\) must be equal to size\[-2\] \+ 1 \(that is 2\), but got: 3"):
- torch.sparse_csr_tensor(crow_indices, col_indices, values, (1, 1),
+ r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
+ torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
device=device)
with self.assertRaisesRegex(RuntimeError,
- r"Number of dimensions of crow_indices and col_indices must be the same"):
- torch.sparse_csr_tensor(crow_indices, col_indices.repeat(2, 1), values, size,
- device=device)
-
- with self.assertRaisesRegex(RuntimeError,
- r"Number of dimensions of indices and values must be the same"):
- torch.sparse_csr_tensor(crow_indices, col_indices, values.repeat(2, 1), size,
- device=device)
-
- with self.assertRaisesRegex(RuntimeError,
- r"Number of dimensions of indices must be one less"):
- torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(2, 1), values.repeat(2, 1), size,
- device=device)
-
- with self.assertRaisesRegex(RuntimeError,
- r"All batch dimensions of the provided size, indices, and values must be the same"):
- torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(3, 1), values.repeat(4, 1), (2, 2, 10),
+ r"col_indices and values must have equal sizes, " +
+ r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
+ torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
device=device)
def test_factory_indices_invariants_check(self, device):
@@ -456,7 +408,7 @@
with self.assertRaisesRegex(RuntimeError,
r"at position i \= 2," +
- r" the condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
+ r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
device=device)
@@ -464,7 +416,7 @@
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
device=device)
- with self.assertRaisesRegex(RuntimeError, r"size\[-1\] should be greater than col_indices\.max\(\)"):
+ with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
device=device)
@@ -569,12 +521,12 @@
sparse = dense.to_sparse_csr()
self.assertEqual(sparse.to_dense(), dense)
- batch_shape = (2, 3)
- crow_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
- col_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
- values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
- csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
- dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(csr.shape)
+ crow_indices = torch.tensor([0, 3, 5])
+ col_indices = torch.tensor([0, 1, 2, 0, 1])
+ values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
+ csr = torch.sparse_csr_tensor(crow_indices, col_indices,
+ values, dtype=dtype, device=device)
+ dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
self.assertEqual(csr.to_dense(), dense)
@skipCPUIfNoMklSparse
@@ -1147,9 +1099,6 @@
@dtypes(torch.float, torch.double)
def test_add(self, device, dtype):
def _test_spadd_shape(nnz, shape):
- # sparse.to_dense() uses torch.add internally so if torch.add is wrong,
- # the dense tensor will be wrong but this test would still pass
- # there's a separate test that checks for the correctness of the .to_dense() call
x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
y = torch.randn(*shape, dtype=dtype, device=device)
r = random.random()
@@ -1171,12 +1120,10 @@
self.assertEqual(res, expected)
- ns = [2, 5]
- batch_shapes = [(), (2,), (2, 3)]
- for b, m, n in itertools.product(batch_shapes, ns, ns):
- _test_spadd_shape(0, (*b, m, n))
- _test_spadd_shape(m * n // 2, (*b, m, n))
- _test_spadd_shape(m * n, (*b, m, n))
+ _test_spadd_shape(10, [100, 100])
+ _test_spadd_shape(0, [100, 100])
+ _test_spadd_shape(10, [100, 1])
+ _test_spadd_shape(10, [1, 100])
@dtypes(torch.float, torch.double)
def test_mul(self, device, dtype):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index bd8732c..1e2d246 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1999,11 +1999,9 @@
return crow_indices.to(device=device)
def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype):
- from operator import mul
- from functools import reduce
sparse_dim = 2
- assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
- assert len(size) >= sparse_dim
+ assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
+ assert len(size) == sparse_dim
def random_sparse_csr(n_rows, n_cols, nnz):
crow_indices = self._make_crow_indices(n_rows, n_cols, nnz, device=device, dtype=index_dtype)
@@ -2017,15 +2015,7 @@
values = make_tensor([nnz], device=device, dtype=dtype, low=low, high=high)
return values, crow_indices, col_indices
- batch_shape = size[:-2]
- n_batch = reduce(mul, batch_shape, 1)
-
- sparse_tensors = [random_sparse_csr(size[-2], size[-1], nnz) for _ in range(n_batch)]
- sparse_tensors_it = map(list, zip(*sparse_tensors))
- values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
- crow_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
- col_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
-
+ values, crow_indices, col_indices = random_sparse_csr(size[0], size[1], nnz)
return torch.sparse_csr_tensor(crow_indices,
col_indices,
values, size=size, dtype=dtype, device=device)