Add torch.float8_e5m2 and torch.float8_e4m3 data types (#104242)

Proposal of two float8 variants - e5m2 and e4m3 - based on https://arxiv.org/pdf/2209.05433.pdf

Hide all Float8 operator implementations behind `#if !defined(C10_MOBILE)` guard to keep Android build size almost unchanged

TODO:
 - Refactor duplicated code
 - Cleanup unbalanced pragma pop in dtype utils
 - Add native implementation on the CUDA size

Co-authored-by: Nikita Shulga <nshulga@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104242
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h
index 9b5b9b7..945f61f 100644
--- a/aten/src/ATen/AccumulateType.h
+++ b/aten/src/ATen/AccumulateType.h
@@ -2,6 +2,8 @@
 #include <ATen/Config.h>
 #include <c10/core/ScalarType.h>
 #include <c10/util/BFloat16.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/Half.h>
 
 // Defines the accumulation type for a scalar type.
@@ -67,6 +69,14 @@
   using type = float;
 };
 template <>
+struct AccumulateType<Float8_e5m2, true> {
+  using type = float;
+};
+template <>
+struct AccumulateType<Float8_e4m3fn, true> {
+  using type = float;
+};
+template <>
 struct AccumulateType<float, true> {
   using type = float;
 };
@@ -111,6 +121,14 @@
   using type = float;
 };
 template <>
+struct AccumulateType<Float8_e5m2, false> {
+  using type = float;
+};
+template <>
+struct AccumulateType<Float8_e4m3fn, false> {
+  using type = float;
+};
+template <>
 struct AccumulateType<c10::complex<Half>, false> {
   using type = c10::complex<float>;
 };
diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp
index 168eb3a..57b9ce0 100644
--- a/aten/src/ATen/DLConvertor.cpp
+++ b/aten/src/ATen/DLConvertor.cpp
@@ -53,6 +53,10 @@
     case ScalarType::BFloat16:
       dtype.code = DLDataTypeCode::kDLBfloat;
       break;
+    case ScalarType::Float8_e5m2:
+    case ScalarType::Float8_e4m3fn:
+      TORCH_CHECK(false, "float8 types are not supported by dlpack");
+      break;
     case ScalarType::QInt8:
     case ScalarType::QUInt8:
     case ScalarType::QInt32:
diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h
index 2405e1f..931cc72 100644
--- a/aten/src/ATen/Dispatch.h
+++ b/aten/src/ATen/Dispatch.h
@@ -291,6 +291,22 @@
       AT_DISPATCH_CASE_FLOATING_TYPES_AND3(                 \
           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__))
 
+#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4(                \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \
+  AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)               \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                 \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)
+
+#define AT_DISPATCH_FLOATING_TYPES_AND4(                                 \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \
+  AT_DISPATCH_SWITCH(                                                    \
+      TYPE,                                                              \
+      NAME,                                                              \
+      AT_DISPATCH_CASE_FLOATING_TYPES_AND4(                              \
+          SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
+
 #define AT_DISPATCH_CASE_COMPLEX_TYPES(...)                    \
   AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \
   AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__)
@@ -515,6 +531,73 @@
       AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4(                       \
           SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__))
 
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5(                      \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)                     \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)                              \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(    \
+    SCALARTYPE1,                                   \
+    SCALARTYPE2,                                   \
+    SCALARTYPE3,                                   \
+    SCALARTYPE4,                                   \
+    SCALARTYPE5,                                   \
+    TYPE,                                          \
+    NAME,                                          \
+    ...)                                           \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \
+          SCALARTYPE1,                             \
+          SCALARTYPE2,                             \
+          SCALARTYPE3,                             \
+          SCALARTYPE4,                             \
+          SCALARTYPE5,                             \
+          __VA_ARGS__))
+
+#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6(  \
+    SCALARTYPE1,                                      \
+    SCALARTYPE2,                                      \
+    SCALARTYPE3,                                      \
+    SCALARTYPE4,                                      \
+    SCALARTYPE5,                                      \
+    SCALARTYPE6,                                      \
+    ...)                                              \
+  AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \
+  AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__)          \
+  AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__)
+
+#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(    \
+    SCALARTYPE1,                                   \
+    SCALARTYPE2,                                   \
+    SCALARTYPE3,                                   \
+    SCALARTYPE4,                                   \
+    SCALARTYPE5,                                   \
+    SCALARTYPE6,                                   \
+    TYPE,                                          \
+    NAME,                                          \
+    ...)                                           \
+  AT_DISPATCH_SWITCH(                              \
+      TYPE,                                        \
+      NAME,                                        \
+      AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \
+          SCALARTYPE1,                             \
+          SCALARTYPE2,                             \
+          SCALARTYPE3,                             \
+          SCALARTYPE4,                             \
+          SCALARTYPE5,                             \
+          SCALARTYPE6,                             \
+          __VA_ARGS__))
+
 #define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...)     \
   AT_DISPATCH_SWITCH(                                \
       TYPE,                                          \
diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h
index 4e1c087..06b2533 100644
--- a/aten/src/ATen/NumericUtils.h
+++ b/aten/src/ATen/NumericUtils.h
@@ -6,6 +6,8 @@
 
 #include <c10/macros/Macros.h>
 #include <c10/util/BFloat16.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/Half.h>
 #include <c10/util/complex.h>
 
@@ -62,6 +64,22 @@
   return at::_isnan(static_cast<float>(val));
 }
 
+template <
+    typename T,
+    typename std::enable_if<std::is_same<T, at::Float8_e5m2>::value, int>::
+        type = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return val.isnan();
+}
+
+template <
+    typename T,
+    typename std::enable_if<std::is_same<T, at::Float8_e4m3fn>::value, int>::
+        type = 0>
+inline C10_HOST_DEVICE bool _isnan(T val) {
+  return val.isnan();
+}
+
 // std::isinf isn't performant to use on integral types; it will
 // (uselessly) convert to floating point and then do the test.
 // This function is.
@@ -92,6 +110,14 @@
   return at::_isinf(static_cast<float>(val));
 }
 
+inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) {
+  return val.isinf();
+}
+
+inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
+  return false;
+}
+
 template <typename T>
 C10_HOST_DEVICE inline T exp(T x) {
   static_assert(
diff --git a/aten/src/ATen/OpMathType.h b/aten/src/ATen/OpMathType.h
index f08e420..ddb2ce7 100644
--- a/aten/src/ATen/OpMathType.h
+++ b/aten/src/ATen/OpMathType.h
@@ -3,6 +3,8 @@
 #include <c10/core/ScalarType.h>
 #include <c10/util/BFloat16.h>
 #include <c10/util/Exception.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/Half.h>
 
 namespace at {
@@ -21,6 +23,14 @@
   using type = float;
 };
 template <>
+struct OpMathType<at::Float8_e5m2> {
+  using type = float;
+};
+template <>
+struct OpMathType<at::Float8_e4m3fn> {
+  using type = float;
+};
+template <>
 struct OpMathType<c10::complex<Half>> {
   using type = c10::complex<float>;
 };
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
index da9c6b6..d0995ca 100644
--- a/aten/src/ATen/cpu/vec/vec_base.h
+++ b/aten/src/ATen/cpu/vec/vec_base.h
@@ -72,7 +72,9 @@
     std::integral_constant<bool,
       std::is_floating_point<T>::value ||
       std::is_same<T, at::Half>::value ||
-      std::is_same<T, at::BFloat16>::value> {
+      std::is_same<T, at::BFloat16>::value ||
+      std::is_same<T, at::Float8_e5m2>::value ||
+      std::is_same<T, at::Float8_e4m3fn>::value> {
 };
 
 template<typename T>
diff --git a/aten/src/ATen/native/Copy.cpp b/aten/src/ATen/native/Copy.cpp
index 2978cc6..34bedc8 100644
--- a/aten/src/ATen/native/Copy.cpp
+++ b/aten/src/ATen/native/Copy.cpp
@@ -48,6 +48,18 @@
       self.numel() >= MIN_SZ;
 }
 
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...)                                   \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                  \
+            kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_CP_TYPES(TYPE, NAME, ...)     \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(    \
+            kComplexHalf, kHalf, kBool, kBFloat16, \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
 // special case copy where tensor is contiguous and src is a transposed matrix
 // This can be generalized to most copies, but it's trickier
 void copy_same_type_transpose_(Tensor& self, const Tensor& src) {
@@ -65,7 +77,7 @@
   // The code below is implemented with the assumption that sizes are equal
   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self.sizes().equals(src.sizes()));
 
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kHalf, kBool, kBFloat16, kComplexHalf, self.scalar_type(), "copy_", [&] {
+  _AT_DISPATCH_CP_TYPES(self.scalar_type(), "copy_", [&] {
     scalar_t* sp = src.data_ptr<scalar_t>();
     scalar_t* rp = self.data_ptr<scalar_t>();
     scalar_t* bp = buf.data_ptr<scalar_t>();
diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp
index 988729b..46ae3a5 100644
--- a/aten/src/ATen/native/LinearAlgebra.cpp
+++ b/aten/src/ATen/native/LinearAlgebra.cpp
@@ -1310,6 +1310,18 @@
   return self.reshape_symint({self.sym_size(0), 1}) * vec2;
 }
 
+
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...)    \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(      \
+            kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...)        \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16, \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
 static void addmm_impl_cpu_(
     Tensor &result, const Tensor &self, Tensor m1, Tensor m2, const Scalar& beta, const Scalar& alpha) {
   TORCH_INTERNAL_ASSERT(self.dim() == 2 && m1.dim() == 2 && m2.dim() == 2);
@@ -1438,9 +1450,7 @@
 
   if(!dispatched) {
     // Apply BLAS routine
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
-        result.scalar_type(), "addmm_impl_cpu_",
-        [&]{
+    _AT_DISPATCH_ADDMM_TYPES(result.scalar_type(), "addmm_impl_cpu_", [&]{
           using opmath_t = at::opmath_type<scalar_t>;
           at::native::cpublas::gemm(
               transpose_a ? a.is_conj() ? TransposeType::ConjTranspose : TransposeType::Transpose : TransposeType::NoTranspose,
diff --git a/aten/src/ATen/native/Scalar.cpp b/aten/src/ATen/native/Scalar.cpp
index b192d38..4906948 100644
--- a/aten/src/ATen/native/Scalar.cpp
+++ b/aten/src/ATen/native/Scalar.cpp
@@ -28,10 +28,21 @@
   }
 }
 
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...)                                   \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                  \
+            kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_SD_TYPES(TYPE, NAME, ...)     \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(    \
+            kComplexHalf, kHalf, kBool, kBFloat16, \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
 Scalar _local_scalar_dense_cpu(const Tensor& self) {
   Scalar r;
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
-    kComplexHalf, kHalf, kBool, kBFloat16, self.scalar_type(), "_local_scalar_dense_cpu", [&] {
+  _AT_DISPATCH_SD_TYPES(self.scalar_type(), "_local_scalar_dense_cpu", [&] {
         scalar_t value = *self.data_ptr<scalar_t>();
         r = Scalar(value);
       });
diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp
index e606f0b..6178054 100644
--- a/aten/src/ATen/native/TensorCompare.cpp
+++ b/aten/src/ATen/native/TensorCompare.cpp
@@ -369,6 +369,18 @@
   return at::imag(self) == 0;
 }
 
+
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...)                          \
+        AT_DISPATCH_FLOATING_TYPES_AND3( kHalf, kBFloat16, kFloat8_e5m2, \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_INF_TYPES(TYPE, NAME, ...)           \
+        AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
+
 Tensor isinf(const Tensor &self) {
   // Note: Integral tensor values are never infinite
   if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/true)) {
@@ -381,7 +393,7 @@
           (at::isinf(at::imag(self)));
   }
 
-  return AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, self.scalar_type(), "isinf", [&]() {
+  return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isinf", [&]() {
     return self.abs() == std::numeric_limits<scalar_t>::infinity();
   });
 }
@@ -397,7 +409,7 @@
     return at::isfinite(at::real(self)).__iand__(at::isfinite(at::imag(self)));
   }
 
-  return AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, self.scalar_type(), "isfinite", [&]() {
+  return _AT_DISPATCH_INF_TYPES(self.scalar_type(), "isfinite", [&]() {
     return (self == self) * (self.abs() != std::numeric_limits<scalar_t>::infinity());
   });
 }
diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
index 67c8314..6758001 100644
--- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
@@ -61,6 +61,34 @@
   });
 }
 
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...)                         \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                  \
+            kComplexHalf, kHalf, kBool, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...)                   \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(                           \
+            kComplexHalf, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...)                   \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                   \
+            kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn,       \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...)  \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(           \
+            kComplexHalf, kHalf, kBool, kBFloat16,        \
+            TYPE, NAME, __VA_ARGS__)
+#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...)  \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(          \
+            kComplexHalf, kHalf, kBFloat16,              \
+            TYPE, NAME, __VA_ARGS__)
+#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...)          \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(          \
+            kHalf, kBFloat16,                            \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
 void mul_kernel(TensorIteratorBase& iter) {
   auto dtype = iter.common_dtype();
   if (dtype == ScalarType::Bool) {
@@ -85,7 +113,7 @@
         });
     });
   } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, dtype, "mul_cpu", [&]() {
+    _AT_DISPATCH_MUL_TYPES(dtype, "mul_cpu", [&]() {
       cpu_kernel_vec(iter,
         [=](scalar_t a, scalar_t b) __ubsan_ignore_undefined__ -> scalar_t { return a * b; },
         [=](Vectorized<scalar_t> a, Vectorized<scalar_t> b) __ubsan_ignore_undefined__ {
@@ -528,14 +556,14 @@
 void eq_kernel(TensorIteratorBase& iter) {
   // See Note [special-case bool outputs]
   if (iter.dtype() == ScalarType::Bool) {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
+    _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
       cpu_kernel(iter,
         [](scalar_t a, scalar_t b) -> bool {
           return a == b;
         });
     });
   } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.common_dtype(), "eq_cpu", [&]() {
+    _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "eq_cpu", [&]() {
       cpu_kernel_vec(
         iter,
         [](scalar_t a, scalar_t b) -> scalar_t {
@@ -551,14 +579,14 @@
 void ne_kernel(TensorIteratorBase& iter) {
   // See Note [special-case bool outputs]
   if (iter.dtype() == ScalarType::Bool) {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kBool, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
+    _AT_DISPATCH_ALL_TYPES_AND_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
       cpu_kernel(iter,
         [](scalar_t a, scalar_t b) -> bool {
           return a != b;
         });
     });
   } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(kComplexHalf, kBFloat16, kHalf, iter.common_dtype(), "ne_cpu", [&]() {
+    _AT_DISPATCH_ALL_TYPES_NO_BOOL(iter.common_dtype(), "ne_cpu", [&]() {
       cpu_kernel_vec(
         iter,
         [](scalar_t a, scalar_t b) -> scalar_t {
diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp
index 3114e0b..d076158 100644
--- a/aten/src/ATen/native/cpu/BlasKernel.cpp
+++ b/aten/src/ATen/native/cpu/BlasKernel.cpp
@@ -263,6 +263,17 @@
   }
 }
 
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...)                  \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                   \
+            kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn,       \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...)         \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(          \
+            kHalf, kBFloat16,                            \
+            TYPE, NAME, __VA_ARGS__)
+#endif
 void cpublas_gemm_impl(
     at::ScalarType type,
     TransposeType transa, TransposeType transb,
@@ -272,9 +283,7 @@
     const void *b, int64_t ldb,
     const Scalar& beta,
     void *c, int64_t ldc) {
-  AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(at::kHalf, at::kBFloat16,
-    type, "cpublas_gemm_impl",
-      [&]{
+  _AT_DISPATCH_GEMM_TYPES(type, "cpublas_gemm_impl", [&]{
         using opmath_t = at::opmath_type<scalar_t>;
         gemm_core_(
             transa, transb, m, n, k,
diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp
index f6784bf..c3bedea 100644
--- a/aten/src/ATen/native/cpu/CopyKernel.cpp
+++ b/aten/src/ATen/native/cpu/CopyKernel.cpp
@@ -164,6 +164,27 @@
   }
 }
 
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                                       \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6(                                       \
+            ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,              \
+            ScalarType::BFloat16, ScalarType::Float8_e5m2, ScalarType::Float8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...)              \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5(                    \
+            kBool, kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...)                                               \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                                               \
+            ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool,ScalarType::BFloat16, \
+            TYPE, NAME, __VA_ARGS__)
+#define _AT_DISPATCH_ALL_TYPES_NO_CF(TYPE, NAME, ...) \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(       \
+            kBool, kHalf, kBFloat16,                  \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
 void direct_copy_kernel(TensorIteratorBase &iter) {
   // TODO: we don't actually need separate instantiations per dtype;
   // we only need a separate instantiation per dtype size. This would
@@ -183,8 +204,7 @@
   } else if (dtype == ScalarType::ComplexHalf) {
     cpu_kernel(iter, [=](c10::complex<at::Half> a) -> c10::complex<at::Half> { return a; });
   } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
-        kBool, kHalf, kBFloat16, dtype, "copy_kernel", [&] {
+    _AT_DISPATCH_ALL_TYPES_NO_CF(dtype, "copy_kernel", [&] {
       cpu_kernel_vec(
           iter,
           [=](scalar_t a) -> scalar_t { return a; },
@@ -237,9 +257,9 @@
     sizeof(BFloat16) == strides_out[0] && (sizeof(float) == strides_in[0] || strides_in[0] == 0)))) {
     float_bfloat16_copy_kernel(iter, requires_neg);
   } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, dtype, "copy_", [&] {
+    _AT_DISPATCH_ALL_TYPES(dtype, "copy_", [&] {
       using dest_t = scalar_t;
-      AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(ScalarType::ComplexHalf, ScalarType::Half, ScalarType::Bool, ScalarType::BFloat16, iter.dtype(1), "copy_", [&] {
+      _AT_DISPATCH_ALL_TYPES(iter.dtype(1), "copy_", [&] {
         if (iter.has_contiguous_first_dim()) {
           TORCH_INTERNAL_ASSERT(iter.ninputs() == 1);
           TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
index 6a54be7..90b41a3 100644
--- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
+++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
@@ -179,6 +179,18 @@
       });
 }
 
+#if !defined(C10_MOBILE)
+#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...)                   \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(                   \
+            kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn,       \
+            TYPE, NAME, __VA_ARGS__)
+#else
+#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...)          \
+        AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(          \
+            kHalf, kBFloat16,                            \
+            TYPE, NAME, __VA_ARGS__)
+#endif
+
 static void abs_kernel(TensorIteratorBase& iter) {
   auto dtype = iter.dtype();
   if (dtype == kComplexHalf) {
@@ -186,7 +198,7 @@
     using opmath_t = at::opmath_type<scalar_t>;
     cpu_kernel(iter, [=](scalar_t a) -> scalar_t { return abs_impl(opmath_t{a}); });
   } else {
-    AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, iter.dtype(), "abs_cpu", [&]() {
+    _AT_DISPATCH_ABS_TYPES(iter.dtype(), "abs_cpu", [&]() {
       cpu_kernel_vec(
           iter,
           [=](scalar_t a) -> scalar_t { return abs_impl(a); },
diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h
index 40841c2..1a17c22 100644
--- a/aten/src/ATen/native/cuda/jit_utils.h
+++ b/aten/src/ATen/native/cuda/jit_utils.h
@@ -186,6 +186,12 @@
 template <> inline std::string typeName<at::BFloat16>(){
     return "at::BFloat16";
 }
+template <> inline std::string typeName<at::Float8_e5m2>(){
+    return "at::Float8_e5m2";
+}
+template <> inline std::string typeName<at::Float8_e4m3fn>(){
+    return "at::Float8_e4m3fn";
+}
 
 #define TYPE_NAME_CASE(ctype, scalartype)                    \
   case ScalarType::scalartype:  return typeName<ctype>();
diff --git a/c10/core/Scalar.h b/c10/core/Scalar.h
index 0ea0913..8152632 100644
--- a/c10/core/Scalar.h
+++ b/c10/core/Scalar.h
@@ -48,7 +48,13 @@
 #define DEFINE_IMPLICIT_CTOR(type, name) \
   Scalar(type vv) : Scalar(vv, true) {}
 
-  AT_FORALL_SCALAR_TYPES_AND3(Half, BFloat16, ComplexHalf, DEFINE_IMPLICIT_CTOR)
+  AT_FORALL_SCALAR_TYPES_AND5(
+      Half,
+      BFloat16,
+      Float8_e5m2,
+      Float8_e4m3fn,
+      ComplexHalf,
+      DEFINE_IMPLICIT_CTOR)
   AT_FORALL_COMPLEX_TYPES(DEFINE_IMPLICIT_CTOR)
 
 #undef DEFINE_IMPLICIT_CTOR
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h
index c80ca2a..fa36a8f 100644
--- a/c10/core/ScalarType.h
+++ b/c10/core/ScalarType.h
@@ -3,6 +3,8 @@
 #include <c10/util/BFloat16.h>
 #include <c10/util/Deprecated.h>
 #include <c10/util/Exception.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/Half.h>
 #include <c10/util/bits.h>
 #include <c10/util/complex.h>
@@ -50,7 +52,9 @@
   _(c10::bits2x4, Bits2x4) /* 19 */                      \
   _(c10::bits4x2, Bits4x2) /* 20 */                      \
   _(c10::bits8, Bits8) /* 21 */                          \
-  _(c10::bits16, Bits16) /* 22 */
+  _(c10::bits16, Bits16) /* 22 */                        \
+  _(c10::Float8_e5m2, Float8_e5m2) /* 23 */              \
+  _(c10::Float8_e4m3fn, Float8_e4m3fn) /* 24 */
 
 // If you want to support ComplexHalf for real, add ComplexHalf
 // into this macro (and change the name).  But beware: convert()
@@ -67,7 +71,9 @@
   _(c10::complex<float>, ComplexFloat)                             \
   _(c10::complex<double>, ComplexDouble)                           \
   _(bool, Bool)                                                    \
-  _(at::BFloat16, BFloat16)
+  _(at::BFloat16, BFloat16)                                        \
+  _(at::Float8_e5m2, Float8_e5m2)                                  \
+  _(at::Float8_e4m3fn, Float8_e4m3fn)
 
 #define AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(_) \
   _(uint8_t, Byte)                             \
@@ -82,7 +88,9 @@
   _(c10::complex<float>, ComplexFloat)         \
   _(c10::complex<double>, ComplexDouble)       \
   _(bool, Bool)                                \
-  _(at::BFloat16, BFloat16)
+  _(at::BFloat16, BFloat16)                    \
+  _(at::Float8_e5m2, Float8_e5m2)              \
+  _(at::Float8_e4m3fn, Float8_e4m3fn)
 
 enum class ScalarType : int8_t {
 #define DEFINE_ENUM(_1, n) n,
@@ -201,6 +209,53 @@
              ::c10::ScalarType::SCALARTYPE3>::t),                             \
     SCALARTYPE3)
 
+#define AT_FORALL_SCALAR_TYPES_AND4(                       \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, _) \
+  _(uint8_t, Byte)                                         \
+  _(int8_t, Char)                                          \
+  _(int16_t, Short)                                        \
+  _(int, Int)                                              \
+  _(int64_t, Long)                                         \
+  _(float, Float)                                          \
+  _(double, Double)                                        \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<             \
+             ::c10::ScalarType::SCALARTYPE1>::t),          \
+    SCALARTYPE1)                                           \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<             \
+             ::c10::ScalarType::SCALARTYPE2>::t),          \
+    SCALARTYPE2)                                           \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<             \
+             ::c10::ScalarType::SCALARTYPE3>::t),          \
+    SCALARTYPE3)                                           \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<             \
+             ::c10::ScalarType::SCALARTYPE4>::t),          \
+    SCALARTYPE4)
+
+#define AT_FORALL_SCALAR_TYPES_AND5(                                    \
+    SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, _) \
+  _(uint8_t, Byte)                                                      \
+  _(int8_t, Char)                                                       \
+  _(int16_t, Short)                                                     \
+  _(int, Int)                                                           \
+  _(int64_t, Long)                                                      \
+  _(float, Float)                                                       \
+  _(double, Double)                                                     \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<                          \
+             ::c10::ScalarType::SCALARTYPE1>::t),                       \
+    SCALARTYPE1)                                                        \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<                          \
+             ::c10::ScalarType::SCALARTYPE2>::t),                       \
+    SCALARTYPE2)                                                        \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<                          \
+             ::c10::ScalarType::SCALARTYPE3>::t),                       \
+    SCALARTYPE3)                                                        \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<                          \
+             ::c10::ScalarType::SCALARTYPE4>::t),                       \
+    SCALARTYPE4)                                                        \
+  _(decltype(::c10::impl::ScalarTypeToCPPType<                          \
+             ::c10::ScalarType::SCALARTYPE5>::t),                       \
+    SCALARTYPE5)
+
 #define AT_FORALL_QINT_TYPES(_) \
   _(c10::qint8, QInt8)          \
   _(c10::quint8, QUInt8)        \
@@ -261,7 +316,8 @@
 static inline bool isFloatingType(ScalarType t) {
   return (
       t == ScalarType::Double || t == ScalarType::Float ||
-      t == ScalarType::Half || t == ScalarType::BFloat16);
+      t == ScalarType::Half || t == ScalarType::BFloat16 ||
+      t == ScalarType::Float8_e5m2 || t == ScalarType::Float8_e4m3fn);
 }
 
 static inline bool isReducedFloatingType(ScalarType t) {
@@ -334,7 +390,8 @@
     case ScalarType::ComplexFloat:
     case ScalarType::ComplexDouble:
       return true;
-      AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
+      AT_FORALL_SCALAR_TYPES_AND5(
+          Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
     default:
       TORCH_CHECK(false, "Unknown ScalarType");
   }
@@ -425,6 +482,8 @@
   constexpr auto c8 = ScalarType::ComplexDouble;
   constexpr auto b1 = ScalarType::Bool;
   constexpr auto bf = ScalarType::BFloat16;
+  constexpr auto b8 = ScalarType::Float8_e5m2;
+  constexpr auto h8 = ScalarType::Float8_e4m3fn;
   constexpr auto ud = ScalarType::Undefined;
   if (a == ud || b == ud) {
     return ScalarType::Undefined;
@@ -462,23 +521,25 @@
   // clang-format off
   static constexpr ScalarType _promoteTypesLookup[
       NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = {
-      /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  q1  q2  q3  bf*/
-      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf},
-      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf},
-      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf},
-      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf},
-      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf},
-      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4},
-      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4},
-      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8},
-      /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4},
-      /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4},
-      /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8},
-      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf},
-      /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
-      /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
-      /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
-      /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf},
+      /*        u1  i1  i2  i4  i8  f2  f4  f8  c2  c4  c8  b1  q1  q2  q3  bf  b8  h8*/
+      /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, ud, ud, ud, bf, b8, h8},
+      /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, ud, ud, ud, bf, b8, h8},
+      /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, ud, ud, ud, bf, b8, h8},
+      /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, ud, ud, ud, bf, b8, h8},
+      /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, ud, ud, ud, bf, b8, h8},
+      /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, ud, ud, ud, f4, f4, f4},
+      /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, ud, ud, ud, f4, f4, f4},
+      /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, ud, ud, ud, f8, f8, f8},
+      /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, ud, ud, ud, c4, c4, c4},
+      /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, ud, ud, ud, c4, c4, c4},
+      /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, ud, ud, ud, c8, c8, c8},
+      /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, ud, ud, ud, bf, b8, h8},
+      /* q1 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+      /* q2 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+      /* q3 */ {ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud, ud},
+      /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, ud, ud, ud, bf, bf, bf},
+      /* b8 */ {b8, b8, b8, b8, b8, f4, f4, f8, c4, c4, c8, b8, ud, ud, ud, bf, b8, ud},
+      /* h8 */ {h8, h8, h8, h8, h8, f4, f4, f8, c4, c4, c8, h8, ud, ud, ud, bf, ud, h8},
   };
   // clang-format on
   return _promoteTypesLookup[static_cast<int>(a)][static_cast<int>(b)];
