Relax divsibilty by 16 for leading dimension of mat1 in scaled_gemm (#108308)
# Summary
CublasLT requires that the matrices be 16 byte aligned. If mat1.size(-1) % 16 == 0 and the matrix is row major than the leading dimension can be any value. See this coment: https://github.com/pytorch/pytorch/pull/107341#discussion_r1310934737
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108308
Approved by: https://github.com/eqy, https://github.com/vkuzo
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 3059885..a5d5315 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -736,8 +736,14 @@
"scale_result must be a float scalar");
TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
" but got ", bias->numel());
- TORCH_CHECK(mat1.sizes()[0] % 16 == 0 && mat1.sizes()[1] % 16 == 0, "mat1 shape (", mat1.sizes()[0], "x",
- mat1.sizes()[1], " must be divisible by 16");
+ TORCH_CHECK(
+ mat1.sizes()[1] % 16 == 0,
+ "Expected trailing dimension of mat1 to be divisble by 16 ",
+ "but got mat1 shape: (",
+ mat1.sizes()[0],
+ "x",
+ mat1.sizes()[1],
+ ".");
TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x",
mat2.sizes()[1], " must be divisible by 16");
// Check types
diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py
index e872e4b..37cb5f8 100644
--- a/test/test_matmul_cuda.py
+++ b/test/test_matmul_cuda.py
@@ -229,6 +229,15 @@
outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias)
self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0)
+ @parametrize("bias", [True, False])
+ def test_non_divisible_leading_dim(self, device, bias: torch.bool) -> None:
+ x = torch.rand((17, 16), device=device).to(torch.float8_e4m3fn)
+ y = torch.rand((16, 16), device=device).to(torch.float8_e4m3fn).t()
+ input_bias = None
+ if bias:
+ input_bias = torch.rand((16,), device=device).to(torch.half)
+ out_fp8, amax_fp8 = torch._scaled_mm(x, y, bias=input_bias)
+
def test_float8_bias_relu_edgecase(self, device) -> None:
(k, l, m) = (16, 48, 32)
x = torch.full((k, l), 0.0, device=device).to(torch.float8_e4m3fn)