Use fallback approach for nested matmul (#85311)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85311
Approved by: https://github.com/cpuhrsch, https://github.com/drisspg
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 073cad7..84d62e9 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -864,9 +864,12 @@
   else if (!self.is_nested() && mat2.is_nested()) {
     AT_ERROR("Expected both to be nested, but got a non-nested self and nested other");
   }
+  // to_padded_tensor only supports contiguous inputs
+  auto self_contig = self.contiguous();
+  auto mat2_contig = mat2.contiguous();
   // dispatcher should have guaranteed that at least one is nested
-  auto self_ptr = get_nested_tensor_impl(self),
-      mat2_ptr = get_nested_tensor_impl(mat2);
+  const auto self_ptr = get_nested_tensor_impl(self_contig);
+  const auto mat2_ptr = get_nested_tensor_impl(mat2_contig);
   int64_t self_dim = self_ptr->dim(),
       mat2_dim = mat2_ptr->dim();
   TORCH_CHECK(
@@ -877,53 +880,41 @@
       mat2_dim >= 3,
       "matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ",
       mat2_dim);
-  TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have same rank");
+  TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have the same rank");
   int64_t ntensors = self_ptr->size(0),
       ntensors2 = mat2_ptr->size(0);
   TORCH_CHECK(ntensors == ntensors2,
       "matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", ntensors,
       " but got: ", ntensors2, ".");
-  const Tensor& self_buffer = self_ptr->get_buffer(),
-      & mat2_buffer = mat2_ptr->get_buffer();
-  std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
-      mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
-      self_strides = NestedTensor_get_strides(self_ptr),
-      mat2_strides = NestedTensor_get_strides(mat2_ptr);
-  const std::vector<int64_t>& self_offsets = self_ptr->get_offsets(),
-      & mat2_offsets = mat2_ptr->get_offsets();
-  // create a contiguous output
-  std::vector<int64_t> batch_sizes;
-  Tensor output;
-  std::tie(batch_sizes, output) = matmul_nested_helper(
-      self_sizes, mat2_sizes, self_buffer.options(), self_ptr->get_nested_size_tensor().options());
-  // call tensor matmul
-  // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient
-  //       until we have specialized nested tensor bmm kernel
-  //       useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_`
-  //                        `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl`
-  std::vector<Tensor> output_unbind = output.unbind();
-  for (int64_t i = 0; i < ntensors; i++) {
-    const IntArrayRef& self_size = self_sizes[i],
-        & mat2_size = mat2_sizes[i];
-    const int64_t& batch_size = batch_sizes[i];
-    if (batch_size == 1) {
-      at::mm_out(
-          output_unbind[i],
-          self_buffer.as_strided(self_size, self_strides[i], self_offsets[i]),
-          mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i])
-          );
-    }
-    else {
-      at::bmm_out(
-          output_unbind[i],
-          self_buffer.as_strided(self_size, self_strides[i], self_offsets[i])
-              .reshape({batch_size, self_size[self_dim - 1 - 2], self_size[self_dim - 1 - 1]}),
-          mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i])
-              .reshape({batch_size, mat2_size[self_dim - 1 - 2], mat2_size[self_dim - 1 - 1]})
-          );
-    }
-  }
-  return output;
+  // Ensure batch dimensions have the same sizes (no broadcasting).
+  const auto& self_sizes = self_ptr->get_nested_size_tensor();
+  const auto& mat2_sizes = mat2_ptr->get_nested_size_tensor();
+  const auto& self_batch_sizes = self_sizes.narrow(1, 0, self_dim-3);
+  const auto& mat2_batch_sizes = mat2_sizes.narrow(1, 0, mat2_dim-3);
+  TORCH_CHECK(at::equal(self_batch_sizes, mat2_batch_sizes),
+    "matmul: For nested tensors, batch dimensions must have the same sizes, ",
+    "no broadcasting is currently performed. Got batch shapes for self ",
+    self_batch_sizes,
+    " and batch shapes for mat2 ",
+    mat2_batch_sizes);
+  // Ensure last dim of self and second last dim of mat2 have the same size
+  const auto& self_dim_size = self_sizes.select(1, -1);
+  const auto& mat2_dim_size = mat2_sizes.select(1, -2);
+  TORCH_CHECK(at::equal(self_dim_size, mat2_dim_size),
+    "matmul: Nested tensors cannot be matrix multiplied, last dimension of self has sizes",
+    self_dim_size,
+    "second last dimension of mat2 has sizes",
+    mat2_dim_size);
+  // Construct output size from input sizes
+  Tensor output_sizes = self_sizes.clone();
+  // The last entry in every row of output_sizes should be last column of mat2_sizes
+  output_sizes.index_put_({at::indexing::Slice(), -1}, mat2_sizes.select(1, -1).clone());
+
+  auto self_padded = self_contig.to_padded_tensor(0.);
+  auto mat2_padded = mat2_contig.to_padded_tensor(0.);
+  auto output_padded = at::matmul(self_padded, mat2_padded);
+  auto output_nested = nested_from_padded_generic(output_padded, output_sizes);
+  return output_nested;
 }
 
 Tensor& matmul_out_nested(const Tensor& tensor1, const Tensor& tensor2, Tensor& result) {
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 3944ef2..9de07c9 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -993,7 +993,7 @@
             r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
             lambda: torch.matmul(nt1, nt0)
         )
