enable fp8 cast for inductor CPU (#117737)
Enable FP8 cast for this issue https://github.com/pytorch/pytorch/issues/117119.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117737
Approved by: https://github.com/jgong5, https://github.com/jansel
diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py
index aaa1472..16c8199 100644
--- a/test/inductor/test_cpu_repro.py
+++ b/test/inductor/test_cpu_repro.py
@@ -33,7 +33,12 @@
from torch._inductor.virtualized import V
from torch.fx.experimental.proxy_tensor import make_fx
from torch.nn import functional as F
-from torch.testing._internal.common_utils import IS_MACOS, slowTest
+from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
+ IS_MACOS,
+ parametrize,
+ slowTest,
+)
from torch.utils._python_dispatch import TorchDispatchMode
try:
@@ -83,6 +88,7 @@
return x, h
+@instantiate_parametrized_tests
class CPUReproTests(TestCase):
common = check_model
@@ -2780,6 +2786,18 @@
"Vectorized<float>::loadu(tmpbuf.data())", 0, exactly=True
).run(code)
+ @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
+ @parametrize("shape", ("15,3,13", "4,2048,4096"))
+ def test_fp8_cast(self, dtype: torch.dtype, shape: str):
+ def fp8_cast(x):
+ y0 = x.to(dtype=torch.float8_e4m3fn).to(dtype)
+ y1 = x.to(dtype=torch.float8_e5m2).to(dtype)
+ return y0, y1
+
+ shape = [int(dim) for dim in shape.split(",")]
+ x = torch.rand(*shape, device="cpu", dtype=dtype)
+ self.common(fp8_cast, (x,))
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index f7b0c5d..d1ed0f7 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -66,6 +66,8 @@
torch.bool: "bool",
torch.bfloat16: "bfloat16",
torch.complex64: "complex64",
+ torch.float8_e4m3fn: "float8_e4m3fn",
+ torch.float8_e5m2: "float8_e5m2",
}
DTYPE_TO_ATEN = {
diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h
index 9e53b7f..e02f8e6 100644
--- a/torch/_inductor/codegen/cpp_prefix.h
+++ b/torch/_inductor/codegen/cpp_prefix.h
@@ -11,6 +11,8 @@
#include <ATen/core/PhiloxRNGEngine.h>
#include <ATen/native/Math.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
#include <c10/util/BFloat16.h>
#include <c10/util/BFloat16-math.h>
#include <c10/util/generic_math.h>
@@ -31,6 +33,9 @@
typedef at::Half half;
typedef at::BFloat16 bfloat16;
+typedef at::Float8_e4m3fn float8_e4m3fn;
+typedef at::Float8_e5m2 float8_e5m2;
+
template <typename T>
struct Welford {
T mean = T(0);