Relax divsibilty by 16 for leading dimension of mat1 in scaled_gemm (#108308)

# Summary
CublasLT requires that the matrices be 16 byte aligned. If mat1.size(-1) % 16 == 0 and the matrix is row major than the leading dimension can be any value. See this coment: https://github.com/pytorch/pytorch/pull/107341#discussion_r1310934737

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108308
Approved by: https://github.com/eqy, https://github.com/vkuzo
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 3059885..a5d5315 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -736,8 +736,14 @@
        "scale_result must be a float scalar");
   TORCH_CHECK(!bias || bias->numel() == mat2.sizes()[1], "Bias must be size ", mat2.sizes()[1],
        " but got ", bias->numel());
-  TORCH_CHECK(mat1.sizes()[0] % 16 == 0 && mat1.sizes()[1] % 16 == 0, "mat1 shape (", mat1.sizes()[0], "x",
-       mat1.sizes()[1], " must be divisible by 16");
+  TORCH_CHECK(
+      mat1.sizes()[1] % 16 == 0,
+      "Expected trailing dimension of mat1 to be divisble by 16 ",
+      "but got mat1 shape: (",
+      mat1.sizes()[0],
+      "x",
+      mat1.sizes()[1],
+      ".");
   TORCH_CHECK(mat2.sizes()[0] % 16 == 0 && mat2.sizes()[1] % 16 == 0, "mat2 shape (", mat2.sizes()[0], "x",
        mat2.sizes()[1], " must be divisible by 16");
   // Check types
diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py
index e872e4b..37cb5f8 100644
--- a/test/test_matmul_cuda.py
+++ b/test/test_matmul_cuda.py
@@ -229,6 +229,15 @@
         outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias)
         self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0)
 
+    @parametrize("bias", [True, False])
+    def test_non_divisible_leading_dim(self, device, bias: torch.bool) -> None:
+        x = torch.rand((17, 16), device=device).to(torch.float8_e4m3fn)
+        y = torch.rand((16, 16), device=device).to(torch.float8_e4m3fn).t()
+        input_bias = None
+        if bias:
+            input_bias = torch.rand((16,), device=device).to(torch.half)
+        out_fp8, amax_fp8 = torch._scaled_mm(x, y, bias=input_bias)
+
     def test_float8_bias_relu_edgecase(self, device) -> None:
         (k, l, m) = (16, 48, 32)
         x = torch.full((k, l), 0.0, device=device).to(torch.float8_e4m3fn)