Revert "Extend CSR constructor to support batched indices and values"

This reverts commit c074a530029cd2fbbbe74c325842bc1aef6a8ac4.

Reverted https://github.com/pytorch/pytorch/pull/74542 on behalf of https://github.com/malfet
diff --git a/aten/src/ATen/SparseCsrTensorImpl.cpp b/aten/src/ATen/SparseCsrTensorImpl.cpp
index 7808f75..2029189 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.cpp
+++ b/aten/src/ATen/SparseCsrTensorImpl.cpp
@@ -60,22 +60,17 @@
 }
 
 void SparseCsrTensorImpl::resize_(int64_t nnz, IntArrayRef size) {
-  auto rows = size[size.size() - 2];
-  auto cols = size[size.size() - 1];
+  auto rows = size[0];
+  auto cols = size[1];
   auto old_crow_indices_size = crow_indices_.size(-1);
-
-  auto new_crow_indices_size = DimVector(size.slice(0, size.size() - 2));
-  new_crow_indices_size.push_back(rows + 1);
-  crow_indices_.resize_(new_crow_indices_size);
+  crow_indices_.resize_({rows + 1});
   if (rows + 1 >= old_crow_indices_size) {
     crow_indices_.narrow(-1, old_crow_indices_size, rows + 1 - old_crow_indices_size).fill_(nnz);
   } else {
     crow_indices_.narrow(-1, rows, 1).fill_(std::min<int64_t>(nnz, rows*cols));
   }
-  auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
-  col_indices_values_size.push_back(std::min<int64_t>(nnz, rows*cols));
-  col_indices_.resize_(col_indices_values_size);
-  values_.resize_(col_indices_values_size);
+  col_indices_.resize_({std::min<int64_t>(nnz, rows*cols)});
+  values_.resize_({std::min<int64_t>(nnz, rows*cols)});
   sizes_and_strides_.set_sizes(size);
 }
 
diff --git a/aten/src/ATen/SparseCsrTensorImpl.h b/aten/src/ATen/SparseCsrTensorImpl.h
index b90176f..850e0a0 100644
--- a/aten/src/ATen/SparseCsrTensorImpl.h
+++ b/aten/src/ATen/SparseCsrTensorImpl.h
@@ -43,7 +43,7 @@
   const Tensor& crow_indices() const { return crow_indices_; }
   const Tensor& col_indices() const { return col_indices_; }
   const Tensor& values() const { return values_; }
