Revert "Add spdiags sparse matrix initialization (#78439)"

This reverts commit cfb2034b657e8527767f1f74854bc62b4d6d4927.

Reverted https://github.com/pytorch/pytorch/pull/78439 on behalf of https://github.com/suo due to broke windows builds, see: https://hud.pytorch.org/pytorch/pytorch/commit/cfb2034b657e8527767f1f74854bc62b4d6d4927
diff --git a/aten/src/ATen/native/cpu/SparseFactories.cpp b/aten/src/ATen/native/cpu/SparseFactories.cpp
deleted file mode 100644
index 0b0f73e..0000000
--- a/aten/src/ATen/native/cpu/SparseFactories.cpp
+++ /dev/null
@@ -1,74 +0,0 @@
-#include <ATen/Dispatch.h>
-#include <ATen/SparseTensorImpl.h>
-#include <ATen/SparseTensorUtils.h>
-#include <ATen/TensorIndexing.h>
-#include <ATen/TensorIterator.h>
-#include <ATen/core/ATen_fwd.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/native/cpu/Loops.h>
-#include <ATen/native/sparse/SparseFactories.h>
-#include <c10/core/Scalar.h>
-#include <c10/util/ArrayRef.h>
-#include <c10/util/Exception.h>
-
-#ifndef AT_PER_OPERATOR_HEADERS
-#include <ATen/Functions.h>
-#include <ATen/NativeFunctions.h>
-#else
-#include <ATen/ops/sparse_coo_tensor.h>
-#endif
-
-namespace at {
-namespace native {
-using namespace at::sparse;
-
-namespace {
-void _spdiags_kernel_cpu(
-    TensorIterator& iter,
-    const Tensor& diagonals,
-    Tensor& values,
-    Tensor& indices) {
-  auto* row_index_write_ptr = indices[0].data_ptr<int64_t>();
-  auto* col_index_write_ptr = indices[1].data_ptr<int64_t>();
-  const int64_t diagonals_read_stride = diagonals.stride(1);
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
-      at::ScalarType::BFloat16,
-      at::ScalarType::Half,
-      at::ScalarType::Bool,
-      at::ScalarType::ComplexHalf,
-      diagonals.scalar_type(),
-      "spdiags_cpu",
-      [&] {
-        auto* values_write_ptr = values.data_ptr<scalar_t>();
-        cpu_kernel(
-            iter,
-            [&](int64_t diag_index,
-                int64_t diag_offset,
-                int64_t out_offset,
-                int64_t n_out) -> int64_t {
-              if (n_out > 0) {
-                auto* rows_start = row_index_write_ptr + out_offset;
-                auto* cols_start = col_index_write_ptr + out_offset;
-                auto* vals_start = values_write_ptr + out_offset;
-                const int64_t first_col = std::max<int64_t>(diag_offset, 0);
-                const int64_t first_row = first_col - diag_offset;
-                auto* data_read = diagonals[diag_index].data_ptr<scalar_t>() +
-                    first_col * diagonals_read_stride;
-                for (int64_t i = 0; i < n_out; ++i) {
-                  rows_start[i] = first_row + i;
-                  cols_start[i] = first_col + i;
-                  vals_start[i] = data_read[i * diagonals_read_stride];
-                }
-              }
-              // dummy return
-              return 0;
-            });
-      });
-}
-
-} // namespace
-
-REGISTER_DISPATCH(spdiags_kernel_stub, &_spdiags_kernel_cpu)
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index d7a4e56..8a7656c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5281,11 +5281,6 @@
     SparseCPU: log_softmax_backward_sparse_cpu
     SparseCUDA: log_softmax_backward_sparse_cuda
 
-- func: _spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor
-  python_module: sparse
-  dispatch:
-    CPU: spdiags
-
 - func: norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor
   device_check: NoCheck   # TensorIterator
   variants: function, method
diff --git a/aten/src/ATen/native/sparse/SparseFactories.cpp b/aten/src/ATen/native/sparse/SparseFactories.cpp
deleted file mode 100644
index f000774..0000000
--- a/aten/src/ATen/native/sparse/SparseFactories.cpp
+++ /dev/null
@@ -1,95 +0,0 @@
-#include <ATen/Dispatch.h>
-#include <ATen/native/sparse/SparseFactories.h>
-
-#ifndef AT_PER_OPERATOR_HEADERS
-#include <ATen/Functions.h>
-#include <ATen/NativeFunctions.h>
-#else
-#include <ATen/ops/_unique.h>
-#include <ATen/ops/arange.h>
-#include <ATen/ops/empty.h>
-#include <ATen/ops/sparse_coo_tensor.h>
-#include <ATen/ops/where.h>
-#endif
-
-namespace at {
-namespace native {
-
-DEFINE_DISPATCH(spdiags_kernel_stub);
-
-Tensor spdiags(
-    const Tensor& diagonals,
-    const Tensor& offsets,
-    IntArrayRef shape,
-    c10::optional<Layout> layout) {
-  auto diagonals_2d = diagonals.dim() == 1 ? diagonals.unsqueeze(0) : diagonals;
-  TORCH_CHECK(diagonals_2d.dim() == 2, "Diagonals must be vector or matrix");
-  TORCH_CHECK(shape.size() == 2, "Output shape must be 2d");
-  auto offsets_1d = offsets.dim() == 0 ? offsets.unsqueeze(0) : offsets;
-  TORCH_CHECK(offsets_1d.dim() == 1, "Offsets must be scalar or vector");
-  TORCH_CHECK(
-      diagonals_2d.size(0) == offsets_1d.size(0),
-      "Number of diagonals (",
-      diagonals_2d.size(0),
-      ") does not match the number of offsets (",
-      offsets_1d.size(0),
-      ")");
-  if (layout) {
-    TORCH_CHECK(
-        (*layout == Layout::Sparse) || (*layout == Layout::SparseCsc) ||
-            (*layout == Layout::SparseCsr),
-        "Only output layouts (Sparse, SparseCsc, SparseCsr) are supported, got ",
-        *layout);
-  }
-  TORCH_CHECK(
-      offsets_1d.scalar_type() == at::kLong,
-      "Offset Tensor must have dtype Long but got ",
-      offsets_1d.scalar_type());
-
-  TORCH_CHECK(
-      offsets_1d.numel() == std::get<0>(at::_unique(offsets_1d)).numel(),
-      "Offset tensor contains duplicate values");
-
-  auto nnz_per_diag = at::where(
-      offsets_1d.le(0),
-      offsets_1d.add(shape[0]).clamp_max_(diagonals_2d.size(1)),
-      offsets_1d.add(-std::min<int64_t>(shape[1], diagonals_2d.size(1))).neg());
-
-  auto nnz_per_diag_cumsum = nnz_per_diag.cumsum(-1);
-  const auto nnz = diagonals_2d.size(0) > 0
-      ? nnz_per_diag_cumsum.select(-1, -1).item<int64_t>()
-      : int64_t{0};
-  // Offsets into nnz for each diagonal
-  auto result_mem_offsets = nnz_per_diag_cumsum.sub(nnz_per_diag);
-  // coo tensor guts
-  auto indices = at::empty({2, nnz}, offsets_1d.options());
-  auto values = at::empty({nnz}, diagonals_2d.options());
-  // We add this indexer to lookup the row of diagonals we are reading from at
-  // each iteration
-  const auto n_diag = offsets_1d.size(0);
-  Tensor diag_index = at::arange(n_diag, offsets_1d.options());
-  // cpu_kernel requires an output
-  auto dummy = at::empty({1}, offsets_1d.options()).resize_({0});
-  auto iter = TensorIteratorConfig()
-                  .set_check_mem_overlap(false)
-                  .add_output(dummy)
-                  .add_input(diag_index)
-                  .add_input(offsets_1d)
-                  .add_input(result_mem_offsets)
-                  .add_input(nnz_per_diag)
-                  .build();
-  spdiags_kernel_stub(iter.device_type(), iter, diagonals_2d, values, indices);
-  auto result_coo = at::sparse_coo_tensor(indices, values, shape);
-  if (layout) {
-    if (*layout == Layout::SparseCsr) {
-      return result_coo.to_sparse_csr();
-    }
-    if (*layout == Layout::SparseCsc) {
-      return result_coo.to_sparse_csc();
-    }
-  }
-  return result_coo;
-}
-
-} // namespace native
-} // namespace at
diff --git a/aten/src/ATen/native/sparse/SparseFactories.h b/aten/src/ATen/native/sparse/SparseFactories.h
deleted file mode 100644
index 3fd6893..0000000
--- a/aten/src/ATen/native/sparse/SparseFactories.h
+++ /dev/null
@@ -1,15 +0,0 @@
-#pragma once
-#include <ATen/TensorIterator.h>
-#include <ATen/core/ATen_fwd.h>
-#include <ATen/core/Tensor.h>
-#include <ATen/native/DispatchStub.h>
-
-namespace at {
-namespace native {
-
-using spdiags_kernel_fn_t =
-    void (*)(TensorIterator&, const Tensor&, Tensor&, Tensor&);
-
-DECLARE_DISPATCH(spdiags_kernel_fn_t, spdiags_kernel_stub);
-} // namespace native
-} // namespace at
diff --git a/build_variables.bzl b/build_variables.bzl
index 83f18ee..aed30fd 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -1155,7 +1155,6 @@
     "aten/src/ATen/native/cpu/scaled_modified_bessel_k0.cpp",
     "aten/src/ATen/native/cpu/scaled_modified_bessel_k1.cpp",
     "aten/src/ATen/native/cpu/spherical_bessel_j0.cpp",
-    "aten/src/ATen/native/cpu/SparseFactories.cpp",
     "aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp",
 ]
 
