Revert "Add spdiags sparse matrix initialization (#78439)"
This reverts commit cfb2034b657e8527767f1f74854bc62b4d6d4927.
Reverted https://github.com/pytorch/pytorch/pull/78439 on behalf of https://github.com/suo due to broke windows builds, see: https://hud.pytorch.org/pytorch/pytorch/commit/cfb2034b657e8527767f1f74854bc62b4d6d4927
diff --git a/aten/src/ATen/native/cpu/SparseFactories.cpp b/aten/src/ATen/native/cpu/SparseFactories.cpp
deleted file mode 100644
index 0b0f73e..0000000
--- a/aten/src/ATen/native/cpu/SparseFactories.cpp
+++ /dev/null
@@ -1,74 +0,0 @@
-#include <ATen/Dispatch.h>
-#include <ATen/SparseTensorImpl.h>
-#include <ATen/SparseTensorUtils.h>
-#include <ATen/TensorIndexing.h>
-#include <ATen/TensorIterator.h>
-#include <ATen/core/ATen_fwd.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/native/cpu/Loops.h>
-#include <ATen/native/sparse/SparseFactories.h>
-#include <c10/core/Scalar.h>
-#include <c10/util/ArrayRef.h>
-#include <c10/util/Exception.h>
-
-#ifndef AT_PER_OPERATOR_HEADERS
-#include <ATen/Functions.h>
-#include <ATen/NativeFunctions.h>
-#else
-#include <ATen/ops/sparse_coo_tensor.h>
-#endif
-
-namespace at {
-namespace native {
-using namespace at::sparse;
-
-namespace {
-void _spdiags_kernel_cpu(
- TensorIterator& iter,
- const Tensor& diagonals,
- Tensor& values,
- Tensor& indices) {
- auto* row_index_write_ptr = indices[0].data_ptr<int64_t>();
- auto* col_index_write_ptr = indices[1].data_ptr<int64_t>();
- const int64_t diagonals_read_stride = diagonals.stride(1);
- AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
- at::ScalarType::BFloat16,
- at::ScalarType::Half,
- at::ScalarType::Bool,
- at::ScalarType::ComplexHalf,
- diagonals.scalar_type(),
- "spdiags_cpu",
- [&] {
- auto* values_write_ptr = values.data_ptr<scalar_t>();
- cpu_kernel(
- iter,
- [&](int64_t diag_index,
- int64_t diag_offset,
- int64_t out_offset,
- int64_t n_out) -> int64_t {
- if (n_out > 0) {
- auto* rows_start = row_index_write_ptr + out_offset;
- auto* cols_start = col_index_write_ptr + out_offset;
- auto* vals_start = values_write_ptr + out_offset;
- const int64_t first_col = std::max<int64_t>(diag_offset, 0);
- const int64_t first_row = first_col - diag_offset;
- auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() +
- first_col * diagonals_read_stride;
- for (int64_t i = 0; i < n_out; ++i) {
- rows_start[i] = first_row + i;
- cols_start[i] = first_col + i;
- vals_start[i] = data_read[i * diagonals_read_stride];
- }
- }
- // dummy return
- return 0;
- });
- });
-}
-
-} // namespace
-
-REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu)
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index d7a4e56..8a7656c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5281,11 +5281,6 @@
SparseCPU: log_softmax_backward_sparse_cpu
SparseCUDA: log_softmax_backward_sparse_cuda
-- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
- python_module: sparse
- dispatch:
- CPU: spdiags
-
- func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
diff --git a/aten/src/ATen/native/sparse/SparseFactories.cpp b/aten/src/ATen/native/sparse/SparseFactories.cpp
deleted file mode 100644
index f000774..0000000
--- a/aten/src/ATen/native/sparse/SparseFactories.cpp
+++ /dev/null
@@ -1,95 +0,0 @@
-#include <ATen/Dispatch.h>
-#include <ATen/native/sparse/SparseFactories.h>
-
-#ifndef AT_PER_OPERATOR_HEADERS
-#include <ATen/Functions.h>
-#include <ATen/NativeFunctions.h>
-#else
-#include <ATen/ops/_unique.h>
-#include <ATen/ops/arange.h>
-#include <ATen/ops/empty.h>
-#include <ATen/ops/sparse_coo_tensor.h>
-#include <ATen/ops/where.h>
-#endif
-
-namespace at {
-namespace native {
-
-DEFINE_DISPATCH(spdiags_kernel_stub);
-
-Tensor spdiags(
- const Tensor& diagonals,
- const Tensor& offsets,
- IntArrayRef shape,
- c10::optional<Layout> layout) {
- auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
- TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
- TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
- auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
- TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
- TORCH_CHECK(
- diagonals_2d.size(0) == offsets_1d.size(0),
- "Number of diagonals (",
- diagonals_2d.size(0),
- ") does not match the number of offsets (",
- offsets_1d.size(0),
- ")");
- if (layout) {
- TORCH_CHECK(
- (*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
- (*layout == Layout::SparseCsr),
- "Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
- *layout);
- }
- TORCH_CHECK(
- offsets_1d.scalar_type() == at::kLong,
- "Offset Tensor must have dtype Long but got ",
- offsets_1d.scalar_type());
-
- TORCH_CHECK(
- offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
- "Offset tensor contains duplicate values");
-
- auto nnz_per_diag = at::where(
- offsets_1d.le(0),
- offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
- offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());
-
- auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
- const auto nnz = diagonals_2d.size(0) > 0
- ? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
- : int64_t{0};
- // Offsets into nnz for each diagonal
- auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
- // coo tensor guts
- auto indices = at::empty({2, nnz}, offsets_1d.options());
- auto values = at::empty({nnz}, diagonals_2d.options());
- // We add this indexer to lookup the row of diagonals we are reading from at
- // each iteration
- const auto n_diag = offsets_1d.size(0);
- Tensor diag_index = at::arange(n_diag, offsets_1d.options());
- // cpu_kernel requires an output
- auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
- auto iter = TensorIteratorConfig()
- .set_check_mem_overlap(false)
- .add_output(dummy)
- .add_input(diag_index)
- .add_input(offsets_1d)
- .add_input(result_mem_offsets)
- .add_input(nnz_per_diag)
- .build();
- spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
- auto result_coo = at::sparse_coo_tensor(indices, values, shape);
- if (layout) {
- if (*layout == Layout::SparseCsr) {
- return result_coo.to_sparse_csr();
- }
- if (*layout == Layout::SparseCsc) {
- return result_coo.to_sparse_csc();
- }
- }
- return result_coo;
-}
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/sparse/SparseFactories.h b/aten/src/ATen/native/sparse/SparseFactories.h
deleted file mode 100644
index 3fd6893..0000000
--- a/aten/src/ATen/native/sparse/SparseFactories.h
+++ /dev/null
@@ -1,15 +0,0 @@
-#pragma once
-#include <ATen/TensorIterator.h>
-#include <ATen/core/ATen_fwd.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/native/DispatchStub.h>
-
-namespace at {
-namespace native {
-
-using spdiags_kernel_fn_t =
- void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&);
-
-DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub);
-} // namespace native
-} // namespace at
diff --git a/build_variables.bzl b/build_variables.bzl
index 83f18ee..aed30fd 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -1155,7 +1155,6 @@
"aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
"aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
"aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
- "aten/src/ATen/native/cpu/SparseFactories.cpp",
"aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
]
@@ -1358,7 +1357,6 @@
"aten/src/ATen/native/sparse/SparseTensorMath.cpp",
"aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
"aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
- "aten/src/ATen/native/sparse/SparseFactories.cpp",
"aten/src/ATen/native/transformers/attention.cpp",
"aten/src/ATen/native/transformers/transformer.cpp",
"aten/src/ATen/native/utils/Factory.cpp",
diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst
index 5b2f526..564df4e 100644
--- a/docs/source/sparse.rst
+++ b/docs/source/sparse.rst
@@ -599,7 +599,6 @@
smm
sparse.softmax
sparse.log_softmax
- sparse.spdiags
Other functions
+++++++++++++++
diff --git a/test/test_sparse.py b/test/test_sparse.py
index b2293e3..f75c330 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -8,7 +8,7 @@
import unittest
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
- do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
+ do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
@@ -26,9 +26,6 @@
floating_and_complex_types_and, integral_types, floating_types_and,
)
-if TEST_SCIPY:
- import scipy.sparse
-
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@@ -3561,94 +3558,6 @@
test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
- @unittest.skipIf(not TEST_NUMPY, "NumPy is not availible")
- @onlyCPU
- @dtypes(*all_types_and_complex_and(torch.bool))
- def test_sparse_spdiags(self, device, dtype):
-
- make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
- make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)
-
- if TEST_SCIPY:
- def reference(diags, offsets, shape):
- return scipy.sparse.spdiags(diags, offsets, *shape).toarray()
-
- else:
- def reference(diags, offsets, shape):
- result = torch.zeros(shape, dtype=dtype, device=device)
- for i, off in enumerate(offsets):
- res_view = result.diagonal(off)
- data = diags[i]
- if off > 0:
- data = data[off:]
-
- m = min(res_view.shape[0], data.shape[0])
- res_view[:m] = data[:m]
- return result
-
- def check_valid(diags, offsets, shape, layout=None):
- ref_out = reference(diags, offsets, shape)
- out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
- if layout is None:
- ex_layout = torch.sparse_coo
- else:
- ex_layout = layout
- out_dense = out.to_dense()
- self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
- self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")
-
- def check_invalid(args, error):
- with self.assertRaisesRegex(RuntimeError, error):
- torch.sparse.spdiags(*args)
-
- def valid_cases():
- # some normal cases
- yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
- yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
- # noncontigous diags
- yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
- # noncontigous offsets
- yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
- # noncontigous diags + offsets
- yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
- # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
- yield (make_diags((0, 3)), make_offsets([]), (3, 3))
- # forward rotation of upper diagonals
- yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
- # rotation exausts input space to read from
- yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
- # Simple cases repeated with special output format
- yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
- yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
- # vector diags
- yield (make_diags((3, )), make_offsets([1]), (4, 4))
- # Scalar offset
- yield (make_diags((1, 3)), make_offsets(2), (4, 4))
- # offsets out of range
- yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
- yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))
-
- for case in valid_cases():
- check_valid(*case)
-
- def invalid_cases():
- yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
- yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
- yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
- yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\
- r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
- yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\
- r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
- yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\
- r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
- yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
- yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
-
-
- for case, error_regex in invalid_cases():
- check_invalid(case, error_regex)
-
-
class TestSparseOneOff(TestCase):
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py
index 64b396e..29cd551 100644
--- a/torch/sparse/__init__.py
+++ b/torch/sparse/__init__.py
@@ -262,97 +262,3 @@
performed. This is useful for preventing data type
overflows. Default: None
""")
-
-
-spdiags = _add_docstr(
- _sparse._spdiags,
- r"""
-sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
-
-Creates a sparse 2D tensor by placing the values from rows of
-:attr:`diagonals` along specified diagonals of the output
-
-The :attr:`offsets` tensor controls which diagonals are set.
-
-- If :attr:`offsets[i]` = 0, it is the main diagonal
-- If :attr:`offsets[i]` < 0, it is below the main diagonal
-- If :attr:`offsets[i]` > 0, it is above the main diagonal
-
-The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
-and an offset may not be repeated.
-
-Args:
- diagonals (Tensor): Matrix storing diagonals row-wise
- offsets (Tensor): The diagonals to be set, stored as a vector
- shape (2-tuple of ints): The desired shape of the result
-Keyword args:
- layout (:class:`torch.layout`, optional): The desired layout of the
- returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
- are supported. Default: ``torch.sparse_coo``
-
-Examples:
-
-Set the main and first two lower diagonals of a matrix::
-
- >>> diags = torch.arange(9).reshape(3, 3)
- >>> diags
- tensor([[0, 1, 2],
- [3, 4, 5],
- [6, 7, 8]])
- >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
- >>> s
- tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
- [0, 1, 2, 0, 1, 0]]),
- values=tensor([0, 1, 2, 3, 4, 6]),
- size=(3, 3), nnz=6, layout=torch.sparse_coo)
- >>> s.to_dense()
- tensor([[0, 0, 0],
- [3, 1, 0],
- [6, 4, 2]])
-
-
-Change the output layout::
-
- >>> diags = torch.arange(9).reshape(3, 3)
- >>> diags
- tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
- >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
- >>> s
- tensor(crow_indices=tensor([0, 1, 3, 6]),
- col_indices=tensor([0, 0, 1, 0, 1, 2]),
- values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
- layout=torch.sparse_csr)
- >>> s.to_dense()
- tensor([[0, 0, 0],
- [3, 1, 0],
- [6, 4, 2]])
-
-Set partial diagonals of a large output::
-
- >>> diags = torch.tensor([[1, 2], [3, 4]])
- >>> offsets = torch.tensor([0, -1])
- >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
- tensor([[1, 0, 0, 0, 0],
- [3, 2, 0, 0, 0],
- [0, 4, 0, 0, 0],
- [0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0]])
-
-.. note::
-
- When setting the values along a given diagonal the index into the diagonal
- and the index into the row of :attr:`diagonals` is taken as the
- column index in the output. This has the effect that when setting a diagonal
- with a positive offset `k` the first value along that diagonal will be
- the value in position `k` of the row of :attr:`diagonals`
-
-Specifying a positive offset::
-
- >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
- >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
- tensor([[1, 2, 3, 0, 0],
- [0, 2, 3, 0, 0],
- [0, 0, 3, 0, 0],
- [0, 0, 0, 0, 0],
- [0, 0, 0, 0, 0]])
-""")