-  int nnz() { return col_indices_.size(-1); }
+  int nnz() { return values_.size(0); }
 
   /**
    * Return a TensorImpl that is a shallow-copy of this TensorImpl.
diff --git a/aten/src/ATen/mkl/SparseDescriptors.h b/aten/src/ATen/mkl/SparseDescriptors.h
index 2c152e0..46d6568 100644
--- a/aten/src/ATen/mkl/SparseDescriptors.h
+++ b/aten/src/ATen/mkl/SparseDescriptors.h
@@ -101,7 +101,7 @@
     sparse_matrix_t raw_descriptor;
 
     // Assuming that the last two dimensions are block elements of the matrix
-    if (values.dim() == 3 && crow_indices.dim() == 1 && col_indices.dim() == 1) {
+    if (values.dim() == 3) {
       TORCH_CHECK(
           values.size(-1) == values.size(-2),
           "MKL Sparse doesn't support matrices with non-square blocks.");
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
index 3df5beb..f91d964 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp
@@ -9,7 +9,6 @@
 #include <ATen/SparseCsrTensorImpl.h>
 #include <ATen/SparseCsrTensorUtils.h>
 #include <ATen/SparseTensorImpl.h>
-#include <ATen/native/LinearAlgebraUtils.h>
 
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
@@ -57,51 +56,29 @@
 
   // Shape and Strides invariants
   TORCH_CHECK(
-      size.size() >= 2,
-      "size of a batched CSR tensor must have length >= 2, but got: ",
+      size.size() == 2,
+      "size of a CSR tensor must be of length 2, but got: ",
       size.size());
   TORCH_CHECK(
-      crow_indices.dim() >= 1,
-      "crow_indices must have dim >= 1 but got crow_indices.dim() = ",
+      crow_indices.dim() == 1,
+      "crow_indices must have dim=1 but got crow_indices.dim()=",
       crow_indices.dim());
   TORCH_CHECK(
-      col_indices.dim() >= 1,
-      "col_indices must have dim >= 1 but got col_indices.dim() = ",
+      col_indices.dim() == 1,
+      "col_indices must have dim=1 but got col_indices.dim()=",
       col_indices.dim());
   TORCH_CHECK(
-      values.dim() >= 1,
-      "values must have dim >= 1 but got values.dim() = ",
+      values.dim() == 1,
+      "values must have dim=1 but got values.dim()=",
       values.dim());
-
+  // Note, this check also enforces `crow_indices.numel() >= 1`
   TORCH_CHECK(
-      crow_indices.dim() == col_indices.dim(),
-      "Number of dimensions of crow_indices and col_indices must be the same.");
-  TORCH_CHECK(
-      crow_indices.dim() == values.dim(),
-      "Number of dimensions of indices and values must be the same.");
-  TORCH_CHECK(
-      crow_indices.dim() == size.size() - 1,
-      "Number of dimensions of indices must be one less than the number of dimensions of the provided size.");
-
-  // All batch sizes must be the same
-  auto batch_size = size.slice(0, size.size() - 2);
-  auto crow_indices_batch_size = crow_indices.sizes().slice(0, crow_indices.dim() - 1);
-  auto col_indices_batch_size = col_indices.sizes().slice(0, col_indices.dim() - 1);
-  auto values_batch_size = values.sizes().slice(0, values.dim() - 1);
-  TORCH_CHECK(
-      batch_size == crow_indices_batch_size &&
-      batch_size == col_indices_batch_size &&
-      batch_size == values_batch_size,
-      "All batch dimensions of the provided size, indices, and values must be the same.");
-
-  // Note, this check also enforces `crow_indices.size(-1) >= 1`
-  TORCH_CHECK(
-      crow_indices.size(-1) == (size[size.size() - 2] + 1),
-      "crow_indices.size(-1) must be equal to size[-2] + 1 (that is ", size[size.size() - 2] + 1, "), but got: ",
-      crow_indices.size(-1));
+      crow_indices.numel() == (size[0] + 1),
+      "crow_indices.numel() must be size(0) + 1, but got: ",
+      crow_indices.numel());
   TORCH_CHECK(
       col_indices.numel() == values.numel(),
-      "col_indices and values must have the same number of elements, but got col_indices.numel(): ",
+      "col_indices and values must have equal sizes, but got col_indices.numel(): ",
       col_indices.numel(),
       ", values.numel(): ",
       values.numel());
@@ -109,28 +86,22 @@
   // Indices invariants
   AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
     Tensor crow_indices_cpu = crow_indices.to(kCPU);
-    auto crow_indices_data_ptr = crow_indices_cpu.data_ptr<index_t>();
-    auto batch_stride = crow_indices_cpu.dim() >= 2 ? crow_indices_cpu.stride(-2) : 0;
-    for (const auto batch_id : c10::irange(batchCount(crow_indices_cpu))) {
-      TORCH_CHECK(
-          crow_indices_data_ptr[batch_id*batch_stride] == 0,
-          "(Batch element ", batch_id, ") ",
-          ": 0th value of crow_indices must be 0, but it is ", crow_indices_data_ptr[batch_id*batch_stride]);
-      TORCH_CHECK(
-          crow_indices_data_ptr[batch_id*batch_stride + crow_indices.size(-1) - 1] == col_indices.size(-1),
-          "(Batch element ", batch_id, ") ",
-          "last value of crow_indices should be equal to the length of col_indices.");
+    auto crow_indices_accessor = crow_indices_cpu.accessor<index_t, 1>();
+    TORCH_CHECK(
+        crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
 
-      for (int i =  1; i <= size[size.size() - 2]; i++) {
-        TORCH_CHECK(
-            crow_indices_data_ptr[batch_id*batch_stride + i - 1] <= crow_indices_data_ptr[batch_id*batch_stride + i],
-            "(Batch element ", batch_id, ") ",
-            "at position i = ", i, ", the condition crow_indices[i - 1] <= crow_indices[i] fails");
-      }
+    TORCH_CHECK(
+        crow_indices_accessor[crow_indices.numel() - 1] == col_indices.numel(),
+        "last value of crow_indices should be equal to the length of col_indices.");
+
+    for (int i =  1; i <= size[0]; i++) {
+      TORCH_CHECK(
+          crow_indices_accessor[i - 1] <= crow_indices_accessor[i],
+          "at position i = ", i, ", this condition crow_indices[i - 1] <= crow_indices[i] fails");
     }
     if (col_indices.numel() > 0) {
       TORCH_CHECK(0 <= col_indices.min().item<index_t>(), "col_indices.min() should be greater or equal to zero");
-      TORCH_CHECK(size[size.size() - 1] > col_indices.max().item<index_t>(), "size[-1] should be greater than col_indices.max()");
+      TORCH_CHECK(size[1] > col_indices.max().item<index_t>(), "size(1) should be greater than col_indices.max()");
     }
   });
 
@@ -242,10 +213,13 @@
     c10::optional<bool> pin_memory) {
   // See [Note: hacky wrapper removal for TensorOptions]
   TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
-  // std::array<int64_t, 2> size = {0, 0};
-  auto size = DimVector(IntArrayRef(col_indices.sizes().data(), col_indices.dim() - 1));
-  size.push_back(crow_indices.size(-1) - 1);
-  size.push_back(col_indices.max().item<int64_t>() + 1);
+  std::array<int64_t, 2> size = {0, 0};
+  if (col_indices.numel() > 0) {
+    AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_construct_check", [&] {
+      size[0] = crow_indices.numel() - 1;
+      size[1] = col_indices.max().item<index_t>() + 1;
+    });
+  }
 
   at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
 
@@ -269,21 +243,16 @@
     c10::optional<MemoryFormat> optional_memory_format) {
   check_size_nonnegative(size);
 
-  TORCH_CHECK(size.size() >= 2, "torch.empty: Only batched sparse CSR matrices are supported, but got size ", size);
+  TORCH_CHECK(size.size() == 2, "torch.empty: Only 2D sparse CSR tensors are supported.");
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(layout == Layout::SparseCsr);
 
-  auto rows = size[size.size() - 2];
+  auto rows = size[0];
   int64_t nnz = 0;
 
-  auto crow_indices_size = DimVector(size.slice(0, size.size() - 2));
-  crow_indices_size.push_back(rows + 1);
-  auto col_indices_values_size = DimVector(size.slice(0, size.size() - 2));
-  col_indices_values_size.push_back(nnz);
-
   TensorOptions options = TensorOptions().dtype(ScalarType::Long).layout(Layout::Strided).device(device).pinned_memory(pin_memory);
-  auto crow_indices = at::empty(crow_indices_size, options);
-  auto col_indices = at::empty(col_indices_values_size, options);
-  auto values = at::empty(col_indices_values_size, options.dtype(dtype));
+  auto crow_indices = at::empty({rows + 1}, options);
+  auto col_indices = at::empty({nnz}, options);
+  auto values = at::empty({nnz}, options.dtype(dtype));
 
   return at::native::_sparse_csr_tensor_unsafe(
       crow_indices,
@@ -301,13 +270,13 @@
     IntArrayRef size,
     c10::optional<MemoryFormat> optional_memory_format) {
   check_size_nonnegative(size);
-  TORCH_CHECK(size.size() >= 2, "torch.resize_: Only batched sparse CSR matrices are supported, but got size ", size);
+  TORCH_CHECK(size.size() == 2, "torch.resize_: Only 2D sparse CSR tensors are supported.");
   TORCH_CHECK(
-      self.size(-1) <= size[size.size() - 1],
+      self.size(1) <= size[1],
       "torch.resize_: Resizing columns of sparse CSR tensors to a smaller value is not supported. ",
       "The original number of columns is ",
-      self.size(-1),
-      " while the requested new number of columns is ", size[size.size() - 1], ".");
+      self.size(1),
+      " while the requested new number of columns is ", size[1], ".");
   get_sparse_csr_impl(self)->resize_(self._nnz(), size);
   return self;
 }
diff --git a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
index 7ae5f3b..76eaf61 100644
--- a/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
+++ b/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
@@ -638,10 +638,13 @@
       " in add operation");
 
   auto src_values = src.values();
+  auto src_crow_indices = src.crow_indices();
+  auto src_col_indices = src.col_indices();
 
   resize_output(out, dense.sizes());
 
   Tensor resultBuffer = out;
+  Tensor valuesBuffer = src_values.to(commonDtype);
 
   if (out.scalar_type() != commonDtype) {
     resultBuffer = dense.to(commonDtype);
@@ -649,15 +652,6 @@
     resultBuffer.copy_(dense);
   }
 
-  if (src._nnz() == 0) {
-    return;
-  }
-
-  auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
-  resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
-  auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
-  auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
-
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
       kHalf,
       kBool,
@@ -677,26 +671,27 @@
              &alpha,
              &src_crow_indices,
              &src_col_indices]() {
-              auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
-              auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
+              auto values_accessor = valuesBuffer.accessor<scalar_t, 1>();
               scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
               scalar_t cast_value = alpha.to<scalar_t>();
 
               auto crow_indices_accessor =
-                  src_crow_indices.accessor<index_t, 2>();
+                  src_crow_indices.accessor<index_t, 1>();
               auto col_indices_accessor =
-                  src_col_indices.accessor<index_t, 2>();
-              auto out_strides = resultBuffer.strides();
+                  src_col_indices.accessor<index_t, 1>();
+              auto out_strides0 = resultBuffer.strides()[0];
+              auto out_strides1 = resultBuffer.strides()[1];
 
-              for (const auto batch_idx : c10::irange(batch_count)) {
-                for (const auto irow : c10::irange(src_crow_indices.size(-1) - 1)) {
-                  index_t start_index = crow_indices_accessor[batch_idx][irow];
-                  index_t end_index = crow_indices_accessor[batch_idx][irow + 1];
-                  for (const auto i : c10::irange(start_index, end_index)) {
-                    auto icol = col_indices_accessor[batch_idx][i];
-                    auto index = batch_idx * out_strides[0] + irow * out_strides[1] + icol * out_strides[2];
-                    out_ptr[index] += cast_value * values_accessor[batch_idx][i];
-                  }
+              for (index_t irow = 0; irow < src_crow_indices.size(0) - 1;
+                   ++irow) {
+                index_t start_index = crow_indices_accessor[irow];
+                index_t end_index = crow_indices_accessor[irow + 1];
+
+                for (index_t i = start_index; i < end_index; ++i) {
+                  auto icol = col_indices_accessor[i];
+                  auto index = resultBuffer.storage_offset() +
+                      irow * out_strides0 + icol * out_strides1;
+                  out_ptr[index] += cast_value * values_accessor[i];
                 }
               }
             });
diff --git a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
index a567bbe..3444942 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
+++ b/aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp
@@ -978,18 +978,6 @@
   auto B_col_indices_ptr = B_col_indices.data_ptr<int>();
   auto C_col_indices_ptr = C_col_indices.data_ptr<int>();
 
-  // Windows compilers don't support nested macros
-  // so we need this lambda outside of the AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES
-  auto fix_nnz = [&C_crow_indices, &m](int nnz) -> int {
-    // For some reason POINTER_MODE_HOST is not working here
-    // Let's extract manually the nnz from the C_crow_indices
-    #if AT_ROCM_ENABLED()
-    return std::max({nnz, C_crow_indices.narrow(-1, m, 1).item<int>()});
-    #else
-    return nnz;
-    #endif
-  };
-
   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
       C.scalar_type(), "add_out_sparse_csr_cuda_impl", [&] {
         auto beta_ = beta.to<scalar_t>();
@@ -1050,8 +1038,6 @@
             &nnzC,
             work_data.get());
 
-        nnzC = fix_nnz(nnzC);
-
         // Resize result using nnz information from cusparse
         col_indices_and_values_resize_(C, nnzC);
         C_col_indices = C.col_indices();
diff --git a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
index 436aefa..c13984f 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseCsrTensorMath.cu
@@ -159,26 +159,18 @@
       " in add operation");
 
   Tensor src_values = src.values();
+  Tensor src_crow_indices = src.crow_indices();
+  Tensor src_col_indices = src.col_indices();
 
   resize_output(output, dense.sizes());
 
   Tensor resultBuffer = output;
-
+  Tensor valuesBuffer = src_values.to(commonDtype);
   if (output.scalar_type() != commonDtype) {
     resultBuffer = dense.to(commonDtype);
   } else if (!is_same_tensor(output, dense)) {
     resultBuffer.copy_(dense);
   }
-
-  if (src._nnz() == 0) {
-    return output;
-  }
-
-  auto valuesBuffer = src_values.to(commonDtype).view({-1, src_values.size(-1)});
-  resultBuffer = resultBuffer.view({-1, output.size(-2), output.size(-1)});
-  auto src_crow_indices = src.crow_indices().view({-1, src.crow_indices().size(-1)});
-  auto src_col_indices = src.col_indices().view({-1, src.col_indices().size(-1)});
-
   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
       kHalf, kBool, kBFloat16,
       commonDtype,
@@ -188,7 +180,6 @@
             src_crow_indices.scalar_type(),
             "csr_add_out_crow_indices",
               [&valuesBuffer, &resultBuffer, &alpha, &src_crow_indices, &src_col_indices]() {
-                auto batch_count = resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
                 scalar_t* values_accessor = valuesBuffer.data_ptr<scalar_t>();
                 scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
                 scalar_t cast_value = alpha.to<scalar_t>();
@@ -198,11 +189,8 @@
                 int64_t out_storage_offset = resultBuffer.storage_offset();
 
                 auto out_strides = resultBuffer.strides();
-                auto out_strides0 = out_strides[0];
-                auto out_strides1 = out_strides[1];
-                auto crow_stride0 = src_crow_indices.stride(0);
-                auto col_stride0 = src_col_indices.stride(0);
-                auto val_stride0 = valuesBuffer.stride(0);
+                int64_t out_strides0 = out_strides[0];
+                int64_t out_strides1 = out_strides[1];
 
                 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
                 at::cuda::ThrustAllocator allocator;
@@ -212,29 +200,24 @@
                thrust::for_each(
                     policy,
                     thrust::make_counting_iterator(int64_t(0)),
-                    thrust::make_counting_iterator(int64_t(src_crow_indices.size(-1) - 1)),
+                    thrust::make_counting_iterator(int64_t(src_crow_indices.size(0) - 1)),
                     [values_accessor,
                     crow_indices_accessor,
                     col_indices_accessor,
                     out_ptr,
-                    cast_value,
+                    out_storage_offset,
                     out_strides0,
-                    out_strides1,
-                    crow_stride0,
-                    col_stride0,
-                    val_stride0,
-                    batch_count
+                    cast_value,
+                    out_strides1
                     ]__device__(int64_t irow) {
-                      for (index_t batch_idx = 0; batch_idx < batch_count; batch_idx++) {
-                        index_t start_index = crow_indices_accessor[batch_idx*crow_stride0 + irow];
-                        index_t end_index = crow_indices_accessor[batch_idx*crow_stride0 + irow + 1];
+                        index_t start_index = crow_indices_accessor[irow];
+                        index_t end_index = crow_indices_accessor[irow + 1];
 
                         for (index_t i = start_index; i < end_index; ++i) {
-                            auto icol = col_indices_accessor[batch_idx*col_stride0 + i];
-                            auto index = batch_idx * out_strides0 + irow * out_strides1 + icol;
-                            out_ptr[index] += cast_value * values_accessor[batch_idx*val_stride0 + i];
+                            auto icol = col_indices_accessor[i];
+                            auto index = out_storage_offset + irow * out_strides0 + icol * out_strides1;
+                            out_ptr[index] += cast_value * values_accessor[i];
                         }
-                      }
                     });
               });
       });
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index 004bcf9..60ebebf 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -166,33 +166,6 @@
             self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values())
 
     @dtypes(*get_all_dtypes())
-    def test_sparse_csr_batch_constructor(self, device, dtype):
-        batch_shape = (2, 3)
-        crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
-        col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
-        values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
-        for index_dtype in [torch.int32, torch.int64]:
-            sparse = torch.sparse_csr_tensor(crow_indices.to(index_dtype),
-                                             col_indices.to(index_dtype),
-                                             values,
-                                             size=(*batch_shape, 2, 10),
-                                             dtype=dtype,
-                                             device=device)
-            self.assertEqual((*batch_shape, 2, 10), sparse.shape)
-            self.assertEqual(crow_indices.to(index_dtype), sparse.crow_indices())
-            self.assertEqual(col_indices.to(index_dtype), sparse.col_indices())
-            self.assertEqual(values, sparse.values())
-
-    @dtypes(*get_all_dtypes())
-    def test_sparse_csr_batch_constructor_shape_inference(self, device, dtype):
-        batch_shape = (2, 3)
-        crow_indices = torch.tensor([0, 2, 4], device=device).repeat(6, 1).reshape(*batch_shape, -1)
-        col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
-        values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
-        sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
-        self.assertEqual((*batch_shape, crow_indices.shape[-1] - 1, col_indices.max() + 1), sparse.shape)
-
-    @dtypes(*get_all_dtypes())
     def test_sparse_csr_constructor_from_lists(self, device, dtype):
         # without size
         sparse = torch.sparse_csr_tensor([0, 2, 4],
@@ -225,17 +198,15 @@
     @dtypes(*get_all_dtypes())
     def test_empty(self, device, dtype):
         ns = [5, 2, 0]
-        batch_shapes = [(), (2,), (2, 3)]
-        for m, n, b in itertools.product(ns, ns, batch_shapes):
-            shape = (*b, m, n)
+        for shape in itertools.product(ns, ns):
             result = torch.empty(shape, dtype=dtype, device=device, layout=torch.sparse_csr)
             self.assertEqual(result.shape, shape)
             self.assertEqual(result.dtype, dtype)
             self.assertEqual(result.device, torch.device(device))
             self.assertEqual(result.layout, torch.sparse_csr)
-            self.assertEqual(result.crow_indices().shape, (*b, shape[-2] + 1,))
-            self.assertEqual(result.col_indices().shape, (*b, 0,))
-            self.assertEqual(result.values().shape, (*b, 0,))
+            self.assertEqual(result.crow_indices().shape, (shape[0] + 1,))
+            self.assertEqual(result.col_indices().shape, (0,))
+            self.assertEqual(result.values().shape, (0,))
             self.assertEqual(result._nnz(), 0)
             self.assertEqual(result.crow_indices().device, torch.device(device))
             self.assertEqual(result.col_indices().device, torch.device(device))
@@ -247,22 +218,23 @@
     @skipMeta
     @dtypes(*get_all_dtypes())
     def test_empty_errors(self, device, dtype):
-        with self.assertRaisesRegex(RuntimeError, "torch.empty: Only batched sparse CSR matrices are supported, but got size"):
+        with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
             torch.empty((5,), dtype=dtype, device=device, layout=torch.sparse_csr)
 
+        with self.assertRaisesRegex(RuntimeError, "torch.empty: Only 2D sparse CSR tensors are supported."):
+            torch.empty((2, 3, 4), dtype=dtype, device=device, layout=torch.sparse_csr)
+
     @skipMeta
     @dtypes(*get_all_dtypes())
     def test_clone(self, device, dtype):
-        from operator import mul
-        from functools import reduce
-        for batch_shape in ((), (2,), (2, 3)):
-            prod = reduce(mul, batch_shape, 1)
-            crow_indices = torch.tensor([0, 2, 4], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
-            col_indices = torch.tensor([0, 1, 0, 1], device=device).repeat(prod, 1).reshape(*batch_shape, -1)
-            values = torch.tensor([1, 2, 3, 4], device=device, dtype=dtype).repeat(prod, 1).reshape(*batch_shape, -1)
-            sparse = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
-            cloned_sparse = sparse.clone()
-            self.assertEqual(sparse, cloned_sparse)
+        x = torch.sparse_csr_tensor([0, 2, 4],
+                                    [0, 1, 0, 1],
+                                    [1, 2, 3, 4],
+                                    dtype=dtype,
+                                    device=device)
+        y = x.clone()
+
+        self.assertEqual(x, y)
 
     @skipMeta
     @dtypes(*get_all_dtypes())
@@ -277,10 +249,9 @@
             self.assertEqual(a, b)
 
         ns = [5, 2, 0]
-        batch_shapes = [(), (2,), (2, 3)]
-        for (m, n, b), index_dtype in zip(itertools.product(ns, ns, batch_shapes), [torch.int32, torch.int64]):
-            run_test((*b, m, n), 0, index_dtype)
-            run_test((*b, m, n), m * n, index_dtype)
+        for shape, index_dtype in zip(itertools.product(ns, ns), [torch.int32, torch.int64]):
+            run_test(shape, 0, index_dtype)
+            run_test(shape, shape[0] * shape[1], index_dtype)
 
     @skipMeta
     @dtypes(*get_all_dtypes())
@@ -304,31 +275,25 @@
     @skipMeta
     @dtypes(*get_all_dtypes())
     def test_resize(self, device, dtype):
-        batch_shapes = [(), (2,), (2, 3)]
-        for index_dtype, b in zip([torch.int32, torch.int64], batch_shapes):
-            shape = (*b, 2, 3)
+        for index_dtype in [torch.int32, torch.int64]:
+            shape = (2, 3)
             nnz = 6
             a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
 
-            new_shape = (*b, 4, 5)
+            new_shape = (4, 5)
             a.resize_(new_shape)
 
             self.assertEqual(a.shape, new_shape)
             # resize to larger shape doesn't add specified elements
             self.assertEqual(a._nnz(), nnz)
 
-            new_shape = (*b, 1, 5)
+            new_shape = (1, 5)
             a.resize_(new_shape)
 
             self.assertEqual(a.shape, new_shape)
             # resize to smaller shape trims specified elements
             self.assertEqual(a._nnz(), 5)
 
-            # trim batched dimensions
-            a.resize_(new_shape[-2], new_shape[-1])
-            self.assertEqual(a.shape, (new_shape[-2], new_shape[-1]))
-            self.assertEqual(a._nnz(), 5)
-
     @skipMeta
     @dtypes(*get_all_dtypes())
     def test_resize_errors(self, device, dtype):
@@ -337,7 +302,7 @@
             nnz = 6
             a = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=index_dtype)
 
-            with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only batched sparse CSR matrices are supported"):
+            with self.assertRaisesRegex(RuntimeError, "torch.resize_: Only 2D sparse CSR tensors are supported."):
                 new_shape = (4,)
                 a.resize_(new_shape)
 
@@ -382,62 +347,49 @@
                                     torch.tensor([1, 2, 3, 4]))
 
     def test_factory_shape_invariants_check(self, device):
-        crow_indices = torch.tensor([0, 2, 4], device=device)
-        col_indices = torch.tensor([0, 1, 0, 1], device=device)
-        values = torch.tensor([1, 2, 3, 4], device=device)
+        crow_indices = [0, 2, 4]
+        col_indices = [0, 1, 0, 1]
+        values = [1, 2, 3, 4]
         size = (2, 10)
-        torch.sparse_csr_tensor(crow_indices, col_indices, values, size, device=device)
+        torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
+                                device=device)
 
-        with self.assertRaisesRegex(RuntimeError, r"size of a batched CSR tensor must have length >= 2, but got: 1"):
-            torch.sparse_csr_tensor(crow_indices, col_indices, values,
-                                    size=(2,),
+        with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
+            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
+                                    size=(2, 10, 2),
                                     device=device)
 
-        with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim >= 1 but got crow_indices\.dim\(\)\ = 0"):
-            torch.sparse_csr_tensor(torch.zeros((), device=device, dtype=torch.int64),
-                                    col_indices,
-                                    values,
+        with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
+            torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
+                                    torch.tensor(col_indices),
+                                    torch.tensor(values),
                                     size,
                                     device=device)
 
-        with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim >= 1 but got col_indices\.dim\(\)\ = 0"):
-            torch.sparse_csr_tensor(crow_indices,
-                                    torch.zeros((), device=device, dtype=torch.int64),
-                                    values,
+        with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
+            torch.sparse_csr_tensor(torch.tensor(crow_indices),
+                                    torch.tensor(col_indices).repeat(2, 1),
+                                    torch.tensor(values),
                                     size,
                                     device=device)
 
-        with self.assertRaisesRegex(RuntimeError, r"values must have dim >= 1 but got values\.dim\(\)\ = 0"):
-            torch.sparse_csr_tensor(crow_indices,
-                                    col_indices,
-                                    torch.zeros((), device=device, dtype=torch.int64),
+        with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
+            torch.sparse_csr_tensor(torch.tensor(crow_indices),
+                                    torch.tensor(col_indices),
+                                    torch.tensor(values).repeat(2, 1),
                                     size,
                                     device=device)
 
         with self.assertRaisesRegex(RuntimeError,
-                                    r"crow_indices\.size\(-1\) must be equal to size\[-2\] \+ 1 \(that is 2\), but got: 3"):
-            torch.sparse_csr_tensor(crow_indices, col_indices, values, (1, 1),
+                                    r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
+            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
                                     device=device)
 
 
         with self.assertRaisesRegex(RuntimeError,
-                                    r"Number of dimensions of crow_indices and col_indices must be the same"):
-            torch.sparse_csr_tensor(crow_indices, col_indices.repeat(2, 1), values, size,
-                                    device=device)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    r"Number of dimensions of indices and values must be the same"):
-            torch.sparse_csr_tensor(crow_indices, col_indices, values.repeat(2, 1), size,
-                                    device=device)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    r"Number of dimensions of indices must be one less"):
-            torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(2, 1), values.repeat(2, 1), size,
-                                    device=device)
-
-        with self.assertRaisesRegex(RuntimeError,
-                                    r"All batch dimensions of the provided size, indices, and values must be the same"):
-            torch.sparse_csr_tensor(crow_indices.repeat(2, 1), col_indices.repeat(3, 1), values.repeat(4, 1), (2, 2, 10),
+                                    r"col_indices and values must have equal sizes, " +
+                                    r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
+            torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
                                     device=device)
 
     def test_factory_indices_invariants_check(self, device):
@@ -456,7 +408,7 @@
 
         with self.assertRaisesRegex(RuntimeError,
                                     r"at position i \= 2," +
-                                    r" the condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
+                                    r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
             torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
                                     device=device)
 
@@ -464,7 +416,7 @@
             torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
                                     device=device)
 
-        with self.assertRaisesRegex(RuntimeError, r"size\[-1\] should be greater than col_indices\.max\(\)"):
+        with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
             torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
                                     device=device)
 
@@ -569,12 +521,12 @@
             sparse = dense.to_sparse_csr()
             self.assertEqual(sparse.to_dense(), dense)
 
-        batch_shape = (2, 3)
-        crow_indices = torch.tensor([0, 3, 5], device=device).repeat(6, 1).reshape(*batch_shape, -1)
-        col_indices = torch.tensor([0, 1, 2, 0, 1], device=device).repeat(6, 1).reshape(*batch_shape, -1)
-        values = torch.tensor([1, 2, 1, 3, 4], device=device, dtype=dtype).repeat(6, 1).reshape(*batch_shape, -1)
-        csr = torch.sparse_csr_tensor(crow_indices, col_indices, values, dtype=dtype, device=device)
-        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device).repeat(6, 1).reshape(csr.shape)
+        crow_indices = torch.tensor([0, 3, 5])
+        col_indices = torch.tensor([0, 1, 2, 0, 1])
+        values = torch.tensor([1, 2, 1, 3, 4], dtype=dtype)
+        csr = torch.sparse_csr_tensor(crow_indices, col_indices,
+                                      values, dtype=dtype, device=device)
+        dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
         self.assertEqual(csr.to_dense(), dense)
 
     @skipCPUIfNoMklSparse
@@ -1147,9 +1099,6 @@
     @dtypes(torch.float, torch.double)
     def test_add(self, device, dtype):
         def _test_spadd_shape(nnz, shape):
-            # sparse.to_dense() uses torch.add internally so if torch.add is wrong,
-            # the dense tensor will be wrong but this test would still pass
-            # there's a separate test that checks for the correctness of the .to_dense() call
             x = self.genSparseCSRTensor(shape, nnz, dtype=dtype, device=device, index_dtype=torch.int32)
             y = torch.randn(*shape, dtype=dtype, device=device)
             r = random.random()
@@ -1171,12 +1120,10 @@
 
             self.assertEqual(res, expected)
 
-        ns = [2, 5]
-        batch_shapes = [(), (2,), (2, 3)]
-        for b, m, n in itertools.product(batch_shapes, ns, ns):
-            _test_spadd_shape(0, (*b, m, n))
-            _test_spadd_shape(m * n // 2, (*b, m, n))
-            _test_spadd_shape(m * n, (*b, m, n))
+        _test_spadd_shape(10, [100, 100])
+        _test_spadd_shape(0, [100, 100])
+        _test_spadd_shape(10, [100, 1])
+        _test_spadd_shape(10, [1, 100])
 
     @dtypes(torch.float, torch.double)
     def test_mul(self, device, dtype):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index bd8732c..1e2d246 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1999,11 +1999,9 @@
         return crow_indices.to(device=device)
 
     def genSparseCSRTensor(self, size, nnz, *, device, dtype, index_dtype):
-        from operator import mul
-        from functools import reduce
         sparse_dim = 2
-        assert all(size[d] > 0 for d in range(len(size))) or nnz == 0, 'invalid arguments'
-        assert len(size) >= sparse_dim
+        assert all(size[d] > 0 for d in range(sparse_dim)) or nnz == 0, 'invalid arguments'
+        assert len(size) == sparse_dim
 
         def random_sparse_csr(n_rows, n_cols, nnz):
             crow_indices = self._make_crow_indices(n_rows, n_cols, nnz, device=device, dtype=index_dtype)
@@ -2017,15 +2015,7 @@
             values = make_tensor([nnz], device=device, dtype=dtype, low=low, high=high)
             return values, crow_indices, col_indices
 
-        batch_shape = size[:-2]
-        n_batch = reduce(mul, batch_shape, 1)
-
-        sparse_tensors = [random_sparse_csr(size[-2], size[-1], nnz) for _ in range(n_batch)]
-        sparse_tensors_it = map(list, zip(*sparse_tensors))
-        values = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
-        crow_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
-        col_indices = torch.stack(next(sparse_tensors_it)).reshape(*batch_shape, -1)
-
+        values, crow_indices, col_indices = random_sparse_csr(size[0], size[1], nnz)
         return torch.sparse_csr_tensor(crow_indices,
                                        col_indices,
                                        values, size=size, dtype=dtype, device=device)