@@ -490,8 +551,4 @@
   return stream << toString(scalar_type);
 }
 
-#define AT_FORAUTOCAST_SCALAR_TYPES(_) \
-  _(half, Half) /* 0 */                \
-  _(bfloat16, BFloat16) /* 1 */
-
 } // namespace c10
diff --git a/c10/util/Float8_e4m3fn-inl.h b/c10/util/Float8_e4m3fn-inl.h
new file mode 100644
index 0000000..fc52b49
--- /dev/null
+++ b/c10/util/Float8_e4m3fn-inl.h
@@ -0,0 +1,275 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <cstring>
+#include <limits>
+
+C10_CLANG_DIAGNOSTIC_PUSH()
+#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
+C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
+#endif
+
+namespace c10 {
+
+/// Constructors
+
+inline C10_HOST_DEVICE Float8_e4m3fn::Float8_e4m3fn(float value) {
+  x = detail::fp8e4m3fn_from_fp32_value(value);
+}
+
+/// Implicit conversions
+
+inline C10_HOST_DEVICE Float8_e4m3fn::operator float() const {
+  return detail::fp8e4m3fn_to_fp32_value(x);
+}
+
+/// Special values helper
+
+inline C10_HOST_DEVICE bool Float8_e4m3fn::isnan() const {
+  return (x & 0b01111111) == 0b01111111;
+}
+
+/// Arithmetic
+
+inline C10_HOST_DEVICE Float8_e4m3fn
+operator+(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
+  return static_cast<float>(a) + static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn
+operator-(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
+  return static_cast<float>(a) - static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn
+operator*(const Float8_e4m3fn& a, const Float8_e4m3fn& b) {
+  return static_cast<float>(a) * static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn operator/(
+    const Float8_e4m3fn& a,
+    const Float8_e4m3fn& b) __ubsan_ignore_float_divide_by_zero__ {
+  return static_cast<float>(a) / static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn operator-(const Float8_e4m3fn& a) {
+  return -static_cast<float>(a);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn& operator+=(
+    Float8_e4m3fn& a,
+    const Float8_e4m3fn& b) {
+  a = a + b;
+  return a;
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn& operator-=(
+    Float8_e4m3fn& a,
+    const Float8_e4m3fn& b) {
+  a = a - b;
+  return a;
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn& operator*=(
+    Float8_e4m3fn& a,
+    const Float8_e4m3fn& b) {
+  a = a * b;
+  return a;
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn& operator/=(
+    Float8_e4m3fn& a,
+    const Float8_e4m3fn& b) {
+  a = a / b;
+  return a;
+}
+
+/// Arithmetic with floats
+
+inline C10_HOST_DEVICE float operator+(Float8_e4m3fn a, float b) {
+  return static_cast<float>(a) + b;
+}
+inline C10_HOST_DEVICE float operator-(Float8_e4m3fn a, float b) {
+  return static_cast<float>(a) - b;
+}
+inline C10_HOST_DEVICE float operator*(Float8_e4m3fn a, float b) {
+  return static_cast<float>(a) * b;
+}
+inline C10_HOST_DEVICE float operator/(Float8_e4m3fn a, float b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return static_cast<float>(a) / b;
+}
+
+inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fn b) {
+  return a + static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fn b) {
+  return a - static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fn b) {
+  return a * static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fn b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return a / static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fn& b) {
+  return a += static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fn& b) {
+  return a -= static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fn& b) {
+  return a *= static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fn& b) {
+  return a /= static_cast<float>(b);
+}
+
+/// Arithmetic with doubles
+
+inline C10_HOST_DEVICE double operator+(Float8_e4m3fn a, double b) {
+  return static_cast<double>(a) + b;
+}
+inline C10_HOST_DEVICE double operator-(Float8_e4m3fn a, double b) {
+  return static_cast<double>(a) - b;
+}
+inline C10_HOST_DEVICE double operator*(Float8_e4m3fn a, double b) {
+  return static_cast<double>(a) * b;
+}
+inline C10_HOST_DEVICE double operator/(Float8_e4m3fn a, double b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return static_cast<double>(a) / b;
+}
+
+inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fn b) {
+  return a + static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fn b) {
+  return a - static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fn b) {
+  return a * static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fn b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return a / static_cast<double>(b);
+}
+
+/// Arithmetic with ints
+
+inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int b) {
+  return a + static_cast<Float8_e4m3fn>(b);
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int b) {
+  return a - static_cast<Float8_e4m3fn>(b);
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int b) {
+  return a * static_cast<Float8_e4m3fn>(b);
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int b) {
+  return a / static_cast<Float8_e4m3fn>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn operator+(int a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) + b;
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator-(int a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) - b;
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator*(int a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) * b;
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator/(int a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) / b;
+}
+
+//// Arithmetic with int64_t
+
+inline C10_HOST_DEVICE Float8_e4m3fn operator+(Float8_e4m3fn a, int64_t b) {
+  return a + static_cast<Float8_e4m3fn>(b);
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator-(Float8_e4m3fn a, int64_t b) {
+  return a - static_cast<Float8_e4m3fn>(b);
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator*(Float8_e4m3fn a, int64_t b) {
+  return a * static_cast<Float8_e4m3fn>(b);
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator/(Float8_e4m3fn a, int64_t b) {
+  return a / static_cast<Float8_e4m3fn>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e4m3fn operator+(int64_t a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) + b;
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator-(int64_t a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) - b;
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator*(int64_t a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) * b;
+}
+inline C10_HOST_DEVICE Float8_e4m3fn operator/(int64_t a, Float8_e4m3fn b) {
+  return static_cast<Float8_e4m3fn>(a) / b;
+}
+
+/// NOTE: we do not define comparisons directly and instead rely on the implicit
+/// conversion from c10::Float8_e4m3fn to float.
+
+} // namespace c10
+
+namespace std {
+
+template <>
+class numeric_limits<c10::Float8_e4m3fn> {
+ public:
+  static constexpr bool is_specialized = true;
+  static constexpr bool is_signed = true;
+  static constexpr bool is_integer = false;
+  static constexpr bool is_exact = false;
+  static constexpr bool has_infinity = false;
+  static constexpr bool has_quiet_NaN = true;
+  static constexpr bool has_signaling_NaN = false;
+  static constexpr auto has_denorm = true;
+  static constexpr auto has_denorm_loss = true;
+  static constexpr auto round_style = numeric_limits<float>::round_style;
+  static constexpr bool is_iec559 = false;
+  static constexpr bool is_bounded = true;
+  static constexpr bool is_modulo = false;
+  static constexpr int digits = 4;
+  static constexpr int digits10 = 0;
+  static constexpr int max_digits10 = 3;
+  static constexpr int radix = 2;
+  static constexpr int min_exponent = -5;
+  static constexpr int min_exponent10 = -1;
+  static constexpr int max_exponent = 8;
+  static constexpr int max_exponent10 = 2;
+  static constexpr auto traps = numeric_limits<float>::traps;
+  static constexpr auto tinyness_before = false;
+
+  static constexpr c10::Float8_e4m3fn min() {
+    return c10::Float8_e4m3fn(0x08, c10::Float8_e4m3fn::from_bits());
+  }
+  static constexpr c10::Float8_e4m3fn lowest() {
+    return c10::Float8_e4m3fn(0xFE, c10::Float8_e4m3fn::from_bits());
+  }
+  static constexpr c10::Float8_e4m3fn max() {
+    return c10::Float8_e4m3fn(0x7E, c10::Float8_e4m3fn::from_bits());
+  }
+  static constexpr c10::Float8_e4m3fn epsilon() {
+    return c10::Float8_e4m3fn(0x20, c10::Float8_e4m3fn::from_bits());
+  }
+  static constexpr c10::Float8_e4m3fn round_error() {
+    return c10::Float8_e4m3fn(0x30, c10::Float8_e4m3fn::from_bits());
+  }
+  static constexpr c10::Float8_e4m3fn quiet_NaN() {
+    return c10::Float8_e4m3fn(0x7F, c10::Float8_e4m3fn::from_bits());
+  }
+  static constexpr c10::Float8_e4m3fn denorm_min() {
+    return c10::Float8_e4m3fn(0x01, c10::Float8_e4m3fn::from_bits());
+  }
+};
+
+} // namespace std
+
+C10_CLANG_DIAGNOSTIC_POP()
diff --git a/c10/util/Float8_e4m3fn.cpp b/c10/util/Float8_e4m3fn.cpp
new file mode 100644
index 0000000..b5f866d
--- /dev/null
+++ b/c10/util/Float8_e4m3fn.cpp
@@ -0,0 +1,14 @@
+#include <c10/util/Float8_e4m3fn.h>
+#include <iostream>
+
+namespace c10 {
+
+static_assert(
+    std::is_standard_layout<Float8_e4m3fn>::value,
+    "c10::Float8_e4m3fn must be standard layout.");
+
+std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value) {
+  out << (float)value;
+  return out;
+}
+} // namespace c10
diff --git a/c10/util/Float8_e4m3fn.h b/c10/util/Float8_e4m3fn.h
new file mode 100644
index 0000000..74bc3b6
--- /dev/null
+++ b/c10/util/Float8_e4m3fn.h
@@ -0,0 +1,240 @@
+#pragma once
+
+/// Defines the Float8_e4m3fn type (8-bit floating-point) including conversions
+/// to standard C types and basic arithmetic operations. Note that arithmetic
+/// operations are implemented by converting to floating point and
+/// performing the operation in float32.
+/// Binary configuration:
+/// s eeee mmm
+/// 1 sign bit
+/// 4 exponent bits
+/// 3 mantissa bits
+/// bias = 7
+///
+/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
+/// and inspired by Half implementation from pytorch/c10/util/Half.h
+
+#include <c10/macros/Macros.h>
+#include <c10/util/C++17.h>
+#include <c10/util/TypeSafeSignMath.h>
+#include <c10/util/floating_point_utils.h>
+#include <type_traits>
+
+#if defined(__cplusplus) && (__cplusplus >= 201103L)
+#include <cmath>
+#include <cstdint>
+#elif !defined(__OPENCL_VERSION__)
+#include <math.h>
+#include <stdint.h>
+#endif
+
+#ifdef _MSC_VER
+#include <intrin.h>
+#endif
+
+#include <cstdint>
+#include <cstring>
+#include <iosfwd>
+#include <limits>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <utility>
+
+#include <typeinfo> // operator typeid
+
+namespace c10 {
+
+namespace detail {
+
+/*
+ * Convert a 8-bit floating-point number in fp8 E4M3FN format, in bit
+ * representation, to a 32-bit floating-point number in IEEE single-precision
+ * format, in bit representation.
+ *
+ * @note The implementation doesn't use any floating-point operations.
+ */
+inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
+  /*
+   * Extend the fp8 E4M3FN number to 32 bits and shift to the
+   * upper part of the 32-bit word:
+   *      +---+----+---+-----------------------------+
+   *      | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
+   *      +---+----+---+-----------------------------+
+   * Bits  31 27-30 24-26          0-23
+   *
+   * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
+   * - zero bits.
+   */
+  const uint32_t w = (uint32_t)input << 24;
+  /*
+   * Extract the sign of the input number into the high bit of the 32-bit word:
+   *
+   *      +---+----------------------------------+
+   *      | S |0000000 00000000 00000000 00000000|
+   *      +---+----------------------------------+
+   * Bits  31                 0-31
+   */
+  const uint32_t sign = w & UINT32_C(0x80000000);
+  /*
+   * Extract mantissa and biased exponent of the input number into the bits 0-30
+   * of the 32-bit word:
+   *
+   *      +---+----+---+-----------------------------+
+   *      | S |EEEE|MMM|0000 0000 0000 0000 0000 0000|
+   *      +---+----+---+-----------------------------+
+   * Bits  31  27-30 24-26      0-23
+   */
+  const uint32_t nonsign = w & UINT32_C(0x7FFFFFFF);
+  /*
+   * Renorm shift is the number of bits to shift mantissa left to make the
+   * half-precision number normalized. If the initial number is normalized, some
+   * of its high 5 bits (sign == 0 and 4-bit exponent) equals one. In this case
+   * renorm_shift == 0. If the number is denormalize, renorm_shift > 0. Note
+   * that if we shift denormalized nonsign by renorm_shift, the unit bit of
+   * mantissa will shift into exponent, turning the biased exponent into 1, and
+   * making mantissa normalized (i.e. without leading 1).
+   */
+#if defined(__CUDA_ARCH__)
+  uint32_t renorm_shift = __clz(nonsign);
+#elif defined(_MSC_VER)
+  unsigned long nonsign_bsr;
+  _BitScanReverse(&nonsign_bsr, (unsigned long)nonsign);
+  uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31;
+#else
+  uint32_t renorm_shift = __builtin_clz(nonsign);
+#endif
+  renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0;
+  /*
+   * Iff fp8e4m3fn number has all exponent and mantissa bits set to 1,
+   * the addition overflows it into bit 31, and the subsequent shift turns the
+   * high 9 bits into 1. Thus inf_nan_mask == 0x7F800000 if the fp8e4m3fn number
+   * is Nan, 0x00000000 otherwise
+   */
+  const int32_t inf_nan_mask =
+      ((int32_t)(nonsign + 0x01000000) >> 8) & INT32_C(0x7F800000);
+  /*
+   * Iff nonsign is 0, it overflows into 0xFFFFFFFF, turning bit 31
+   * into 1. Otherwise, bit 31 remains 0. The signed shift right by 31
+   * broadcasts bit 31 into all bits of the zero_mask. Thus zero_mask ==
+   * 0xFFFFFFFF if the half-precision number was zero (+0.0h or -0.0h)
+   * 0x00000000 otherwise
+   */
+  const int32_t zero_mask = (int32_t)(nonsign - 1) >> 31;
+  /*
+   * 1. Shift nonsign left by renorm_shift to normalize it (if the input
+   * was denormal)
+   * 2. Shift nonsign right by 4 so the exponent (4 bits originally)
+   * becomes an 8-bit field and 3-bit mantissa shifts into the 3 high
+   * bits of the 23-bit mantissa of IEEE single-precision number.
+   * 3. Add 0x78 to the exponent (starting at bit 23) to compensate the
+   * different in exponent bias (0x7F for single-precision number less 0x07
+   * for fp8e4m3fn number).
+   * 4. Subtract renorm_shift from the exponent (starting at bit 23) to
+   * account for renormalization. As renorm_shift is less than 0x78, this
+   * can be combined with step 3.
+   * 5. Binary OR with inf_nan_mask to turn the exponent into 0xFF if the
+   * input was NaN or infinity.
+   * 6. Binary ANDNOT with zero_mask to turn the mantissa and exponent
+   * into zero if the input was zero.
+   * 7. Combine with the sign of the input number.
+   */
+  uint32_t result = sign |
+      ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) |
+        inf_nan_mask) &
+       ~zero_mask);
+  return fp32_from_bits(result);
+}
+
+/*
+ * Convert a 32-bit floating-point number in IEEE single-precision format to a
+ * 8-bit floating-point number in fp8 E4M3FN format, in bit representation.
+ */
+inline C10_HOST_DEVICE uint8_t fp8e4m3fn_from_fp32_value(float f) {
+  /*
+   * Binary representation of 480.0f, which is the first value
+   * not representable in fp8e4m3fn range:
+   * 0 1111 111 - fp8e4m3fn
+   * 0 10000111 11100000000000000000000 - fp32
+   */
+  constexpr uint32_t fp8_max = UINT32_C(1087) << 20;
+
+  /*
+   * A mask for converting fp32 numbers lower than fp8e4m3fn normal range
+   * into denorm representation
+   * magic number: ((127 - 7) + (23 - 3) + 1)
+   */
+  constexpr uint32_t denorm_mask = UINT32_C(141) << 23;
+
+  uint32_t f_bits = fp32_to_bits(f);
+
+  uint8_t result = 0u;
+
+  /*
+   * Extract the sign of the input number into the high bit of the 32-bit word:
+   *
+   *      +---+----------------------------------+
+   *      | S |0000000 00000000 00000000 00000000|
+   *      +---+----------------------------------+
+   * Bits  31                 0-31
+   */
+  const uint32_t sign = f_bits & UINT32_C(0x80000000);
+
+  /*
+   * Set sign bit to 0
+   */
+  f_bits ^= sign;
+
+  if (f_bits >= fp8_max) {
+    // NaN - all exponent and mantissa bits set to 1
+    result = 0x7f;
+  } else {
+    if (f_bits < (UINT32_C(121) << 23)) {
+      // Input number is smaller than 2^(-6), which is the smallest
+      // fp8e4m3fn normal number
+      f_bits =
+          fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
+      result = static_cast<uint8_t>(f_bits - denorm_mask);
+    } else {
+      // resulting mantissa is odd
+      uint8_t mant_odd = (f_bits >> 20) & 1;
+
+      // update exponent, rounding bias part 1
+      f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
+
+      // rounding bias part 2
+      f_bits += mant_odd;
+
+      // take the bits!
+      result = static_cast<uint8_t>(f_bits >> 20);
+    }
+  }
+
+  result |= static_cast<uint8_t>(sign >> 24);
+  return result;
+}
+
+} // namespace detail
+
+struct alignas(1) Float8_e4m3fn {
+  uint8_t x;
+
+  struct from_bits_t {};
+  C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
+    return from_bits_t();
+  }
+
+  Float8_e4m3fn() = default;
+
+  constexpr C10_HOST_DEVICE Float8_e4m3fn(uint8_t bits, from_bits_t)
+      : x(bits){};
+  inline C10_HOST_DEVICE Float8_e4m3fn(float value);
+  inline C10_HOST_DEVICE operator float() const;
+  inline C10_HOST_DEVICE bool isnan() const;
+};
+
+C10_API std::ostream& operator<<(std::ostream& out, const Float8_e4m3fn& value);
+
+} // namespace c10
+
+#include <c10/util/Float8_e4m3fn-inl.h> // IWYU pragma: keep
diff --git a/c10/util/Float8_e5m2-inl.h b/c10/util/Float8_e5m2-inl.h
new file mode 100644
index 0000000..71056ec
--- /dev/null
+++ b/c10/util/Float8_e5m2-inl.h
@@ -0,0 +1,284 @@
+#pragma once
+
+#include <c10/macros/Macros.h>
+#include <cstring>
+#include <limits>
+
+C10_CLANG_DIAGNOSTIC_PUSH()
+#if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
+C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
+#endif
+
+#define EXP_WIDTH_FP8 5
+#define MAN_WIDTH_FP8 2
+#define EXP_BIAS_FP8 15
+
+namespace c10 {
+
+/// Constructors
+
+inline C10_HOST_DEVICE Float8_e5m2::Float8_e5m2(float value) {
+  x = detail::fp8e5m2_from_fp32_value(value);
+}
+
+/// Implicit conversions
+
+inline C10_HOST_DEVICE Float8_e5m2::operator float() const {
+  return detail::fp8e5m2_to_fp32_value(x);
+}
+
+/// Special values helpers
+
+inline C10_HOST_DEVICE bool Float8_e5m2::isnan() const {
+  return (x & 0b01111111) > 0b01111100;
+}
+
+inline C10_HOST_DEVICE bool Float8_e5m2::isinf() const {
+  return (x & 0b01111111) == 0b01111100;
+}
+
+/// Arithmetic
+
+inline C10_HOST_DEVICE Float8_e5m2
+operator+(const Float8_e5m2& a, const Float8_e5m2& b) {
+  return static_cast<float>(a) + static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2
+operator-(const Float8_e5m2& a, const Float8_e5m2& b) {
+  return static_cast<float>(a) - static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2
+operator*(const Float8_e5m2& a, const Float8_e5m2& b) {
+  return static_cast<float>(a) * static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2 operator/(
+    const Float8_e5m2& a,
+    const Float8_e5m2& b) __ubsan_ignore_float_divide_by_zero__ {
+  return static_cast<float>(a) / static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2 operator-(const Float8_e5m2& a) {
+  return -static_cast<float>(a);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2& operator+=(
+    Float8_e5m2& a,
+    const Float8_e5m2& b) {
+  a = a + b;
+  return a;
+}
+
+inline C10_HOST_DEVICE Float8_e5m2& operator-=(
+    Float8_e5m2& a,
+    const Float8_e5m2& b) {
+  a = a - b;
+  return a;
+}
+
+inline C10_HOST_DEVICE Float8_e5m2& operator*=(
+    Float8_e5m2& a,
+    const Float8_e5m2& b) {
+  a = a * b;
+  return a;
+}
+
+inline C10_HOST_DEVICE Float8_e5m2& operator/=(
+    Float8_e5m2& a,
+    const Float8_e5m2& b) {
+  a = a / b;
+  return a;
+}
+
+/// Arithmetic with floats
+
+inline C10_HOST_DEVICE float operator+(Float8_e5m2 a, float b) {
+  return static_cast<float>(a) + b;
+}
+inline C10_HOST_DEVICE float operator-(Float8_e5m2 a, float b) {
+  return static_cast<float>(a) - b;
+}
+inline C10_HOST_DEVICE float operator*(Float8_e5m2 a, float b) {
+  return static_cast<float>(a) * b;
+}
+inline C10_HOST_DEVICE float operator/(Float8_e5m2 a, float b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return static_cast<float>(a) / b;
+}
+
+inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2 b) {
+  return a + static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2 b) {
+  return a - static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2 b) {
+  return a * static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2 b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return a / static_cast<float>(b);
+}
+
+inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2& b) {
+  return a += static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2& b) {
+  return a -= static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2& b) {
+  return a *= static_cast<float>(b);
+}
+inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2& b) {
+  return a /= static_cast<float>(b);
+}
+
+/// Arithmetic with doubles
+
+inline C10_HOST_DEVICE double operator+(Float8_e5m2 a, double b) {
+  return static_cast<double>(a) + b;
+}
+inline C10_HOST_DEVICE double operator-(Float8_e5m2 a, double b) {
+  return static_cast<double>(a) - b;
+}
+inline C10_HOST_DEVICE double operator*(Float8_e5m2 a, double b) {
+  return static_cast<double>(a) * b;
+}
+inline C10_HOST_DEVICE double operator/(Float8_e5m2 a, double b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return static_cast<double>(a) / b;
+}
+
+inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2 b) {
+  return a + static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2 b) {
+  return a - static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2 b) {
+  return a * static_cast<double>(b);
+}
+inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2 b)
+    __ubsan_ignore_float_divide_by_zero__ {
+  return a / static_cast<double>(b);
+}
+
+/// Arithmetic with ints
+
+inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int b) {
+  return a + static_cast<Float8_e5m2>(b);
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int b) {
+  return a - static_cast<Float8_e5m2>(b);
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int b) {
+  return a * static_cast<Float8_e5m2>(b);
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int b) {
+  return a / static_cast<Float8_e5m2>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2 operator+(int a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) + b;
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator-(int a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) - b;
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator*(int a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) * b;
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator/(int a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) / b;
+}
+
+//// Arithmetic with int64_t
+
+inline C10_HOST_DEVICE Float8_e5m2 operator+(Float8_e5m2 a, int64_t b) {
+  return a + static_cast<Float8_e5m2>(b);
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator-(Float8_e5m2 a, int64_t b) {
+  return a - static_cast<Float8_e5m2>(b);
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator*(Float8_e5m2 a, int64_t b) {
+  return a * static_cast<Float8_e5m2>(b);
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator/(Float8_e5m2 a, int64_t b) {
+  return a / static_cast<Float8_e5m2>(b);
+}
+
+inline C10_HOST_DEVICE Float8_e5m2 operator+(int64_t a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) + b;
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator-(int64_t a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) - b;
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator*(int64_t a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) * b;
+}
+inline C10_HOST_DEVICE Float8_e5m2 operator/(int64_t a, Float8_e5m2 b) {
+  return static_cast<Float8_e5m2>(a) / b;
+}
+
+/// NOTE: we do not define comparisons directly and instead rely on the implicit
+/// conversion from c10::Float8_e5m2 to float.
+
+} // namespace c10
+
+namespace std {
+
+template <>
+class numeric_limits<c10::Float8_e5m2> {
+ public:
+  static constexpr bool is_signed = true;
+  static constexpr bool is_integer = false;
+  static constexpr bool is_specialized = true;
+  static constexpr bool is_exact = false;
+  static constexpr bool has_infinity = true;
+  static constexpr bool has_quiet_NaN = false;
+  static constexpr bool has_signaling_NaN = false;
+  static constexpr auto has_denorm = true;
+  static constexpr auto has_denorm_loss = true;
+  static constexpr auto round_style = numeric_limits<float>::round_style;
+  static constexpr bool is_iec559 = false;
+  static constexpr bool is_bounded = true;
+  static constexpr bool is_modulo = false;
+  static constexpr int digits = 3;
+  static constexpr int digits10 = 0;
+  static constexpr int max_digits10 = 2;
+  static constexpr int radix = 2;
+  static constexpr int min_exponent = -13;
+  static constexpr int min_exponent10 = -4;
+  static constexpr int max_exponent = 16;
+  static constexpr int max_exponent10 = 4;
+  static constexpr auto traps = numeric_limits<float>::traps;
+  static constexpr auto tinyness_before =
+      numeric_limits<float>::tinyness_before;
+
+  static constexpr c10::Float8_e5m2 min() {
+    return c10::Float8_e5m2(0x4, c10::Float8_e5m2::from_bits());
+  }
+  static constexpr c10::Float8_e5m2 max() {
+    return c10::Float8_e5m2(0x7B, c10::Float8_e5m2::from_bits());
+  }
+  static constexpr c10::Float8_e5m2 lowest() {
+    return c10::Float8_e5m2(0xFB, c10::Float8_e5m2::from_bits());
+  }
+  static constexpr c10::Float8_e5m2 epsilon() {
+    return c10::Float8_e5m2(0x34, c10::Float8_e5m2::from_bits());
+  }
+  static constexpr c10::Float8_e5m2 round_error() {
+    return c10::Float8_e5m2(0x38, c10::Float8_e5m2::from_bits());
+  }
+  static constexpr c10::Float8_e5m2 infinity() {
+    return c10::Float8_e5m2(0x7C, c10::Float8_e5m2::from_bits());
+  }
+  static constexpr c10::Float8_e5m2 denorm_min() {
+    return c10::Float8_e5m2(0x01, c10::Float8_e5m2::from_bits());
+  }
+};
+
+} // namespace std
+
+C10_CLANG_DIAGNOSTIC_POP()
diff --git a/c10/util/Float8_e5m2.cpp b/c10/util/Float8_e5m2.cpp
new file mode 100644
index 0000000..edbd99f
--- /dev/null
+++ b/c10/util/Float8_e5m2.cpp
@@ -0,0 +1,14 @@
+#include <c10/util/Float8_e5m2.h>
+#include <iostream>
+
+namespace c10 {
+
+static_assert(
+    std::is_standard_layout<Float8_e5m2>::value,
+    "c10::Float8_e5m2 must be standard layout.");
+
+std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value) {
+  out << (float)value;
+  return out;
+}
+} // namespace c10
diff --git a/c10/util/Float8_e5m2.h b/c10/util/Float8_e5m2.h
new file mode 100644
index 0000000..f01821d
--- /dev/null
+++ b/c10/util/Float8_e5m2.h
@@ -0,0 +1,143 @@
+#pragma once
+
+/// Defines the Float8_e5m2 type (8-bit floating-point) including conversions
+/// to standard C types and basic arithmetic operations. Note that arithmetic
+/// operations are implemented by converting to floating point and
+/// performing the operation in float32.
+/// Binary configuration:
+/// s eeeee mm
+/// 1 sign bit
+/// 5 exponent bits
+/// 2 mantissa bits
+/// bias = 15
+///
+/// Implementation based on the paper https://arxiv.org/pdf/2209.05433.pdf
+/// and inspired by Half implementation from pytorch/c10/util/Half.h
+
+#include <c10/util/Half.h>
+
+namespace c10 {
+
+namespace detail {
+
+/*
+ * Convert a 8-bit floating-point number in fp8 E5M2 format, in bit
+ * representation, to a 32-bit floating-point number in IEEE single-precision
+ * format, in bit representation.
+ *
+ * @note The implementation doesn't use any floating-point operations.
+ */
+inline C10_HOST_DEVICE float fp8e5m2_to_fp32_value(uint8_t input) {
+  /*
+   * Extend the fp8 E5M2 number to 32 bits and shift to the
+   * upper part of the 32-bit word:
+   *      +---+----+---+-----------------------------+
+   *      | S |EEEEE|MM|0000 0000 0000 0000 0000 0000|
+   *      +---+----+---+-----------------------------+
+   * Bits  31 26-30 24-25          0-23
+   *
+   * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, 0
+   * - zero bits.
+   */
+  uint16_t half_representation = input;
+  half_representation <<= 8;
+  return fp16_ieee_to_fp32_value(half_representation);
+}
+
+/*
+ * Convert a 32-bit floating-point number in IEEE single-precision format to a
+ * 8-bit floating-point number in fp8 E5M2 format, in bit representation.
+ */
+inline C10_HOST_DEVICE uint8_t fp8e5m2_from_fp32_value(float f) {
+  /*
+   * Binary representation of fp32 infinity
+   * 0 11111111 00000000000000000000000
+   */
+  constexpr uint32_t fp32_inf = UINT32_C(255) << 23;
+
+  /*
+   * Binary representation of 65536.0f, which is the first value
+   * not representable in fp8e5m2 range:
+   * 0 11111 00 - fp8e5m2
+   * 0 10001111 00000000000000000000000 - fp32
+   */
+  constexpr uint32_t fp8_max = UINT32_C(143) << 23;
+
+  /*
+   * A mask for converting fp32 numbers lower than fp8e5m2 normal range
+   * into denorm representation
+   * magic number: ((127 - 15) + (23 - 2) + 1)
+   */
+  constexpr uint32_t denorm_mask = UINT32_C(134) << 23;
+
+  uint32_t f_bits = fp32_to_bits(f);
+  uint8_t result = 0u;
+
+  /*
+   * Extract the sign of the input number into the high bit of the 32-bit word:
+   *
+   *      +---+----------------------------------+
+   *      | S |0000000 00000000 00000000 00000000|
+   *      +---+----------------------------------+
+   * Bits  31                 0-31
+   */
+  const uint32_t sign = f_bits & UINT32_C(0x80000000);
+
+  /*
+   * Set sign bit to 0
+   */
+  f_bits ^= sign;
+
+  if (f_bits >= fp8_max) {
+    // NaN - all exponent and mantissa bits set to 1
+    result = f_bits > fp32_inf ? UINT8_C(0x7F) : UINT8_C(0x7C);
+  } else {
+    if (f_bits < (UINT32_C(113) << 23)) {
+      // Input number is smaller than 2^(-14), which is the smallest
+      // fp8e5m2 normal number
+      f_bits =
+          fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask));
+      result = static_cast<uint8_t>(f_bits - denorm_mask);
+    } else {
+      // resulting mantissa is odd
+      uint32_t mant_odd = (f_bits >> 21) & 1;
+
+      // update exponent, rounding bias part 1
+      f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
+
+      // rounding bias part 2
+      f_bits += mant_odd;
+
+      // take the bits!
+      result = static_cast<uint8_t>(f_bits >> 21);
+    }
+  }
+
+  result |= static_cast<uint8_t>(sign >> 24);
+  return result;
+}
+
+} // namespace detail
+
+struct alignas(1) Float8_e5m2 {
+  uint8_t x;
+
+  struct from_bits_t {};
+  C10_HOST_DEVICE static constexpr from_bits_t from_bits() {
+    return from_bits_t();
+  }
+
+  Float8_e5m2() = default;
+
+  constexpr C10_HOST_DEVICE Float8_e5m2(uint8_t bits, from_bits_t) : x(bits){};
+  inline C10_HOST_DEVICE Float8_e5m2(float value);
+  inline C10_HOST_DEVICE operator float() const;
+  inline C10_HOST_DEVICE bool isnan() const;
+  inline C10_HOST_DEVICE bool isinf() const;
+};
+
+C10_API std::ostream& operator<<(std::ostream& out, const Float8_e5m2& value);
+
+} // namespace c10
+
+#include <c10/util/Float8_e5m2-inl.h> // IWYU pragma: keep
diff --git a/c10/util/Half.h b/c10/util/Half.h
index bb28d71..9a85daf 100644
--- a/c10/util/Half.h
+++ b/c10/util/Half.h
@@ -13,6 +13,7 @@
 #include <c10/util/C++17.h>
 #include <c10/util/TypeSafeSignMath.h>
 #include <c10/util/complex.h>
