enable fp8 cast for inductor CPU (#117737)

Enable FP8 cast for this issue https://github.com/pytorch/pytorch/issues/117119.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117737
Approved by: https://github.com/jgong5, https://github.com/jansel
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index aaa1472..16c8199 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -33,7 +33,12 @@
 from torch._inductor.virtualized import V
 from torch.fx.experimental.proxy_tensor import make_fx
 from torch.nn import functional as F
-from torch.testing._internal.common_utils import IS_MACOS, slowTest
+from torch.testing._internal.common_utils import (
+    instantiate_parametrized_tests,
+    IS_MACOS,
+    parametrize,
+    slowTest,
+)
 from torch.utils._python_dispatch import TorchDispatchMode
 
 try:
@@ -83,6 +88,7 @@
         return x, h
 
 
+@instantiate_parametrized_tests
 class CPUReproTests(TestCase):
     common = check_model
 
@@ -2780,6 +2786,18 @@
                 "Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
             ).run(code)
 
+    @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
+    @parametrize("shape", ("15,3,13", "4,2048,4096"))
+    def test_fp8_cast(self, dtype: torch.dtype, shape: str):
+        def fp8_cast(x):
+            y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
+            y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
+            return y0, y1
+
+        shape = [int(dim) for dim in shape.split(",")]
+        x = torch.rand(*shape, device="cpu", dtype=dtype)
+        self.common(fp8_cast, (x,))
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index f7b0c5d..d1ed0f7 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -66,6 +66,8 @@
     torch.bool: "bool",
     torch.bfloat16: "bfloat16",
     torch.complex64: "complex64",
+    torch.float8_e4m3fn: "float8_e4m3fn",
+    torch.float8_e5m2: "float8_e5m2",
 }
 
 DTYPE_TO_ATEN = {
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
index 9e53b7f..e02f8e6 100644
--- a/torch/_inductor/codegen/cpp_prefix.h
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -11,6 +11,8 @@
 #include <ATen/core/PhiloxRNGEngine.h>
 #include <ATen/native/Math.h>
 
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/BFloat16.h>
 #include <c10/util/BFloat16-math.h>
 #include <c10/util/generic_math.h>
@@ -31,6 +33,9 @@
 typedef at::Half half;
 typedef at::BFloat16 bfloat16;
 
+typedef at::Float8_e4m3fn float8_e4m3fn;
+typedef at::Float8_e5m2 float8_e5m2;
+
 template <typename T>
 struct Welford {
   T mean = T(0);