Use fallback approach for nested matmul (#85311)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85311
Approved by: https://github.com/cpuhrsch, https://github.com/drisspg
diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp
index 073cad7..84d62e9 100644
--- a/aten/src/ATen/native/nested/NestedTensorMath.cpp
+++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp
@@ -864,9 +864,12 @@
else if (!self.is_nested() && mat2.is_nested()) {
AT_ERROR("Expected both to be nested, but got a non-nested self and nested other");
}
+ // to_padded_tensor only supports contiguous inputs
+ auto self_contig = self.contiguous();
+ auto mat2_contig = mat2.contiguous();
// dispatcher should have guaranteed that at least one is nested
- auto self_ptr = get_nested_tensor_impl(self),
- mat2_ptr = get_nested_tensor_impl(mat2);
+ const auto self_ptr = get_nested_tensor_impl(self_contig);
+ const auto mat2_ptr = get_nested_tensor_impl(mat2_contig);
int64_t self_dim = self_ptr->dim(),
mat2_dim = mat2_ptr->dim();
TORCH_CHECK(
@@ -877,53 +880,41 @@
mat2_dim >= 3,
"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: ",
mat2_dim);
- TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have same rank");
+ TORCH_CHECK(self_dim == mat2_dim, "matmul: both inputs must have the same rank");
int64_t ntensors = self_ptr->size(0),
ntensors2 = mat2_ptr->size(0);
TORCH_CHECK(ntensors == ntensors2,
"matmul: Expected size for the 1st dimension of 2nd input tensor to be: ", ntensors,
" but got: ", ntensors2, ".");
- const Tensor& self_buffer = self_ptr->get_buffer(),
- & mat2_buffer = mat2_ptr->get_buffer();
- std::vector<IntArrayRef> self_sizes = NestedTensor_get_sizes(self_ptr),
- mat2_sizes = NestedTensor_get_sizes(mat2_ptr),
- self_strides = NestedTensor_get_strides(self_ptr),
- mat2_strides = NestedTensor_get_strides(mat2_ptr);
- const std::vector<int64_t>& self_offsets = self_ptr->get_offsets(),
- & mat2_offsets = mat2_ptr->get_offsets();
- // create a contiguous output
- std::vector<int64_t> batch_sizes;
- Tensor output;
- std::tie(batch_sizes, output) = matmul_nested_helper(
- self_sizes, mat2_sizes, self_buffer.options(), self_ptr->get_nested_size_tensor().options());
- // call tensor matmul
- // TODO: `padding nested tensor -> bmm -> remove padding` may be more efficient
- // until we have specialized nested tensor bmm kernel
- // useful resource: `aten/src/ATen/native/cpu/LinearAlgebra.cpp/bmm_out_or_baddbmm_`
- // `aten/src/ATen/native/cuda/Blas.cpp/baddbmm_out_cuda_impl`
- std::vector<Tensor> output_unbind = output.unbind();
- for (int64_t i = 0; i < ntensors; i++) {
- const IntArrayRef& self_size = self_sizes[i],
- & mat2_size = mat2_sizes[i];
- const int64_t& batch_size = batch_sizes[i];
- if (batch_size == 1) {
- at::mm_out(
- output_unbind[i],
- self_buffer.as_strided(self_size, self_strides[i], self_offsets[i]),
- mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i])
- );
- }
- else {
- at::bmm_out(
- output_unbind[i],
- self_buffer.as_strided(self_size, self_strides[i], self_offsets[i])
- .reshape({batch_size, self_size[self_dim - 1 - 2], self_size[self_dim - 1 - 1]}),
- mat2_buffer.as_strided(mat2_size, mat2_strides[i], mat2_offsets[i])
- .reshape({batch_size, mat2_size[self_dim - 1 - 2], mat2_size[self_dim - 1 - 1]})
- );
- }
- }
- return output;
+ // Ensure batch dimensions have the same sizes (no broadcasting).
+ const auto& self_sizes = self_ptr->get_nested_size_tensor();
+ const auto& mat2_sizes = mat2_ptr->get_nested_size_tensor();
+ const auto& self_batch_sizes = self_sizes.narrow(1, 0, self_dim-3);
+ const auto& mat2_batch_sizes = mat2_sizes.narrow(1, 0, mat2_dim-3);
+ TORCH_CHECK(at::equal(self_batch_sizes, mat2_batch_sizes),
+ "matmul: For nested tensors, batch dimensions must have the same sizes, ",
+ "no broadcasting is currently performed. Got batch shapes for self ",
+ self_batch_sizes,
+ " and batch shapes for mat2 ",
+ mat2_batch_sizes);
+ // Ensure last dim of self and second last dim of mat2 have the same size
+ const auto& self_dim_size = self_sizes.select(1, -1);
+ const auto& mat2_dim_size = mat2_sizes.select(1, -2);
+ TORCH_CHECK(at::equal(self_dim_size, mat2_dim_size),
+ "matmul: Nested tensors cannot be matrix multiplied, last dimension of self has sizes",
+ self_dim_size,
+ "second last dimension of mat2 has sizes",
+ mat2_dim_size);
+ // Construct output size from input sizes
+ Tensor output_sizes = self_sizes.clone();
+ // The last entry in every row of output_sizes should be last column of mat2_sizes
+ output_sizes.index_put_({at::indexing::Slice(), -1}, mat2_sizes.select(1, -1).clone());
+
+ auto self_padded = self_contig.to_padded_tensor(0.);
+ auto mat2_padded = mat2_contig.to_padded_tensor(0.);
+ auto output_padded = at::matmul(self_padded, mat2_padded);
+ auto output_nested = nested_from_padded_generic(output_padded, output_sizes);
+ return output_nested;
}
Tensor& matmul_out_nested(const Tensor& tensor1, const Tensor& tensor2, Tensor& result) {
diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py
index 3944ef2..9de07c9 100644
--- a/test/test_nestedtensor.py
+++ b/test/test_nestedtensor.py
@@ -993,7 +993,7 @@
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
lambda: torch.matmul(nt1, nt0)
)
- # error case: incompatible generalized batch size
+ # error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
nt0 = torch.nested_tensor([torch.randn((2, 2, 4)),
torch.randn((2, 3, 4))],
device=device, dtype=dtype)
@@ -1002,23 +1002,26 @@
device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
- r"matmul: For nested tensors, no broadcasting is currently performed: "
- r"[0-9]+-th nested matrices in batch at dimension [0-9]+ "
- r"have mismatching sizes [0-9]+ and [0-9]+",
+ "matmul(): For nested tensors, batch dimensions must have the same sizes,",
lambda: torch.matmul(nt0, nt1)
)
+ # error case: incompatible batch sizes that should technically broadcast
+ nt0 = torch.nested_tensor([torch.randn((2, 2, 4)),
+ torch.randn((1, 3, 4))],
+ device=device, dtype=dtype)
+ nt1 = torch.nested_tensor([torch.randn((1, 4, 6)),
+ torch.randn((3, 4, 5))],
+ device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
- r"matmul: For nested tensors, no broadcasting is currently performed: "
- r"[0-9]+-th nested matrices in batch at dimension [0-9]+ "
- r"have mismatching sizes [0-9]+ and [0-9]+",
- lambda: torch.matmul(nt1, nt0)
+ "matmul(): For nested tensors, batch dimensions must have the same sizes,",
+ lambda: torch.matmul(nt0, nt1)
)
# error case: underlying matrices cannot be multiplied
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
- r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
+ "matmul(): Nested tensors cannot be matrix multiplied",
lambda: torch.matmul(nt0, nt0)
)
# normal nested tensor: 3D
@@ -1027,11 +1030,11 @@
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0))
self.assertEqual(actual, expect)
- # normal nested tensor: 4D
- nt0 = torch.nested_tensor([torch.randn((8, 2, 4)),
+ # normal nested tensor: 4D (with testing for batch_size=1)
+ nt0 = torch.nested_tensor([torch.randn((1, 2, 4)),
torch.randn((8, 3, 7))],
device=device, dtype=dtype)
- nt1 = torch.nested_tensor([torch.randn((8, 4, 6)),
+ nt1 = torch.nested_tensor([torch.randn((1, 4, 6)),
torch.randn((8, 7, 5))],
device=device, dtype=dtype)
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)