+#include <c10/util/floating_point_utils.h>
 #include <type_traits>
 
 #if defined(__cplusplus) && (__cplusplus >= 201103L)
@@ -51,51 +52,12 @@
 #include <sycl/sycl.hpp> // for SYCL 2020
 #endif
 
-// Standard check for compiling CUDA with clang
-#if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
-#define C10_DEVICE_HOST_FUNCTION __device__ __host__
-#else
-#define C10_DEVICE_HOST_FUNCTION
-#endif
-
 #include <typeinfo> // operator typeid
 
 namespace c10 {
 
 namespace detail {
 
-C10_DEVICE_HOST_FUNCTION inline float fp32_from_bits(uint32_t w) {
-#if defined(__OPENCL_VERSION__)
-  return as_float(w);
-#elif defined(__CUDA_ARCH__)
-  return __uint_as_float((unsigned int)w);
-#elif defined(__INTEL_COMPILER)
-  return _castu32_f32(w);
-#else
-  union {
-    uint32_t as_bits;
-    float as_value;
-  } fp32 = {w};
-  return fp32.as_value;
-#endif
-}
-
-C10_DEVICE_HOST_FUNCTION inline uint32_t fp32_to_bits(float f) {
-#if defined(__OPENCL_VERSION__)
-  return as_uint(f);
-#elif defined(__CUDA_ARCH__)
-  return (uint32_t)__float_as_uint(f);
-#elif defined(__INTEL_COMPILER)
-  return _castf32_u32(f);
-#else
-  union {
-    float as_value;
-    uint32_t as_bits;
-  } fp32 = {f};
-  return fp32.as_bits;
-#endif
-}
-
 /*
  * Convert a 16-bit floating-point number in IEEE half-precision format, in bit
  * representation, to a 32-bit floating-point number in IEEE single-precision
@@ -201,7 +163,7 @@
  * mode and no operations on denormals) floating-point operations and bitcasts
  * between integer and floating-point variables.
  */
-inline float fp16_ieee_to_fp32_value(uint16_t h) {
+C10_HOST_DEVICE inline float fp16_ieee_to_fp32_value(uint16_t h) {
   /*
    * Extend the half-precision floating-point number to 32 bits and shift to the
    * upper part of the 32-bit word:
diff --git a/c10/util/TypeCast.h b/c10/util/TypeCast.h
index c24fcae..fab7cfd 100644
--- a/c10/util/TypeCast.h
+++ b/c10/util/TypeCast.h
@@ -1,6 +1,8 @@
 #pragma once
 #include <c10/macros/Macros.h>
 #include <c10/util/BFloat16.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/Half.h>
 
 #include <type_traits>
@@ -78,6 +80,26 @@
 };
 
 template <>
+struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Float8_e5m2> {
+  C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
+      c10::Half>
+  apply(c10::Float8_e5m2 src) {
+    return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
+  }
+};
+
+template <>
+struct static_cast_with_inter_type<
+    c10::complex<c10::Half>,
+    c10::Float8_e4m3fn> {
+  C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
+      c10::Half>
+  apply(c10::Float8_e4m3fn src) {
+    return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
+  }
+};
+
+template <>
 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
       c10::Half>
diff --git a/c10/util/floating_point_utils.h b/c10/util/floating_point_utils.h
new file mode 100644
index 0000000..478e016
--- /dev/null
+++ b/c10/util/floating_point_utils.h
@@ -0,0 +1,39 @@
+#pragma once
+
+#include <cstdint>
+
+namespace c10::detail {
+
+C10_HOST_DEVICE inline float fp32_from_bits(uint32_t w) {
+#if defined(__OPENCL_VERSION__)
+  return as_float(w);
+#elif defined(__CUDA_ARCH__)
+  return __uint_as_float((unsigned int)w);
+#elif defined(__INTEL_COMPILER)
+  return _castu32_f32(w);
+#else
+  union {
+    uint32_t as_bits;
+    float as_value;
+  } fp32 = {w};
+  return fp32.as_value;
+#endif
+}
+
+C10_HOST_DEVICE inline uint32_t fp32_to_bits(float f) {
+#if defined(__OPENCL_VERSION__)
+  return as_uint(f);
+#elif defined(__CUDA_ARCH__)
+  return (uint32_t)__float_as_uint(f);
+#elif defined(__INTEL_COMPILER)
+  return _castf32_u32(f);
+#else
+  union {
+    float as_value;
+    uint32_t as_bits;
+  } fp32 = {f};
+  return fp32.as_bits;
+#endif
+}
+
+} // namespace c10::detail
diff --git a/test/quantization/core/experimental/test_float8.py b/test/quantization/core/experimental/test_float8.py
new file mode 100644
index 0000000..074d0cf
--- /dev/null
+++ b/test/quantization/core/experimental/test_float8.py
@@ -0,0 +1,154 @@
+# Owner(s): ["oncall: quantization"]
+
+import torch
+from torch.testing._internal.common_utils import (
+    TestCase,
+    parametrize,
+    instantiate_parametrized_tests,
+    run_tests,
+)
+
+# Masks for float8 simulation
+
+# 0 11111111 11000000000000000000000b
+MASK_152 = torch.tensor(2145386496, dtype=torch.int)
+# 0 11111111 11100000000000000000000b
+MASK_143 = torch.tensor(2146435072, dtype=torch.int)
+MASK = {
+    torch.float8_e5m2: MASK_152,
+    torch.float8_e4m3fn: MASK_143,
+}
+
+# 0 00000000 00011111111111111111111b
+MASK_ROUND_152 = torch.tensor(1048575, dtype=torch.int)
+# 0 00000000 00001111111111111111111b
+MASK_ROUND_143 = torch.tensor(524287, dtype=torch.int)
+MASK_ROUND = {
+    torch.float8_e5m2: MASK_ROUND_152,
+    torch.float8_e4m3fn: MASK_ROUND_143,
+}
+
+FP8_MAX_152 = torch.tensor(57344, dtype=torch.float)
+FP8_MAX_143 = torch.tensor(448, dtype=torch.float)
+FP8_MAX = {torch.float8_e5m2: FP8_MAX_152, torch.float8_e4m3fn: FP8_MAX_143}
+
+SPECIAL_NUMBERS = {
+    torch.float8_e5m2: [
+        ("01111100", float("inf"), "inf"),
+        ("11111100", -1.0 * float("inf"), "neg_inf"),
+        ("01111101", float("nan"), "nan"),
+        ("11111101", float("nan"), "nan"),
+        ("01111110", float("nan"), "nan"),
+        ("11111110", float("nan"), "nan"),
+        ("01111111", float("nan"), "nan"),
+        ("11111111", float("nan"), "nan"),
+        ("00000000", 0.0, "zero"),
+        ("10000000", -0.0, "neg_zero"),
+        ("01111011", 57344.0, "max_normal"),
+        ("11111011", -57344.0, "neg_max_normal"),
+        ("00000100", 2**-14, "min_normal"),
+        ("10000100", -1 * (2**-14), "neg_min_normal"),
+        ("00000011", 0.75 * (2**-14), "max_subnorm"),
+        ("10000011", -0.75 * (2**-14), "neg_max_subnorm"),
+        ("00000001", 2**-16, "min_subnorm"),
+        ("10000001", -1 * (2**-16), "neg_min_subnorm"),
+    ],
+    torch.float8_e4m3fn: [
+        ("01111111", float("nan"), "nan"),
+        ("11111111", float("nan"), "nan"),
+        ("00000000", 0.0, "zero"),
+        ("10000000", -0.0, "neg_zero"),
+        ("01111110", 448.0, "max_normal"),
+        ("11111110", -448.0, "neg_max_normal"),
+        ("00001000", 2**-6, "min_normal"),
+        ("10001000", -1 * (2**-6), "neg_min_normal"),
+        ("00000111", 0.875 * (2**-6), "max_subnorm"),
+        ("10000111", -0.875 * (2**-6), "neg_max_subnorm"),
+        ("00000001", 2**-9, "min_subnorm"),
+        ("10000001", -1 * (2**-9), "neg_min_subnorm"),
+    ],
+}
+
+
+def simulateFp8Precision(input, variant):
+    dtype = torch.float
+    int_type = torch.int
+    mask = MASK[variant]
+    mask_round = MASK_ROUND[variant]
+    excessive_bits = torch.tensor(21, dtype=int_type)
+
+    signs = torch.where(input < 0.0, -1.0, 1.0).to(dtype)
+    asInt = torch.bitwise_and(input.view(int_type), 2147483647)
+
+    mant_odd = torch.bitwise_and(
+        torch.bitwise_right_shift(asInt, excessive_bits),
+        torch.tensor(1, dtype=int_type),
+    )
+    asInt_masked = asInt + mask_round
+    asInt_odded = asInt_masked + mant_odd
+    masked = torch.bitwise_and(asInt_odded, mask)
+    return masked.view(dtype) * signs
+
+
+class TestFloat8Dtype(TestCase):
+    """
+    Sanity test for zeros comparison
+    """
+    @parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
+    def test_creation_with_zeros(self, dtype):
+        x = torch.zeros(8, dtype=torch.float)
+        x8 = torch.zeros(8, dtype=dtype)
+        self.assertEqual(x, x8.float())
+
+    """
+        Numerical test of float8 conversion
+    """
+    @parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
+    def test_cast_to_float8(self, dtype):
+        x = torch.rand((100, 100)) * FP8_MAX[dtype]
+        x = torch.cat((x, -x))
+        x8 = x.to(dtype)
+        x8_simulated = simulateFp8Precision(x, dtype)
+        self.assertEqual(x8_simulated, x8.float())
+
+    """
+        Test of mul implementation
+    """
+    @parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
+    def test_mul(self, dtype):
+        shape = (10, 10)
+        a = torch.randn(shape)
+        a8_simulated = simulateFp8Precision(a, dtype)
+        a8 = a.to(dtype)
+        b = torch.randn(shape)
+        b8_simulated = simulateFp8Precision(b, dtype)
+        b8 = b.to(dtype)
+        mul8 = a8 * b8
+        mul8_simulated = (a8_simulated * b8_simulated).to(dtype)
+        self.assertEqual(mul8, mul8_simulated)
+
+    """
+        Test special numbers
+    """
+    @parametrize("dtype", [torch.float8_e5m2, torch.float8_e4m3fn])
+    def test_special_numbers(self, dtype):
+        def compare_binary_with_decimal(binary, decimal, number_name, dtype):
+            bits_int = int(binary, 2)
+            tensor_int = torch.tensor([bits_int], dtype=torch.uint8)
+            tensor_fp8 = tensor_int.view(dtype)
+            if number_name == "nan":
+                assert tensor_fp8.isnan()
+            else:
+                tensor_fp32 = tensor_fp8.float()
+                ref_tensor_fp32 = torch.tensor([decimal], dtype=torch.float)
+                self.assertEqual(tensor_fp32, ref_tensor_fp32)
+
+        for number in SPECIAL_NUMBERS[dtype]:
+            compare_binary_with_decimal(*number, dtype)
+
+
+instantiate_parametrized_tests(TestFloat8Dtype)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/tools/iwyu/c10.imp b/tools/iwyu/c10.imp
index 1ec23eb..4679e8e 100644
--- a/tools/iwyu/c10.imp
+++ b/tools/iwyu/c10.imp
@@ -4,6 +4,8 @@
 
     { include: [ "<c10/util/BFloat16-inl.h>", private, "<c10/util/BFloat16.h>", public ] },
     { include: [ "<c10/util/Half-inl.h>", private, "<c10/util/Half.h>", public ] },
+    { include: [ "<c10/util/Float8_e5m2-inl.h>", private, "<c10/util/Float8_e5m2.h>", public ] },
+    { include: [ "<c10/util/Float8_e4m3fn-inl.h>", private, "<c10/util/Float8_e4m3fn.h>", public ] },
 
     { include: [ "<c10/util/complex_math.h>", private, "<c10/util/complex.h>", public ] },
     { include: [ "<c10/util/complex_utils.h>", private, "<c10/util/complex.h>", public ] },
diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py
index 7b519ad..a7a1f87 100644
--- a/torch/_tensor_str.py
+++ b/torch/_tensor_str.py
@@ -330,7 +330,12 @@
     if self.is_neg():
         self = self.resolve_neg()
 
-    if self.dtype is torch.float16 or self.dtype is torch.bfloat16:
+    if self.dtype in [
+        torch.float16,
+        torch.bfloat16,
+        torch.float8_e5m2,
+        torch.float8_e4m3fn,
+    ]:
         self = self.float()
 
     if self.dtype is torch.complex32:
diff --git a/torch/csrc/StorageMethods.cpp b/torch/csrc/StorageMethods.cpp
index fa34014..cadad52 100644
--- a/torch/csrc/StorageMethods.cpp
+++ b/torch/csrc/StorageMethods.cpp
@@ -213,15 +213,18 @@
   auto dtype = reinterpret_cast<THPDtype*>(dtype_obj);
   scalar_type = dtype->scalar_type;
 
+  const bool is_endian_independent = (scalar_type == at::kByte) ||
+      (scalar_type == at::kChar) || (scalar_type == at::kFloat8_e5m2) ||
+      (scalar_type == at::kFloat8_e4m3fn);
+
   TORCH_CHECK(
-      (scalar_type == at::kByte) || (scalar_type == at::kChar) ||
-          (byte_order_str != nullptr),
+      is_endian_independent || (byte_order_str != nullptr),
       "function missing required argument 'byte_order' (pos 2)");
   size_t element_size = c10::elementSize(scalar_type);
 
   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
   bool do_byte_swap;
-  if (scalar_type != at::kByte && scalar_type != at::kChar) {
+  if (!is_endian_independent) {
     if (strcmp(byte_order_str, "native") == 0) {
       do_byte_swap = false;
     } else if (strcmp(byte_order_str, "big") == 0) {
@@ -292,7 +295,7 @@
       c10::GetDefaultCPUAllocator(),
       /*resizable=*/true);
 
-  if (scalar_type == at::kByte || scalar_type == at::kChar) {
+  if (is_endian_independent) {
     memcpy(storage->mutable_data(), src + offset, count);
   } else if (scalar_type == at::kBool) {
     // Because of ASAN checks, that are failing whenever
diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp
index f1bdd6a..c01504e 100644
--- a/torch/csrc/jit/tensorexpr/types.cpp
+++ b/torch/csrc/jit/tensorexpr/types.cpp
@@ -14,7 +14,13 @@
 // NOLINTNEXTLINE
 #define DTYPE_DEFINE(_1, n) TORCH_API Dtype k##n(ScalarType::n, 1);
 
-AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, DTYPE_DEFINE)
+AT_FORALL_SCALAR_TYPES_AND5(
+    Bool,
+    Half,
+    BFloat16,
+    Float8_e5m2,
+    Float8_e4m3fn,
+    DTYPE_DEFINE)
 DTYPE_DEFINE(c10::quint8, QUInt8);
 DTYPE_DEFINE(c10::qint8, QInt8);
 
@@ -28,7 +34,8 @@
 #define TYPE_CASE(_1, n) \
   case ScalarType::n:    \
     return k##n;
-    AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE)
+    AT_FORALL_SCALAR_TYPES_AND5(
+        Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE)
     TYPE_CASE(c10::quint8, QUInt8);
     TYPE_CASE(c10::qint8, QInt8);
 #undef TYPE_CASE
@@ -58,7 +65,8 @@
     scalar_size = sizeof(Type); \
     break;
 
-    AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TYPE_CASE);
+    AT_FORALL_SCALAR_TYPES_AND5(
+        Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE);
     TYPE_CASE(c10::quint8, QUInt8);
     TYPE_CASE(c10::qint8, QInt8);
 #undef TYPE_CASE
@@ -83,6 +91,10 @@
       return "half";
     case ScalarType::BFloat16:
       return "bfloat16";
+    case ScalarType::Float8_e5m2:
+      return "float8_e5m2";
+    case ScalarType::Float8_e4m3fn:
+      return "float8_e4m3fn";
     case ScalarType::QInt8:
       return "qint8";
     case ScalarType::QUInt8:
diff --git a/torch/csrc/utils/byte_order.h b/torch/csrc/utils/byte_order.h
index 5732a04..1b60216 100644
--- a/torch/csrc/utils/byte_order.h
+++ b/torch/csrc/utils/byte_order.h
@@ -1,6 +1,8 @@
 #pragma once
 
 #include <c10/util/BFloat16.h>
+#include <c10/util/Float8_e4m3fn.h>
+#include <c10/util/Float8_e5m2.h>
 #include <c10/util/Half.h>
 #include <torch/csrc/Export.h>
 #include <cstddef>
@@ -154,6 +156,14 @@
     const uint8_t* src,
     THPByteOrder order,
     size_t len);
+TORCH_API void THP_decodeFloat8_e5m2Buffer(
+    at::Float8_e5m2* dst,
+    const uint8_t* src,
+    size_t len);
+TORCH_API void THP_decodeFloat8_e4m3fnBuffer(
+    at::Float8_e4m3fn* dst,
+    const uint8_t* src,
+    size_t len);
 TORCH_API void THP_decodeComplexFloatBuffer(
     c10::complex<float>* dst,
     const uint8_t* src,
diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h
index ff766b5..eb560b5 100644
--- a/torch/csrc/utils/python_scalars.h
+++ b/torch/csrc/utils/python_scalars.h
@@ -70,6 +70,14 @@
       *(at::BFloat16*)data =
           at::convert<at::BFloat16, double>(THPUtils_unpackDouble(obj));
       break;
+    case at::kFloat8_e5m2:
+      *(at::Float8_e5m2*)data =
+          at::convert<at::Float8_e5m2, double>(THPUtils_unpackDouble(obj));
+      break;
+    case at::kFloat8_e4m3fn:
+      *(at::Float8_e4m3fn*)data =
+          at::convert<at::Float8_e4m3fn, double>(THPUtils_unpackDouble(obj));
+      break;
     default:
       throw std::runtime_error("invalid type");
   }
@@ -110,6 +118,12 @@
     case at::kBFloat16:
       return PyFloat_FromDouble(
           at::convert<double, at::BFloat16>(*(at::BFloat16*)data));
+    case at::kFloat8_e5m2:
+      return PyFloat_FromDouble(
+          at::convert<double, at::Float8_e5m2>(*(at::Float8_e5m2*)data));
+    case at::kFloat8_e4m3fn:
+      return PyFloat_FromDouble(
+          at::convert<double, at::Float8_e4m3fn>(*(at::Float8_e4m3fn*)data));
     default:
       throw std::runtime_error("invalid type");
   }