-        # error case: incompatible generalized batch size
+        # error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
         nt0 = torch.nested_tensor([torch.randn((2, 2, 4)),
                                    torch.randn((2, 3, 4))],
                                   device=device, dtype=dtype)
@@ -1002,23 +1002,26 @@
                                   device=device, dtype=dtype)
         self.assertRaisesRegex(
             RuntimeError,
-            r"matmul: For nested tensors, no broadcasting is currently performed: "
-            r"[0-9]+-th nested matrices in batch at dimension [0-9]+ "
-            r"have mismatching sizes [0-9]+ and [0-9]+",
+            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
             lambda: torch.matmul(nt0, nt1)
         )
+        # error case: incompatible batch sizes that should technically broadcast
+        nt0 = torch.nested_tensor([torch.randn((2, 2, 4)),
+                                   torch.randn((1, 3, 4))],
+                                  device=device, dtype=dtype)
+        nt1 = torch.nested_tensor([torch.randn((1, 4, 6)),
+                                   torch.randn((3, 4, 5))],
+                                  device=device, dtype=dtype)
         self.assertRaisesRegex(
             RuntimeError,
-            r"matmul: For nested tensors, no broadcasting is currently performed: "
-            r"[0-9]+-th nested matrices in batch at dimension [0-9]+ "
-            r"have mismatching sizes [0-9]+ and [0-9]+",
-            lambda: torch.matmul(nt1, nt0)
+            "matmul(): For nested tensors, batch dimensions must have the same sizes,",
+            lambda: torch.matmul(nt0, nt1)
         )
         # error case: underlying matrices cannot be multiplied
         nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
         self.assertRaisesRegex(
             RuntimeError,
-            r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
+            "matmul(): Nested tensors cannot be matrix multiplied",
             lambda: torch.matmul(nt0, nt0)
         )
         # normal nested tensor: 3D
@@ -1027,11 +1030,11 @@
         actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
         expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0))
         self.assertEqual(actual, expect)
-        # normal nested tensor: 4D
-        nt0 = torch.nested_tensor([torch.randn((8, 2, 4)),
+        # normal nested tensor: 4D (with testing for batch_size=1)
+        nt0 = torch.nested_tensor([torch.randn((1, 2, 4)),
                                    torch.randn((8, 3, 7))],
                                   device=device, dtype=dtype)
-        nt1 = torch.nested_tensor([torch.randn((8, 4, 6)),
+        nt1 = torch.nested_tensor([torch.randn((1, 4, 6)),
                                    torch.randn((8, 7, 5))],
                                   device=device, dtype=dtype)
         actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)