Add meta func for scaled mm (#112609)
# Summary
Adds a meta implementation for _scaled_mm which is required for dynamic shapes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112609
Approved by: https://github.com/eellison, https://github.com/malfet
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 7024e89..38cce45 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -753,7 +753,7 @@
TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar");
TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
- TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type());
+ TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
// Type restrictions imposed by CuBLASLt as of CUDA-12.1
TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
"Multiplication of two Float8_e5m2 matrices is not supported");
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 75d5e06..68a4353 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -5179,6 +5179,56 @@
return grad_q, grad_k, grad_v, grad_bias
+@register_meta([aten._scaled_mm.default])
+def meta_scaled_mm(
+ self: torch.Tensor,
+ mat2: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ out_dtype: Optional[torch.dtype] = None,
+ scale_a: Optional[torch.Tensor] = None,
+ scale_b: Optional[torch.Tensor] = None,
+ scale_result: Optional[torch.Tensor] = None,
+ use_fast_accum: bool = False,
+):
+ def is_row_major(stride):
+ return stride[0] > stride[1] and stride[1] == 1
+
+ def is_col_major(shape, stride):
+ return stride[0] == 1 and stride[1] == shape[0]
+
+ def is_fp8_type(dtype):
+ return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
+
+ torch._check(
+ self.dim() == 2 and mat2.dim() == 2,
+ lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
+ )
+ torch._check(
+ is_row_major(self.stride()),
+ lambda: "self must be row_major",
+ )
+ torch._check(
+ is_col_major(mat2.shape, mat2.stride()),
+ lambda: "mat2 must be col_major",
+ )
+ torch._check(
+ self.size(1) % 16 == 0,
+ lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}",
+ )
+ torch._check(
+ mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
+ lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
+ )
+ torch._check(
+ is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
+ lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
+ )
+ _out_dtype = out_dtype if out_dtype is not None else self.dtype
+ return torch.empty(
+ self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device
+ ), torch.empty((), dtype=torch.float32, device=self.device)
+
+
@register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
@out_wrapper()
def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py
index 6ded7ff..a46d8cf 100644
--- a/torch/testing/_creation.py
+++ b/torch/testing/_creation.py
@@ -11,6 +11,7 @@
_INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
_FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
+_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
_COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
_BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
_FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
@@ -217,6 +218,18 @@
_uniform_random_(
torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
)
+ elif dtype in _FLOATING_8BIT_TYPES:
+ low, high = modify_low_high(
+ low,
+ high,
+ lowest_inclusive=torch.finfo(dtype).min,
+ highest_exclusive=torch.finfo(dtype).max,
+ default_low=-9,
+ default_high=9,
+ )
+ result = torch.empty(shape, device=device, dtype=torch.float32)
+ _uniform_random_(result, low, high)
+ result = result.to(dtype)
else:
raise TypeError(
f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1817b61..3a66f57 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -26,7 +26,7 @@
skipCPUIfNoMklSparse,
toleranceOverride, tol)
from torch.testing._internal.common_cuda import (
- SM53OrLater, SM60OrLater, SM80OrLater, with_tf32_off, TEST_CUDNN,
+ SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
_get_torch_cuda_version, _get_torch_rocm_version,
)
from torch.testing._internal.common_utils import (
@@ -8176,6 +8176,25 @@
yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
error_type=error_type, error_regex=error_regex)
+def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
+ make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
+ make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
+ M, N, K = 15, 32, 16
+ samples = []
+ # two e4m3
+ mat1 = make_mat_e4m3((M, K))
+ mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
+ samples.append(SampleInput(mat1, mat2))
+ # mat1 e4m3 mat2 e5m2
+ mat1 = make_mat_e4m3((M, K))
+ mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
+ samples.append(SampleInput(mat1, mat2))
+ # mat1 e5m2 mat2 e4m3
+ mat1 = make_mat_e5m2((M, K))
+ mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
+ samples.append(SampleInput(mat1, mat2))
+
+ yield from samples
def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -13691,6 +13710,17 @@
), ],
),
OpInfo(
+ 'torch._scaled_mm',
+ sample_inputs_func=sample_inputs_scaled_mm,
+ dtypes=empty_types(),
+ dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
+ supports_out=True,
+ supports_forward_ad=False,
+ supports_autograd=False,
+ decorators=[skipCUDAIf(not SM90OrLater, 'Requires CUDA SM >= 9.0')],
+ skips=()
+ ),
+ OpInfo(
'nn.functional.scaled_dot_product_attention',
op=lambda *args, **kwargs:
wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),