diff --git a/torch/csrc/utils/tensor_dtypes.cpp b/torch/csrc/utils/tensor_dtypes.cpp
index 84d7566..9fd9fa3 100644
--- a/torch/csrc/utils/tensor_dtypes.cpp
+++ b/torch/csrc/utils/tensor_dtypes.cpp
@@ -62,6 +62,10 @@
       return std::make_pair("bits8", "");
     case at::ScalarType::Bits16:
       return std::make_pair("bits16", "");
+    case at::ScalarType::Float8_e5m2:
+      return std::make_pair("float8_e5m2", "");
+    case at::ScalarType::Float8_e4m3fn:
+      return std::make_pair("float8_e4m3fn", "");
     default:
       throw std::runtime_error("Unimplemented scalar type");
   }
diff --git a/torchgen/api/types/types.py b/torchgen/api/types/types.py
index 1e56300..29f100e 100644
--- a/torchgen/api/types/types.py
+++ b/torchgen/api/types/types.py
@@ -47,6 +47,8 @@
 complexFloatT = BaseCppType("c10", "complex<float>")
 complexDoubleT = BaseCppType("c10", "complex<double>")
 bfloat16T = BaseCppType("at", "BFloat16")
+float8_e5m2T = BaseCppType("at", "Float8_e5m2")
+float8_e4m3fnT = BaseCppType("at", "Float8_e4m3fn")
 stringT = BaseCppType("c10", "string_view")
 generatorT = BaseCppType("at", "Generator")
 scalarTypeT = BaseCppType("at", "ScalarType")
@@ -93,7 +95,8 @@
     ScalarType.ComplexFloat: complexFloatT,
     ScalarType.ComplexDouble: complexDoubleT,
     ScalarType.Bool: boolT,
-    ScalarType.BFloat16: bfloat16T,
+    ScalarType.Float8_e5m2: float8_e5m2T,
+    ScalarType.Float8_e4m3fn: float8_e4m3fnT,
 }
 
 BaseTypeToCppMapping: Dict[BaseTy, BaseCppType] = {
diff --git a/torchgen/model.py b/torchgen/model.py
index 0b44732..0deb58a 100644
--- a/torchgen/model.py
+++ b/torchgen/model.py
@@ -298,6 +298,8 @@
     ComplexDouble = auto()
     Bool = auto()
     BFloat16 = auto()
+    Float8_e5m2 = auto()
+    Float8_e4m3fn = auto()
 
     def __str__(self) -> str:
         return self.name