@@ -1358,7 +1357,6 @@
     "aten/src/ATen/native/sparse/SparseTensorMath.cpp",
     "aten/src/ATen/native/sparse/SparseUnaryOps.cpp",
     "aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp",
-    "aten/src/ATen/native/sparse/SparseFactories.cpp",
     "aten/src/ATen/native/transformers/attention.cpp",
     "aten/src/ATen/native/transformers/transformer.cpp",
     "aten/src/ATen/native/utils/Factory.cpp",
diff --git a/docs/source/sparse.rst b/docs/source/sparse.rst
index 5b2f526..564df4e 100644
--- a/docs/source/sparse.rst
+++ b/docs/source/sparse.rst
@@ -599,7 +599,6 @@
     smm
     sparse.softmax
     sparse.log_softmax
-    sparse.spdiags
 
 Other functions
 +++++++++++++++
diff --git a/test/test_sparse.py b/test/test_sparse.py
index b2293e3..f75c330 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -8,7 +8,7 @@
 import unittest
 from torch.testing import make_tensor
 from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
-    do_test_empty_full, load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
+    do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS, gradcheck, coalescedonoff, \
     DeterministicGuard, first_sample
 from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
 from numbers import Number
@@ -26,9 +26,6 @@
     floating_and_complex_types_and, integral_types, floating_types_and,
 )
 
