|  | # Owner(s): ["module: linear algebra"] | 
|  |  | 
|  | import unittest | 
|  | from itertools import product | 
|  | from functools import partial | 
|  | from typing import Optional, Tuple | 
|  |  | 
|  | import torch | 
|  |  | 
|  | from torch.quantization._quantized_conversions import ( | 
|  | pack_int4_to_int8, | 
|  | quantized_weight_reorder_for_mixed_dtypes_linear_cutlass, | 
|  | ) | 
|  |  | 
|  | from torch.testing import make_tensor | 
|  | from torch.testing._internal.common_cuda import SM53OrLater, _get_torch_cuda_version | 
|  | from torch.testing._internal.common_device_type import ( | 
|  | dtypes, | 
|  | instantiate_device_type_tests, | 
|  | onlyCUDA, | 
|  | tol as xtol, | 
|  | toleranceOverride, | 
|  | ) | 
|  |  | 
|  | from torch.testing._internal.common_utils import ( | 
|  | IS_ARM64, | 
|  | IS_JETSON, | 
|  | IS_WINDOWS, | 
|  | parametrize, | 
|  | run_tests, | 
|  | skipIfRocmVersionLessThan, | 
|  | TEST_WITH_ROCM, | 
|  | skipIfRocm, | 
|  | TestCase, | 
|  | ) | 
|  |  | 
|  | _IS_SM8X = False | 
|  | if torch.cuda.is_available(): | 
|  | _IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8 | 
|  |  | 
|  | # Protects against includes accidentally setting the default dtype | 
|  | assert torch.get_default_dtype() is torch.float32 | 
|  |  | 
|  |  | 
|  | @unittest.skipIf(IS_ARM64, "Issue with numpy version on arm") | 
|  | class TestMatmulCuda(TestCase): | 
|  | def setUp(self): | 
|  | super(self.__class__, self).setUp() | 
|  | torch.backends.cuda.matmul.allow_tf32 = False | 
|  |  | 
|  | def tearDown(self): | 
|  | torch.backends.cuda.matmul.allow_tf32 = True | 
|  | super(self.__class__, self).tearDown() | 
|  |  | 
|  | def cublas_addmm(self, size: int, dtype: torch.dtype, reduced_precision: bool = False): | 
|  | # | 
|  | # Check for catastrophic cuBLAS inaccuracy by measuring the deviation between | 
|  | # results from the CUDA invocation of torch.addmm and the CPU invocation | 
|  | # (which does not use CUDA backend). | 
|  | # | 
|  | # Get dims | 
|  | n, m, p = (size + 1, size, size + 2) | 
|  | # Disable reduced precision reductions in BFloat16 to bypass some kernels | 
|  | # which fail the threshold check | 
|  | orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction | 
|  | orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction | 
|  | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = reduced_precision | 
|  | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = reduced_precision | 
|  | # Make random tensors on CPU (seed set on common_utils.py import) | 
|  | # (Not using numpy because it does not support bfloat16) | 
|  | make_arg = partial(make_tensor, dtype=dtype, device="cpu") | 
|  | m_beta = make_arg(1) | 
|  | m_input = make_arg((n, p)) | 
|  | m_1 = make_arg((n, m)) | 
|  | m_2 = make_arg((m, p)) | 
|  | # *(B)FLOAT16 Special Handling* | 
|  | # Backend does not tensorize float16 on CPU, | 
|  | # and bloat16 may present accuracy issues, | 
|  | # so convert to float32 for these cases | 
|  | # (but keep same for other types, e.g. float32 and int*) | 
|  | if dtype == torch.float16 or dtype == torch.bfloat16: | 
|  | m_beta = m_beta.to(dtype=torch.float32) | 
|  | m_input = m_input.to(dtype=torch.float32) | 
|  | m_1 = m_1.to(dtype=torch.float32) | 
|  | m_2 = m_2.to(dtype=torch.float32) | 
|  | # Get CPU result | 
|  | res_cpu = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) | 
|  | # *(B)FLOAT16 Special Handling*`` | 
|  | # Convert back to (b)float16 | 
|  | if dtype == torch.float16 or dtype == torch.bfloat16: | 
|  | m_beta = m_beta.to(dtype=dtype) | 
|  | m_input = m_input.to(dtype=dtype) | 
|  | m_1 = m_1.to(dtype=dtype) | 
|  | m_2 = m_2.to(dtype=dtype) | 
|  | res_cpu = res_cpu.to(dtype=dtype) | 
|  | # Move arg tensors to CUDA | 
|  | m_beta = m_beta.to("cuda") | 
|  | m_input = m_input.to("cuda") | 
|  | m_1 = m_1.to("cuda") | 
|  | m_2 = m_2.to("cuda") | 
|  | # Get CUDA result | 
|  | res_cuda = torch.addmm(m_input, m_1, m_2, beta=m_beta.item()) | 
|  | # Move to CPU for comparison | 
|  | res_cuda = res_cuda.to("cpu") | 
|  | # Compare | 
|  | self.assertEqual(res_cpu, res_cuda) | 
|  | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 | 
|  | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 | 
|  |  | 
|  | @onlyCUDA | 
|  | @skipIfRocmVersionLessThan((5, 2)) | 
|  | # imported 'tol' as 'xtol' to avoid aliasing in code above | 
|  | @toleranceOverride({torch.float16: xtol(atol=1e-1, rtol=1e-1), | 
|  | torch.bfloat16: xtol(atol=1e-1, rtol=1e-1), | 
|  | torch.float32: xtol(atol=1e-1, rtol=1e-1)}) | 
|  | @dtypes(torch.float16, torch.bfloat16, torch.float32) | 
|  | @parametrize("size", [100, 1000, 10000]) | 
|  | def test_cublas_addmm(self, size: int, dtype: torch.dtype): | 
|  | self.cublas_addmm(size, dtype, False) | 
|  |  | 
|  | @onlyCUDA | 
|  | @skipIfRocmVersionLessThan((5, 2)) | 
|  | # imported 'tol' as 'xtol' to avoid aliasing in code above | 
|  | @toleranceOverride({torch.float16: xtol(atol=7e-1, rtol=2e-1), | 
|  | torch.bfloat16: xtol(atol=1e1, rtol=2e-1)}) | 
|  | @dtypes(torch.float16, torch.bfloat16) | 
|  | @parametrize("size", [100, 1000, 10000]) | 
|  | def test_cublas_addmm_reduced_precision(self, size: int, dtype: torch.dtype): | 
|  | self.cublas_addmm(size, dtype, True) | 
|  |  | 
|  | @onlyCUDA | 
|  | @toleranceOverride({torch.float16: xtol(atol=1e-3, rtol=2e-3)}) | 
|  | @dtypes(torch.float16) | 
|  | def test_cublas_addmm_alignment(self, dtype): | 
|  | device = 'cuda' | 
|  | # perturb X, A, or B alignment | 
|  | for idx in range(0, 3): | 
|  | for offset in range(1, 3): | 
|  | offsets = [0, 0, 0] | 
|  | offsets[idx] = offset | 
|  | x_offset, a_offset, b_offset = offsets | 
|  | A = torch.rand((5120 * 2560 + a_offset), requires_grad=True, dtype=dtype, device=device) | 
|  | A = A[a_offset:].reshape(5120, 2560) | 
|  | X = torch.rand((26 * 2560 + x_offset), requires_grad=True, dtype=dtype, device=device) | 
|  | X = X[x_offset:].reshape(26, 1, 2560) | 
|  | B = torch.rand((5120 + b_offset), requires_grad=True, dtype=dtype, device=device) | 
|  | B = B[b_offset:].reshape(5120) | 
|  | out = torch.nn.functional.linear(X, A, B) | 
|  | self.assertEqual(out, torch.matmul(X, A.transpose(1, 0)) + B) | 
|  |  | 
|  | @onlyCUDA | 
|  | @unittest.skipIf(IS_JETSON, "Too large for Jetson") | 
|  | @toleranceOverride({torch.float32: xtol(atol=1e-5, rtol=1.1e-5)}) | 
|  | @dtypes(*([torch.float32, torch.float16] + | 
|  | [torch.bfloat16] if TEST_WITH_ROCM or SM53OrLater else [])) | 
|  | @parametrize( | 
|  | "batch_size, N, M, P", | 
|  | [(2, 100, 100, 100), | 
|  | (2, 1000, 1000, 1000), | 
|  | (1, 10000, 1000, 10000), | 
|  | (1, 10000, 10000, 10000)], | 
|  | name_fn=lambda batch_size, N, M, P: f"{batch_size}_{N}_{M}_{P}", | 
|  | ) | 
|  | @skipIfRocm | 
|  | def test_cublas_baddbmm_large_input(self, device, batch_size, N, M, P, dtype): | 
|  | cpu_dtype = dtype | 
|  | if dtype == torch.float16 or dtype == torch.bfloat16: | 
|  | cpu_dtype = torch.float32 | 
|  |  | 
|  | M1 = torch.rand((N, M), device=device, dtype=dtype) | 
|  | M2 = torch.rand((M, P), device=device, dtype=dtype) | 
|  | A = torch.rand((N, P), device=device, dtype=dtype) | 
|  |  | 
|  | def _convert_to_cpu(t): | 
|  | return t.to(device='cpu', dtype=cpu_dtype) | 
|  | M1_cpu, M2_cpu, A_cpu = map(_convert_to_cpu, [M1, M2, A]) | 
|  |  | 
|  | # linear | 
|  | out1_cpu = torch.nn.functional.linear(M1_cpu, M2_cpu.t(), A_cpu).to(dtype=dtype) | 
|  | out1_gpu = torch.nn.functional.linear(M1, M2.t(), A).cpu() | 
|  | self.assertEqual(out1_cpu, out1_gpu) | 
|  | # test multiply the identity matrix | 
|  | if N == M and M == P: | 
|  | M2_eye = torch.eye(N, device=device, dtype=dtype) | 
|  | out1_eye_gpu = torch.nn.functional.linear(M1, M2_eye.t(), torch.zeros_like(A)) | 
|  | self.assertEqual(M1_cpu.to(dtype=dtype), out1_eye_gpu.cpu()) | 
|  |  | 
|  | # baddbmm | 
|  | def _expand_to_batch(t: torch.Tensor): | 
|  | return t.expand((batch_size, ) + t.size()) | 
|  | alpha, beta = 1.0, 1.0 | 
|  | M1, M2, A, M1_cpu, M2_cpu, A_cpu = map(_expand_to_batch, [M1, M2, A, M1_cpu, M2_cpu, A_cpu]) | 
|  |  | 
|  | out2_cpu = torch.baddbmm(A_cpu, M1_cpu, M2_cpu, beta=beta, alpha=alpha).to(dtype=dtype) | 
|  | out2_gpu = torch.baddbmm(A, M1, M2, beta=beta, alpha=alpha).cpu() | 
|  | self.assertEqual(out2_cpu, out2_gpu) | 
|  | # test multiply the identity matrix | 
|  | if N == M and M == P: | 
|  | M2_eye = torch.eye(N, device=device, dtype=dtype).expand(batch_size, N, N) | 
|  | out2_eye_gpu = torch.baddbmm(torch.zeros_like(A), M1, M2_eye, beta=beta, alpha=alpha) | 
|  | self.assertEqual(M1_cpu.to(dtype=dtype), out2_eye_gpu.cpu()) | 
|  |  | 
|  | # cross comparison | 
|  | self.assertEqual(out1_gpu, out2_gpu[0]) | 
|  |  | 
|  |  | 
|  |  | 
|  | f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" | 
|  |  | 
|  | if torch.version.hip: | 
|  | e4m3_type = torch.float8_e4m3fnuz | 
|  | e5m2_type = torch.float8_e5m2fnuz | 
|  | E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max | 
|  | E5M2_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max | 
|  | else: | 
|  | e4m3_type = torch.float8_e4m3fn | 
|  | e5m2_type = torch.float8_e5m2 | 
|  | E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max | 
|  | E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max | 
|  |  | 
|  | # avoid division by zero when calculating scale | 
|  | EPS = 1e-12 | 
|  |  | 
|  | def scaled_mm_supported_device(): | 
|  | if torch.cuda.is_available(): | 
|  | if torch.version.hip: | 
|  | return 'gfx94' in torch.cuda.get_device_properties(0).gcnArchName | 
|  | else: | 
|  | return torch.cuda.get_device_capability() >= (9, 0) or torch.cuda.get_device_capability() == (8, 9) | 
|  | return False | 
|  |  | 
|  |  | 
|  | def amax_to_scale( | 
|  | amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype | 
|  | ): | 
|  | """ Converts the amax value of a tensor to the fp8 scale. | 
|  | Args: | 
|  | amax: The amax value of the tensor. | 
|  | float8_dtype: the float8 dtype. | 
|  | orig_dtype: The original dtype of the tensor. | 
|  | """ | 
|  | scale = torch.empty_like(amax, dtype=torch.float32) | 
|  | if float8_dtype == e4m3_type: | 
|  | res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) | 
|  | elif float8_dtype == e5m2_type: | 
|  | res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) | 
|  | else: | 
|  | raise ValueError(f"Unsupported float8_dtype: {float8_dtype}") | 
|  |  | 
|  | # Ensure the scale is representable in float16, | 
|  | # this helps when amax is small. We are assuming that we don't need | 
|  | # to care about this for float32/bfloat16 | 
|  | if orig_dtype is torch.float16: | 
|  | res = torch.clamp(res, max=torch.finfo(torch.float16).max) | 
|  |  | 
|  | scale.copy_(res) | 
|  | return scale | 
|  |  | 
|  | def tensor_to_scale(x: torch.Tensor, float8_dtype: torch.dtype): | 
|  | amax = torch.max(torch.abs(x)) | 
|  | return amax_to_scale(amax, float8_dtype, x.dtype) | 
|  |  | 
|  | def mm_float8_emulated(x, x_scale, y, y_scale, out_dtype): | 
|  | # naive implementation: dq -> op -> q | 
|  | x_fp32 = x.to(torch.float) / x_scale | 
|  | y_fp32 = y.to(torch.float) / y_scale | 
|  | out_fp32 = torch.mm(x_fp32, y_fp32) | 
|  |  | 
|  | return out_fp32.to(out_dtype), torch.max(torch.abs(out_fp32)) | 
|  |  | 
|  | def addmm_float8_unwrapped( | 
|  | a_data: torch.Tensor, | 
|  | a_scale: torch.Tensor, | 
|  | b_data: torch.Tensor, | 
|  | b_scale: torch.tensor, | 
|  | output_dtype: torch.dtype, | 
|  | output_scale: Optional[torch.Tensor], | 
|  | bias: Optional[torch.Tensor] = None, | 
|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | a_inverse_scale = a_scale.reciprocal() | 
|  | b_inverse_scale = b_scale.reciprocal() | 
|  | if output_dtype == torch.float32 and bias is not None: | 
|  | # Bias is not supported by _scaled_mm when output is fp32 | 
|  | output, output_amax = torch._scaled_mm( | 
|  | a_data, | 
|  | b_data, | 
|  | out_dtype=output_dtype, | 
|  | scale_a=a_inverse_scale, | 
|  | scale_b=b_inverse_scale, | 
|  | scale_result=output_scale, | 
|  | ) | 
|  | output += bias | 
|  | return output, output_amax | 
|  | output, output_amax = torch._scaled_mm( | 
|  | a_data, | 
|  | b_data, | 
|  | bias=bias, | 
|  | out_dtype=output_dtype, | 
|  | scale_a=a_inverse_scale, | 
|  | scale_b=b_inverse_scale, | 
|  | scale_result=output_scale, | 
|  | ) | 
|  | return output, output_amax | 
|  |  | 
|  | def mm_float8( | 
|  | a: torch.Tensor, | 
|  | b: torch.Tensor, | 
|  | a_scale: torch.Tensor, | 
|  | b_scale: torch.Tensor, | 
|  | output_dtype: torch.dtype,  # output dtype | 
|  | output_scale: Optional[torch.Tensor] = None,  # output scale, precomputed | 
|  | ) -> Tuple[torch.Tensor, torch.Tensor]: | 
|  | return addmm_float8_unwrapped( | 
|  | a, a_scale, b, b_scale, output_dtype, output_scale | 
|  | ) | 
|  |  | 
|  | def to_fp8_saturated( | 
|  | x: torch.Tensor, | 
|  | x_scale: torch.tensor, | 
|  | fp8_dtype: torch.dtype | 
|  | ): | 
|  | """ | 
|  | Converts a tensor to a saturated fp8 tensor. | 
|  |  | 
|  | Args: | 
|  | a: Input Tensor. | 
|  | b: Input Tensor. | 
|  | a_scale: scale associated with `a`. | 
|  | b_scale: scale associated with `b`. | 
|  | output_dtype: dtype of result. | 
|  | output_scale: the output tensor's scale, precomputed. | 
|  |  | 
|  | Returns: | 
|  | (torch.Tensor, torch.Tensor): (result of the matrix multiplication, associated amax) | 
|  | Note: | 
|  | The default behavior in PyTorch for casting to `e4m3_type` | 
|  | and `e5m2_type` is to not saturate. In this context, we should | 
|  | saturate. A common case where we want to saturate is when the history | 
|  | of a tensor has a maximum value of `amax1`, and the current amax value | 
|  | is `amax2`, where `amax1 < amax2`. | 
|  | """ | 
|  | x_scaled = x * x_scale | 
|  |  | 
|  | if fp8_dtype == e4m3_type: | 
|  | x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) | 
|  | elif fp8_dtype == e5m2_type: | 
|  | x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) | 
|  | else: | 
|  | raise ValueError(f"to_fp8_saturated(): Unsupported fp8_dtype: {fp8_dtype}") | 
|  |  | 
|  | return x.to(fp8_dtype) | 
|  |  | 
|  | @unittest.skipIf(not torch.cuda.is_available(), "CUDA not found") | 
|  | class TestFP8MatmulCuda(TestCase): | 
|  |  | 
|  |  | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def _test_tautological_mm(self, device: str = "cuda", | 
|  | x_dtype: torch.dtype = e4m3_type, | 
|  | y_dtype: torch.dtype = e4m3_type, | 
|  | out_dtype: Optional[torch.dtype] = None, | 
|  | size: int = 16) -> None: | 
|  | x_fp8 = torch.rand(size, size, device=device).to(x_dtype) | 
|  | y_fp8 = torch.eye(size, device=device, dtype=y_dtype).t() | 
|  | out_fp32 = torch.mm(x_fp8.to(torch.float), y_fp8.to(torch.float)) | 
|  | (out_fp8, amax_fp8) = torch._scaled_mm(x_fp8, y_fp8, out_dtype=out_dtype) | 
|  | if out_dtype is not None: | 
|  | self.assertEqual(out_dtype, out_fp8.dtype) | 
|  | if out_dtype not in [torch.float16, torch.bfloat16, torch.float]: | 
|  | self.assertEqual(out_fp32.amax(), amax_fp8) | 
|  | self.assertEqual(out_fp32, out_fp8.to(torch.float)) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def test_float8_basics(self, device) -> None: | 
|  | self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) | 
|  | # hipblaslt does not yet support mixed e4m3_type input | 
|  | if torch.version.hip is None: | 
|  | self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) | 
|  | self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) | 
|  | # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported | 
|  | with self.assertRaises(RuntimeError): | 
|  | self._test_tautological_mm(device, e5m2_type, e5m2_type) | 
|  |  | 
|  | self._test_tautological_mm(device, size=64, out_dtype=torch.float16) | 
|  | self._test_tautological_mm(device, size=96, out_dtype=torch.float32) | 
|  | # hipblaslt does not yet support bfloat16 output | 
|  | if torch.version.hip is None: | 
|  | self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16) | 
|  | with self.assertRaises(RuntimeError): | 
|  | self._test_tautological_mm(device, out_dtype=e5m2_type) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def test_float8_scale(self, device) -> None: | 
|  | size = (16, 16) | 
|  | x = torch.full(size, .5, device=device, dtype=e4m3_type) | 
|  | # hipblaslt does not yet support mixed e4m3_type input | 
|  | y_type = e4m3_type if torch.version.hip else e5m2_type | 
|  | y = torch.full(size, .5, device=device, dtype=y_type).t() | 
|  | scale_a = torch.tensor(1.5, device=device) | 
|  | scale_b = torch.tensor(0.66, device=device) | 
|  | out_fp8, amax_fp8 = torch._scaled_mm(x, y) | 
|  | self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) | 
|  | out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b) | 
|  | self.assertEqual(out_fp8, out_fp8_s) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | @parametrize("base_dtype", [torch.float16, torch.bfloat16, torch.float32]) | 
|  | def test_scaled_mm_vs_emulated(self, base_dtype): | 
|  | torch.manual_seed(42) | 
|  | input_dtype = e4m3_type | 
|  | output_dtype = base_dtype | 
|  | compare_type = torch.float32 | 
|  |  | 
|  | x = torch.randn(16, 16, device="cuda", dtype=base_dtype) | 
|  | y = torch.randn(32, 16, device="cuda", dtype=base_dtype).t() | 
|  |  | 
|  | x_scale = tensor_to_scale(x, input_dtype).float() | 
|  | y_scale = tensor_to_scale(y, input_dtype).float() | 
|  |  | 
|  | x_fp8 = to_fp8_saturated(x, x_scale, e4m3_type) | 
|  | y_fp8 = to_fp8_saturated(y, y_scale, e4m3_type) | 
|  |  | 
|  | # Calculate actual F8 mm | 
|  | out_scaled_mm, output_amax_scaled = mm_float8( | 
|  | x_fp8, | 
|  | y_fp8, | 
|  | a_scale=x_scale, | 
|  | b_scale=y_scale, | 
|  | output_dtype=output_dtype | 
|  | ) | 
|  |  | 
|  | # Calculate emulated F8 mm | 
|  | out_emulated, output_amax_emulated = mm_float8_emulated( | 
|  | x_fp8, | 
|  | x_scale, | 
|  | y_fp8, | 
|  | y_scale, | 
|  | output_dtype | 
|  | ) | 
|  |  | 
|  | if output_dtype != base_dtype: | 
|  | out_scaled_mm = out_scaled_mm.to(compare_type) | 
|  | out_emulated = out_emulated.to(compare_type) | 
|  |  | 
|  | out_scaled_mm = out_scaled_mm / amax_to_scale( | 
|  | output_amax_scaled, input_dtype | 
|  | ) | 
|  | out_emulated = out_emulated / amax_to_scale( | 
|  | output_amax_emulated, input_dtype | 
|  | ) | 
|  |  | 
|  | if base_dtype in {torch.bfloat16, torch.float16}: | 
|  | atol, rtol = 7e-2, 7e-2 | 
|  | else: | 
|  | atol, rtol = 2e-3, 2e-3 | 
|  |  | 
|  | torch.testing.assert_close(out_scaled_mm, out_emulated, atol=atol, rtol=rtol) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def test_float8_bias(self, device) -> None: | 
|  | (k, l, m) = (16, 48, 32) | 
|  | x = torch.rand((k, l), device=device).to(e4m3_type) | 
|  | y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() | 
|  | bias = torch.full((m,), 4.0, device=device, dtype=torch.half) | 
|  | out_fp8, amax_fp8 = torch._scaled_mm(x, y) | 
|  | outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias) | 
|  | # this fails on ROCm currently because hipblaslt doesn't have amax op | 
|  | if torch.version.hip is None: | 
|  | self.assertEqual((amaxb_fp8 - amax_fp8).item(), 4.0) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | @parametrize("bias", [True, False]) | 
|  | def test_non_divisible_leading_dim(self, device, bias: torch.bool) -> None: | 
|  | x = torch.rand((17, 16), device=device).to(e4m3_type) | 
|  | y = torch.rand((16, 16), device=device).to(e4m3_type).t() | 
|  | input_bias = None | 
|  | if bias: | 
|  | input_bias = torch.rand((16,), device=device).to(torch.half) | 
|  | out_fp8, amax_fp8 = torch._scaled_mm(x, y, bias=input_bias) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def test_float8_bias_relu_edgecase(self, device) -> None: | 
|  | (k, l, m) = (16, 48, 32) | 
|  | x = torch.full((k, l), 0.0, device=device).to(e4m3_type) | 
|  | y = torch.full((m, l), 1.0, device=device, dtype=e4m3_type).t() | 
|  | bias = torch.full((m,), -3.0, device=device, dtype=torch.half) | 
|  | outb_fp8, amaxb_fp8 = torch._scaled_mm(x, y, bias=bias) | 
|  | self.assertEqual(amaxb_fp8.item(), 3.0) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def test_float32_output_errors_with_bias(self, device) -> None: | 
|  | (k, l, m) = (16, 48, 32) | 
|  | x = torch.rand((k, l), device=device).to(e4m3_type) | 
|  | y = torch.full((m, l), .25, device=device, dtype=e4m3_type).t() | 
|  | bias = torch.full((m,), 4.0, device=device, dtype=torch.bfloat16) | 
|  | self.assertRaisesRegex( | 
|  | RuntimeError, | 
|  | "Bias is not supported when out_dtype is set to Float32", | 
|  | lambda: torch._scaled_mm(x, y, bias=bias, out_dtype=torch.float32), | 
|  | ) | 
|  |  | 
|  | @unittest.skipIf(scaled_mm_supported_device(), | 
|  | "This test is only for devices with compute capability < 8.9") | 
|  | def test_error_message_fp8_pre_sm89(self, device) -> None: | 
|  | (k, l, m) = (16, 48, 32) | 
|  | x = torch.rand((k, l), device=device).to(e4m3_type) | 
|  | y = torch.rand((m, l), device=device).to(e4m3_type).t() | 
|  | self.assertRaisesRegex( | 
|  | RuntimeError, | 
|  | r"torch\.\_scaled\_mm is only supported on CUDA devices with compute capability \>\= 9\.0 or 8\.9, or ROCm MI300\+", | 
|  | lambda: torch._scaled_mm(x, y, out_dtype=torch.float32), | 
|  | ) | 
|  |  | 
|  | @unittest.skipIf(not scaled_mm_supported_device(), f8_msg) | 
|  | def test_float8_scale_fast_accum(self, device) -> None: | 
|  | size = (16, 16) | 
|  | x = torch.full(size, .5, device=device, dtype=e4m3_type) | 
|  | # hipblaslt does not yet support mixed e4m3_type input | 
|  | y_type = e4m3_type if torch.version.hip else e5m2_type | 
|  | y = torch.full(size, .5, device=device, dtype=y_type).t() | 
|  | scale_a = torch.tensor(1.5, device=device) | 
|  | scale_b = torch.tensor(0.66, device=device) | 
|  | out_fp8, amax_fp8 = torch._scaled_mm(x, y, use_fast_accum=True) | 
|  | self.assertEqual(out_fp8.to(torch.float), torch.full(size, 4., device=device)) | 
|  | out_fp8_s, amax_fp8_s = torch._scaled_mm(x, y, scale_a=scale_a, scale_b=scale_b, use_fast_accum=True) | 
|  | self.assertEqual(out_fp8, out_fp8_s) | 
|  |  | 
|  |  | 
|  | @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support CUTLASS") | 
|  | @unittest.skipIf(IS_WINDOWS, "Windows doesn't support CUTLASS extensions") | 
|  | @unittest.skipIf(not _IS_SM8X, "mixed dtypes linear only supported on SM 8.x") | 
|  | class TestMixedDtypesLinearCuda(TestCase): | 
|  | @dtypes(torch.float16, torch.bfloat16) | 
|  | def test_mixed_dtypes_linear(self, dtype: torch.dtype, device: str = "cuda"): | 
|  | version = _get_torch_cuda_version() | 
|  | if version < (11, 8): | 
|  | self.skipTest("_mixed_dtypes_linear only compiled for CUDA 11.8+") | 
|  |  | 
|  | def run_test( | 
|  | batch_shape, | 
|  | m, | 
|  | n, | 
|  | k, | 
|  | add_bias, | 
|  | activation, | 
|  | dtype, | 
|  | dtypeq, | 
|  | device, | 
|  | rtol, | 
|  | atol, | 
|  | ): | 
|  | if not add_bias and activation != "none": | 
|  | return | 
|  |  | 
|  | val_lo, val_hi = -1, 1 | 
|  | valq_lo, valq_hi = -2, 2 | 
|  | input = make_tensor( | 
|  | *batch_shape, m, k, low=val_lo, high=val_hi, dtype=dtype, device=device | 
|  | ) | 
|  | weight = make_tensor( | 
|  | n, k, low=valq_lo, high=valq_hi, dtype=torch.int8, device=device | 
|  | ) | 
|  | scale = make_tensor( | 
|  | (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device | 
|  | ) | 
|  | bias = ( | 
|  | make_tensor( | 
|  | (n,), low=val_lo, high=val_hi, dtype=input.dtype, device=device | 
|  | ) | 
|  | if add_bias | 
|  | else None | 
|  | ) | 
|  |  | 
|  | input_ref = input.reshape(-1, input.shape[-1]) | 
|  |  | 
|  | # First, test plain multiplication. | 
|  | weight_ref = weight.T.to(input.dtype) * scale.view(1, n) | 
|  | weightq = ( | 
|  | pack_int4_to_int8(weight.T) if dtypeq == torch.quint4x2 else weight.T | 
|  | ) | 
|  | output_ref = torch.mm(input_ref, weight_ref).reshape(*input.shape[:-1], n) | 
|  | output = torch.ops.aten._mixed_dtypes_linear( | 
|  | input, | 
|  | quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( | 
|  | weightq, dtypeq, transpose=False | 
|  | ), | 
|  | scale, | 
|  | ) | 
|  | torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) | 
|  |  | 
|  | # Second, test the linear operator itself. | 
|  | weight_ref = weight.to(input.dtype) * scale.view(n, 1) | 
|  | weightq = pack_int4_to_int8(weight) if dtypeq == torch.quint4x2 else weight | 
|  | bias_ref = bias.view(1, n) if add_bias else None | 
|  | output_ref = torch.nn.functional.linear( | 
|  | input_ref, weight_ref, bias=bias_ref | 
|  | ).reshape(*input.shape[:-1], n) | 
|  | if activation == "relu": | 
|  | relu = torch.nn.ReLU() | 
|  | output_ref = relu(output_ref) | 
|  | elif activation == "silu": | 
|  | silu = torch.nn.SiLU() | 
|  | output_ref = silu(output_ref) | 
|  | output = torch.ops.aten._mixed_dtypes_linear( | 
|  | input, | 
|  | quantized_weight_reorder_for_mixed_dtypes_linear_cutlass( | 
|  | weightq, dtypeq, transpose=True | 
|  | ), | 
|  | scale, | 
|  | bias=bias, | 
|  | activation=activation, | 
|  | ) | 
|  | torch.testing.assert_close(output, output_ref, rtol=rtol, atol=atol) | 
|  |  | 
|  | dtypeqs = [torch.int8, torch.quint4x2] | 
|  | batch_shapes = [[], [2], [2, 1]] | 
|  | shapes = [ | 
|  | [8, 64, 64], | 
|  | [8, 64, 128], | 
|  | [8, 128, 64], | 
|  | [8, 128, 128], | 
|  | [8, 128, 192], | 
|  | [8, 128, 256], | 
|  | [8, 256, 128], | 
|  | [8, 256, 384], | 
|  | [8, 384, 256], | 
|  | ] | 
|  | activations = [None, "relu", "silu"] | 
|  | rtol, atol = 1e-3, 1e-3 | 
|  | if dtype == torch.bfloat16: | 
|  | rtol, atol = 1e-2, 1e-3 | 
|  | for dtypeq, batch_shape, (m, n, k), add_bias, activation in product( | 
|  | dtypeqs, batch_shapes, shapes, (False, True), activations | 
|  | ): | 
|  | run_test( | 
|  | batch_shape, | 
|  | m, | 
|  | n, | 
|  | k, | 
|  | add_bias, | 
|  | activation, | 
|  | dtype, | 
|  | dtypeq, | 
|  | device, | 
|  | rtol, | 
|  | atol, | 
|  | ) | 
|  |  | 
|  | instantiate_device_type_tests(TestMatmulCuda, globals(), except_for="cpu") | 
|  | instantiate_device_type_tests(TestFP8MatmulCuda, globals(), except_for="cpu") | 
|  | instantiate_device_type_tests(TestMixedDtypesLinearCuda, globals(), except_for="cpu") | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | TestCase._default_dtype_check_enabled = True | 
|  | run_tests() |