Add meta func for scaled mm (#112609)

# Summary
Adds a meta implementation for _scaled_mm which is required for dynamic shapes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112609
Approved by: https://github.com/eellison, https://github.com/malfet
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 7024e89..38cce45 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -753,7 +753,7 @@
   TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
   TORCH_CHECK(amax.scalar_type() == kFloat, "amax must be a float scalar");
   TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
-  TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat1.scalar_type());
+  TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
   // Type restrictions imposed by CuBLASLt as of CUDA-12.1
   TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
         "Multiplication of two Float8_e5m2 matrices is not supported");
diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py
index 75d5e06..68a4353 100644
--- a/torch/_meta_registrations.py
+++ b/torch/_meta_registrations.py
@@ -5179,6 +5179,56 @@
     return grad_q, grad_k, grad_v, grad_bias
 
 
+@register_meta([aten._scaled_mm.default])
+def meta_scaled_mm(
+    self: torch.Tensor,
+    mat2: torch.Tensor,
+    bias: Optional[torch.Tensor] = None,
+    out_dtype: Optional[torch.dtype] = None,
+    scale_a: Optional[torch.Tensor] = None,
+    scale_b: Optional[torch.Tensor] = None,
+    scale_result: Optional[torch.Tensor] = None,
+    use_fast_accum: bool = False,
+):
+    def is_row_major(stride):
+        return stride[0] > stride[1] and stride[1] == 1
+
+    def is_col_major(shape, stride):
+        return stride[0] == 1 and stride[1] == shape[0]
+
+    def is_fp8_type(dtype):
+        return dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
+
+    torch._check(
+        self.dim() == 2 and mat2.dim() == 2,
+        lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
+    )
+    torch._check(
+        is_row_major(self.stride()),
+        lambda: "self must be row_major",
+    )
+    torch._check(
+        is_col_major(mat2.shape, mat2.stride()),
+        lambda: "mat2 must be col_major",
+    )
+    torch._check(
+        self.size(1) % 16 == 0,
+        lambda: f"Expected self.size(0) to be divisible by 16, but got self.size(1)={self.size(1)}",
+    )
+    torch._check(
+        mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
+        lambda: f"Expected both dimensions of mat2 to be divisble by 16 but got {mat2.shape}",
+    )
+    torch._check(
+        is_fp8_type(self.dtype) and is_fp8_type(mat2.dtype),
+        lambda: f"Expected both inputs to be fp8 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
+    )
+    _out_dtype = out_dtype if out_dtype is not None else self.dtype
+    return torch.empty(
+        self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device
+    ), torch.empty((), dtype=torch.float32, device=self.device)
+
+
 @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
 @out_wrapper()
 def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py
index 6ded7ff..a46d8cf 100644
--- a/torch/testing/_creation.py
+++ b/torch/testing/_creation.py
@@ -11,6 +11,7 @@
 
 _INTEGRAL_TYPES = [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]
 _FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64]
+_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2]
 _COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128]
 _BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES]
 _FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]
@@ -217,6 +218,18 @@
         _uniform_random_(
             torch.view_as_real(result) if dtype in _COMPLEX_TYPES else result, low, high
         )
+    elif dtype in _FLOATING_8BIT_TYPES:
+        low, high = modify_low_high(
+            low,
+            high,
+            lowest_inclusive=torch.finfo(dtype).min,
+            highest_exclusive=torch.finfo(dtype).max,
+            default_low=-9,
+            default_high=9,
+        )
+        result = torch.empty(shape, device=device, dtype=torch.float32)
+        _uniform_random_(result, low, high)
+        result = result.to(dtype)
     else:
         raise TypeError(
             f"The requested dtype '{dtype}' is not supported by torch.testing.make_tensor()."
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 1817b61..3a66f57 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -26,7 +26,7 @@
      skipCPUIfNoMklSparse,
      toleranceOverride, tol)
 from torch.testing._internal.common_cuda import (
-    SM53OrLater, SM60OrLater, SM80OrLater, with_tf32_off, TEST_CUDNN,
+    SM53OrLater, SM60OrLater, SM80OrLater, SM90OrLater, with_tf32_off, TEST_CUDNN,
     _get_torch_cuda_version, _get_torch_rocm_version,
 )
 from torch.testing._internal.common_utils import (
@@ -8176,6 +8176,25 @@
         yield ErrorInput(SampleInput(input, args=args, kwargs=kwargs),
                          error_type=error_type, error_regex=error_regex)
 
+def sample_inputs_scaled_mm(op_info, device, dtype, requires_grad, **kwargs):
+    make_mat_e4m3 = partial(make_tensor, device=device, dtype=torch.float8_e4m3fn, requires_grad=requires_grad)
+    make_mat_e5m2 = partial(make_tensor, device=device, dtype=torch.float8_e5m2, requires_grad=requires_grad)
+    M, N, K = 15, 32, 16
+    samples = []
+    # two e4m3
+    mat1 = make_mat_e4m3((M, K))
+    mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
+    samples.append(SampleInput(mat1, mat2))
+    # mat1 e4m3 mat2 e5m2
+    mat1 = make_mat_e4m3((M, K))
+    mat2 = make_mat_e5m2((K, N)).t().contiguous().t()
+    samples.append(SampleInput(mat1, mat2))
+    # mat1 e5m2 mat2 e4m3
+    mat1 = make_mat_e5m2((M, K))
+    mat2 = make_mat_e4m3((K, N)).t().contiguous().t()
+    samples.append(SampleInput(mat1, mat2))
+
+    yield from samples
 
 def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_grad, **kwargs):
     make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -13691,6 +13710,17 @@
             ), ],
     ),
     OpInfo(
+        'torch._scaled_mm',
+        sample_inputs_func=sample_inputs_scaled_mm,
+        dtypes=empty_types(),
+        dtypesIfCUDA=empty_types() + (torch.float8_e4m3fn,),
+        supports_out=True,
+        supports_forward_ad=False,
+        supports_autograd=False,
+        decorators=[skipCUDAIf(not SM90OrLater, 'Requires CUDA SM >= 9.0')],
+        skips=()
+    ),
+    OpInfo(
         'nn.functional.scaled_dot_product_attention',
         op=lambda *args, **kwargs:
                wrapper_set_seed(torch.nn.functional.scaled_dot_product_attention, *args, **kwargs),