-if TEST_SCIPY:
-    import scipy.sparse
-
 # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
 # sharding on sandcastle. This line silences flake warnings
 load_tests = load_tests
@@ -3561,94 +3558,6 @@
         test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
         test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
 
-    @unittest.skipIf(not TEST_NUMPY, "NumPy is not availible")
-    @onlyCPU
-    @dtypes(*all_types_and_complex_and(torch.bool))
-    def test_sparse_spdiags(self, device, dtype):
-
-        make_diags = functools.partial(make_tensor, dtype=dtype, device=device)
-        make_offsets = functools.partial(torch.tensor, dtype=torch.long, device=device)
-
-        if TEST_SCIPY:
-            def reference(diags, offsets, shape):
-                return scipy.sparse.spdiags(diags, offsets, *shape).toarray()
-
-        else:
-            def reference(diags, offsets, shape):
-                result = torch.zeros(shape, dtype=dtype, device=device)
-                for i, off in enumerate(offsets):
-                    res_view = result.diagonal(off)
-                    data = diags[i]
-                    if off > 0:
-                        data = data[off:]
-
-                    m = min(res_view.shape[0], data.shape[0])
-                    res_view[:m] = data[:m]
-                return result
-
-        def check_valid(diags, offsets, shape, layout=None):
-            ref_out = reference(diags, offsets, shape)
-            out = torch.sparse.spdiags(diags, offsets, shape, layout=layout)
-            if layout is None:
-                ex_layout = torch.sparse_coo
-            else:
-                ex_layout = layout
-            out_dense = out.to_dense()
-            self.assertTrue(out.layout == ex_layout, f"Output layout {out.layout} expected {ex_layout}")
-            self.assertEqual(out_dense, ref_out, f"Result:\n{out_dense} does not match reference:\n{ref_out}")
-
-        def check_invalid(args, error):
-            with self.assertRaisesRegex(RuntimeError, error):
-                torch.sparse.spdiags(*args)
-
-        def valid_cases():
-            # some normal cases
-            yield (make_diags((1, 5)), make_offsets([0]), (5, 5))
-            yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4))
-            # noncontigous diags
-            yield (make_diags((5, 4), noncontiguous=True), make_offsets([-1, 1, 0, 2, -2]), (5, 5))
-            # noncontigous offsets
-            yield (make_diags((3, 4)), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
-            # noncontigous diags + offsets
-            yield (make_diags((3, 4), noncontiguous=True), make_offsets([1, -1, 0, -2, 2])[::2], (5, 5))
-            # correct dimensionality, 2d, 2d , and shapes match, but the number of diagonals is zero
-            yield (make_diags((0, 3)), make_offsets([]), (3, 3))
-            # forward rotation of upper diagonals
-            yield (make_diags((3, 8)), make_offsets([1, 2, 3]), (4, 4))
-            # rotation exausts input space to read from
-            yield (make_diags((2, 3)), make_offsets([2, 1]), (3, 3))
-            # Simple cases repeated with special output format
-            yield (make_diags((1, 5)), make_offsets([0]), (5, 5), torch.sparse_csc)
-            yield (make_diags((3, 3)), make_offsets([-1, 0, 1]), (4, 4), torch.sparse_csr)
-            # vector diags
-            yield (make_diags((3, )), make_offsets([1]), (4, 4))
-            # Scalar offset
-            yield (make_diags((1, 3)), make_offsets(2), (4, 4))
-            # offsets out of range
-            yield (make_diags((1, 3)), make_offsets([3]), (3, 3))
-            yield (make_diags((1, 3)), make_offsets([-3]), (3, 3))
-
-        for case in valid_cases():
-            check_valid(*case)
-
-        def invalid_cases():
-            yield (make_diags((1, 3)), make_offsets([0]), (3, 2, 3)), "Output shape must be 2d"
-            yield (make_diags((2, 3)), make_offsets([[1, 2], [0, 3]]), (3, 3)), "Offsets must be scalar or vector"
-            yield (make_diags((3, 2, 3)), make_offsets([0, 1, 2]), (4, 4)), "Diagonals must be vector or matrix"
-            yield (make_diags((3, 3)), make_offsets([-1, 0]), (3, 3)),\
-                r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
-            yield (make_diags((5,)), make_offsets([0, 1, 2, 3, 4]), (3, 3)),\
-                r"Number of diagonals \(\d\) does not match the number of offsets \(\d\)"
-            yield (make_diags((2, 2)), make_offsets([-1, 0]), (2, 3), torch.strided),\
-                r"Only output layouts \(\w+, \w+, \w+\) are supported, got \w+"
-            yield (make_diags((2, 5)), make_offsets([0, 0]), (5, 5)), "Offset tensor contains duplicate values"
-            yield (make_diags((1, 5)), make_offsets([0]).to(torch.int32), (5, 5)), r"Offset Tensor must have dtype Long but got \w+"
-
-
-        for case, error_regex in invalid_cases():
-            check_invalid(case, error_regex)
-
-
 
 class TestSparseOneOff(TestCase):
     @unittest.skipIf(not TEST_CUDA, 'CUDA not available')
diff --git a/torch/sparse/__init__.py b/torch/sparse/__init__.py
index 64b396e..29cd551 100644
--- a/torch/sparse/__init__.py
+++ b/torch/sparse/__init__.py
@@ -262,97 +262,3 @@
         performed. This is useful for preventing data type
         overflows. Default: None
 """)
-
-
-spdiags = _add_docstr(
-    _sparse._spdiags,
-    r"""
-sparse.spdiags(diagonals, offsets, shape, layout=None) -> Tensor
-
-Creates a sparse 2D tensor by placing the values from rows of
-:attr:`diagonals` along specified diagonals of the output
-
-The :attr:`offsets` tensor controls which diagonals are set.
-
-- If :attr:`offsets[i]` = 0, it is the main diagonal
-- If :attr:`offsets[i]` < 0, it is below the main diagonal
-- If :attr:`offsets[i]` > 0, it is above the main diagonal
-
-The number of rows in :attr:`diagonals` must match the length of :attr:`offsets`,
-and an offset may not be repeated.
-
-Args:
-    diagonals (Tensor): Matrix storing diagonals row-wise
-    offsets (Tensor): The diagonals to be set, stored as a vector
-    shape (2-tuple of ints): The desired shape of the result
-Keyword args:
-    layout (:class:`torch.layout`, optional): The desired layout of the
-        returned tensor. ``torch.sparse_coo``, ``torch.sparse_csc`` and ``torch.sparse_csr``
-        are supported. Default: ``torch.sparse_coo``
-
-Examples:
-
-Set the main and first two lower diagonals of a matrix::
-
-    >>> diags = torch.arange(9).reshape(3, 3)
-    >>> diags
-    tensor([[0, 1, 2],
-            [3, 4, 5],
-            [6, 7, 8]])
-    >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3))
-    >>> s
-    tensor(indices=tensor([[0, 1, 2, 1, 2, 2],
-                           [0, 1, 2, 0, 1, 0]]),
-           values=tensor([0, 1, 2, 3, 4, 6]),
-           size=(3, 3), nnz=6, layout=torch.sparse_coo)
-    >>> s.to_dense()
-    tensor([[0, 0, 0],
-            [3, 1, 0],
-            [6, 4, 2]])
-
-
-Change the output layout::
-
-    >>> diags = torch.arange(9).reshape(3, 3)
-    >>> diags
-    tensor([[0, 1, 2],[3, 4, 5], [6, 7, 8])
-    >>> s = torch.sparse.spdiags(diags, torch.tensor([0, -1, -2]), (3, 3), layout=torch.sparse_csr)
-    >>> s
-    tensor(crow_indices=tensor([0, 1, 3, 6]),
-           col_indices=tensor([0, 0, 1, 0, 1, 2]),
-           values=tensor([0, 3, 1, 6, 4, 2]), size=(3, 3), nnz=6,
-           layout=torch.sparse_csr)
-    >>> s.to_dense()
-    tensor([[0, 0, 0],
-            [3, 1, 0],
-            [6, 4, 2]])
-
-Set partial diagonals of a large output::
-
-    >>> diags = torch.tensor([[1, 2], [3, 4]])
-    >>> offsets = torch.tensor([0, -1])
-    >>> torch.sparse.spdiags(diags, offsets, (5, 5)).to_dense()
-    tensor([[1, 0, 0, 0, 0],
-            [3, 2, 0, 0, 0],
-            [0, 4, 0, 0, 0],
-            [0, 0, 0, 0, 0],
-            [0, 0, 0, 0, 0]])
-
-.. note::
-
-    When setting the values along a given diagonal the index into the diagonal
-    and the index into the row of :attr:`diagonals` is taken as the
-    column index in the output. This has the effect that when setting a diagonal
-    with a positive offset `k` the first value along that diagonal will be
-    the value in position `k` of the row of :attr:`diagonals`
-
-Specifying a positive offset::
-
-    >>> diags = torch.tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
-    >>> torch.sparse.spdiags(diags, torch.tensor([0, 1, 2]), (5, 5)).to_dense()
-    tensor([[1, 2, 3, 0, 0],
-            [0, 2, 3, 0, 0],
-            [0, 0, 3, 0, 0],
-            [0, 0, 0, 0, 0],
-            [0, 0, 0, 0, 0]])
-""")