Update BFloat16 implementation to be based on templates to allow different
types.

The motivation is to allow for the easy addition of other similar reduced
precision types.

Rollforward with fix (casts between custom types not supported in older numpys)

PiperOrigin-RevId: 455154215
diff --git a/tensorflow/python/lib/core/BUILD b/tensorflow/python/lib/core/BUILD
index d070623..ce9a903 100644
--- a/tensorflow/python/lib/core/BUILD
+++ b/tensorflow/python/lib/core/BUILD
@@ -48,8 +48,14 @@
 
 cc_library(
     name = "bfloat16_lib",
-    srcs = ["bfloat16.cc"],
-    hdrs = ["bfloat16.h"],
+    srcs = [
+        "bfloat16.cc",
+        "float8_e4m3b11.cc",
+    ],
+    hdrs = [
+        "bfloat16.h",
+        "float8_e4m3b11.h",
+    ],
     deps = [
         ":numpy_lib",
         "//tensorflow/core/platform:logging",
diff --git a/tensorflow/python/lib/core/bfloat16.cc b/tensorflow/python/lib/core/bfloat16.cc
index ab7e2f1..e6e61a0 100644
--- a/tensorflow/python/lib/core/bfloat16.cc
+++ b/tensorflow/python/lib/core/bfloat16.cc
@@ -19,6 +19,8 @@
 #include <cmath>
 #include <limits>
 #include <locale>
+
+#include "tensorflow/python/lib/core/float8_e4m3b11.h"
 // Place `<locale>` before <Python.h> to avoid a build failure in macOS.
 #include <Python.h>
 
@@ -30,8 +32,6 @@
 namespace tensorflow {
 namespace {
 
-using bfloat16 = Eigen::bfloat16;
-
 struct PyDecrefDeleter {
   void operator()(PyObject* p) const { Py_DECREF(p); }
 };
@@ -52,50 +52,75 @@
   return (overflow == 0);
 }
 
-// Registered numpy type ID. Global variable populated by the registration code.
-// Protected by the GIL.
-int npy_bfloat16 = NPY_NOTYPE;
-
-// Forward declaration.
-extern PyTypeObject bfloat16_type;
-
-// Pointer to the bfloat16 type object we are using. This is either a pointer
-// to bfloat16_type, if we choose to register it, or to the bfloat16 type
-// registered by another system into NumPy.
-PyTypeObject* bfloat16_type_ptr = nullptr;
-
-// Representation of a Python bfloat16 object.
-struct PyBfloat16 {
-  PyObject_HEAD;  // Python object header
-  bfloat16 value;
+template <typename T, typename Enable = void>
+struct TypeDescriptor {
+  // typedef ... T;  // Representation type in memory for NumPy values of type
+  // static int Dtype() { return NPY_...; }  // Numpy type number for T.
 };
 
-// Returns true if 'object' is a PyBfloat16.
-bool PyBfloat16_Check(PyObject* object) {
-  return PyObject_IsInstance(object,
-                             reinterpret_cast<PyObject*>(&bfloat16_type));
+template <typename T>
+struct CustomFloatTypeDescriptor {
+  static int Dtype() { return npy_type; }
+
+  // Registered numpy type ID. Global variable populated by the registration
+  // code. Protected by the GIL.
+  static int npy_type;
+
+  static PyTypeObject type;
+  // Pointer to the python type object we are using. This is either a pointer
+  // to type, if we choose to register it, or to the python type
+  // registered by another system into NumPy.
+  static PyTypeObject* type_ptr;
+
+  static PyNumberMethods number_methods;
+
+  static PyArray_ArrFuncs arr_funcs;
+
+  static PyArray_Descr npy_descr;
+};
+template <typename T>
+int CustomFloatTypeDescriptor<T>::npy_type = NPY_NOTYPE;
+template <typename T>
+PyTypeObject* CustomFloatTypeDescriptor<T>::type_ptr = nullptr;
+
+// Representation of a Python custom float object.
+template <typename T>
+struct PyCustomFloat {
+  PyObject_HEAD;  // Python object header
+  T value;
+};
+
+// Returns true if 'object' is a PyCustomFloat.
+template <typename T>
+bool PyCustomFloat_Check(PyObject* object) {
+  return PyObject_IsInstance(
+      object, reinterpret_cast<PyObject*>(&TypeDescriptor<T>::type));
 }
 
-// Extracts the value of a PyBfloat16 object.
-bfloat16 PyBfloat16_Bfloat16(PyObject* object) {
-  return reinterpret_cast<PyBfloat16*>(object)->value;
+// Extracts the value of a PyCustomFloat object.
+template <typename T>
+T PyCustomFloat_CustomFloat(PyObject* object) {
+  return reinterpret_cast<PyCustomFloat<T>*>(object)->value;
 }
 
-// Constructs a PyBfloat16 object from a bfloat16.
-Safe_PyObjectPtr PyBfloat16_FromBfloat16(bfloat16 x) {
-  Safe_PyObjectPtr ref = make_safe(bfloat16_type.tp_alloc(&bfloat16_type, 0));
-  PyBfloat16* p = reinterpret_cast<PyBfloat16*>(ref.get());
+// Constructs a PyCustomFloat object from PyCustomFloat<T>::T.
+template <typename T>
+Safe_PyObjectPtr PyCustomFloat_FromT(T x) {
+  Safe_PyObjectPtr ref =
+      make_safe(TypeDescriptor<T>::type.tp_alloc(&TypeDescriptor<T>::type, 0));
+  PyCustomFloat<T>* p = reinterpret_cast<PyCustomFloat<T>*>(ref.get());
   if (p) {
     p->value = x;
   }
   return ref;
 }
 
-// Converts a Python object to a bfloat16 value. Returns true on success,
+// Converts a Python object to a reduced float value. Returns true on success,
 // returns false and reports a Python error on failure.
-bool CastToBfloat16(PyObject* arg, bfloat16* output) {
-  if (PyBfloat16_Check(arg)) {
-    *output = PyBfloat16_Bfloat16(arg);
+template <typename T>
+bool CastToCustomFloat(PyObject* arg, T* output) {
+  if (PyCustomFloat_Check<T>(arg)) {
+    *output = PyCustomFloat_CustomFloat<T>(arg);
     return true;
   }
   if (PyFloat_Check(arg)) {
@@ -104,7 +129,7 @@
       return false;
     }
     // TODO(phawkins): check for overflow
-    *output = bfloat16(d);
+    *output = T(d);
     return true;
   }
   if (PyLong_CheckNoOverflow(arg)) {
@@ -113,130 +138,139 @@
       return false;
     }
     // TODO(phawkins): check for overflow
-    *output = bfloat16(static_cast<float>(l));
+    *output = T(static_cast<float>(l));
     return true;
   }
   if (PyArray_IsScalar(arg, Half)) {
     Eigen::half f;
     PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
+    *output = T(f);
     return true;
   }
   if (PyArray_IsScalar(arg, Float)) {
     float f;
     PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
+    *output = T(f);
     return true;
   }
   if (PyArray_IsScalar(arg, Double)) {
     double f;
     PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
+    *output = T(f);
     return true;
   }
   if (PyArray_IsScalar(arg, LongDouble)) {
     long double f;
     PyArray_ScalarAsCtype(arg, &f);
-    *output = bfloat16(f);
+    *output = T(f);
     return true;
   }
   if (PyArray_IsZeroDim(arg)) {
     Safe_PyObjectPtr ref;
     PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
-    if (PyArray_TYPE(arr) != npy_bfloat16) {
-      ref = make_safe(PyArray_Cast(arr, npy_bfloat16));
+    if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
+      ref = make_safe(PyArray_Cast(arr, TypeDescriptor<T>::Dtype()));
       if (PyErr_Occurred()) {
         return false;
       }
       arg = ref.get();
       arr = reinterpret_cast<PyArrayObject*>(arg);
     }
-    *output = *reinterpret_cast<bfloat16*>(PyArray_DATA(arr));
+    *output = *reinterpret_cast<T*>(PyArray_DATA(arr));
     return true;
   }
   return false;
 }
 
-bool SafeCastToBfloat16(PyObject* arg, bfloat16* output) {
-  if (PyBfloat16_Check(arg)) {
-    *output = PyBfloat16_Bfloat16(arg);
+template <typename T>
+bool SafeCastToCustomFloat(PyObject* arg, T* output) {
+  if (PyCustomFloat_Check<T>(arg)) {
+    *output = PyCustomFloat_CustomFloat<T>(arg);
     return true;
   }
   return false;
 }
 
-// Converts a PyBfloat16 into a PyFloat.
-PyObject* PyBfloat16_Float(PyObject* self) {
-  bfloat16 x = PyBfloat16_Bfloat16(self);
-  return PyFloat_FromDouble(static_cast<double>(x));
+// Converts a PyReduceFloat into a PyFloat.
+template <typename T>
+PyObject* PyCustomFloat_Float(PyObject* self) {
+  T x = PyCustomFloat_CustomFloat<T>(self);
+  return PyFloat_FromDouble(static_cast<double>(static_cast<float>(x)));
 }
 
-// Converts a PyBfloat16 into a PyInt.
-PyObject* PyBfloat16_Int(PyObject* self) {
-  bfloat16 x = PyBfloat16_Bfloat16(self);
-  long y = static_cast<long>(x);  // NOLINT
+// Converts a PyReduceFloat into a PyInt.
+template <typename T>
+PyObject* PyCustomFloat_Int(PyObject* self) {
+  T x = PyCustomFloat_CustomFloat<T>(self);
+  long y = static_cast<long>(static_cast<float>(x));  // NOLINT
   return PyLong_FromLong(y);
 }
 
-// Negates a PyBfloat16.
-PyObject* PyBfloat16_Negative(PyObject* self) {
-  bfloat16 x = PyBfloat16_Bfloat16(self);
-  return PyBfloat16_FromBfloat16(-x).release();
+// Negates a PyCustomFloat.
+template <typename T>
+PyObject* PyCustomFloat_Negative(PyObject* self) {
+  T x = PyCustomFloat_CustomFloat<T>(self);
+  return PyCustomFloat_FromT<T>(-x).release();
 }
 
-PyObject* PyBfloat16_Add(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x + y).release();
+template <typename T>
+PyObject* PyCustomFloat_Add(PyObject* a, PyObject* b) {
+  T x, y;
+  if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
+    return PyCustomFloat_FromT<T>(x + y).release();
   }
   return PyArray_Type.tp_as_number->nb_add(a, b);
 }
 
-PyObject* PyBfloat16_Subtract(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x - y).release();
+template <typename T>
+PyObject* PyCustomFloat_Subtract(PyObject* a, PyObject* b) {
+  T x, y;
+  if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
+    return PyCustomFloat_FromT<T>(x - y).release();
   }
   return PyArray_Type.tp_as_number->nb_subtract(a, b);
 }
 
-PyObject* PyBfloat16_Multiply(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x * y).release();
+template <typename T>
+PyObject* PyCustomFloat_Multiply(PyObject* a, PyObject* b) {
+  T x, y;
+  if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
+    return PyCustomFloat_FromT<T>(x * y).release();
   }
   return PyArray_Type.tp_as_number->nb_multiply(a, b);
 }
 
-PyObject* PyBfloat16_TrueDivide(PyObject* a, PyObject* b) {
-  bfloat16 x, y;
-  if (SafeCastToBfloat16(a, &x) && SafeCastToBfloat16(b, &y)) {
-    return PyBfloat16_FromBfloat16(x / y).release();
+template <typename T>
+PyObject* PyCustomFloat_TrueDivide(PyObject* a, PyObject* b) {
+  T x, y;
+  if (SafeCastToCustomFloat<T>(a, &x) && SafeCastToCustomFloat<T>(b, &y)) {
+    return PyCustomFloat_FromT<T>(x / y).release();
   }
   return PyArray_Type.tp_as_number->nb_true_divide(a, b);
 }
 
-// Python number methods for PyBfloat16 objects.
-PyNumberMethods PyBfloat16_AsNumber = {
-    PyBfloat16_Add,       // nb_add
-    PyBfloat16_Subtract,  // nb_subtract
-    PyBfloat16_Multiply,  // nb_multiply
-    nullptr,              // nb_remainder
-    nullptr,              // nb_divmod
-    nullptr,              // nb_power
-    PyBfloat16_Negative,  // nb_negative
-    nullptr,              // nb_positive
-    nullptr,              // nb_absolute
-    nullptr,              // nb_nonzero
-    nullptr,              // nb_invert
-    nullptr,              // nb_lshift
-    nullptr,              // nb_rshift
-    nullptr,              // nb_and
-    nullptr,              // nb_xor
-    nullptr,              // nb_or
-    PyBfloat16_Int,       // nb_int
-    nullptr,              // reserved
-    PyBfloat16_Float,     // nb_float
+// Python number methods for PyCustomFloat objects.
+template <typename T>
+PyNumberMethods CustomFloatTypeDescriptor<T>::number_methods = {
+    PyCustomFloat_Add<T>,       // nb_add
+    PyCustomFloat_Subtract<T>,  // nb_subtract
+    PyCustomFloat_Multiply<T>,  // nb_multiply
+    nullptr,                    // nb_remainder
+    nullptr,                    // nb_divmod
+    nullptr,                    // nb_power
+    PyCustomFloat_Negative<T>,  // nb_negative
+    nullptr,                    // nb_positive
+    nullptr,                    // nb_absolute
+    nullptr,                    // nb_nonzero
+    nullptr,                    // nb_invert
+    nullptr,                    // nb_lshift
+    nullptr,                    // nb_rshift
+    nullptr,                    // nb_and
+    nullptr,                    // nb_xor
+    nullptr,                    // nb_or
+    PyCustomFloat_Int<T>,       // nb_int
+    nullptr,                    // reserved
+    PyCustomFloat_Float<T>,     // nb_float
 
     nullptr,  // nb_inplace_add
     nullptr,  // nb_inplace_subtract
@@ -249,37 +283,40 @@
     nullptr,  // nb_inplace_xor
     nullptr,  // nb_inplace_or
 
-    nullptr,                // nb_floor_divide
-    PyBfloat16_TrueDivide,  // nb_true_divide
-    nullptr,                // nb_inplace_floor_divide
-    nullptr,                // nb_inplace_true_divide
-    nullptr,                // nb_index
+    nullptr,                      // nb_floor_divide
+    PyCustomFloat_TrueDivide<T>,  // nb_true_divide
+    nullptr,                      // nb_inplace_floor_divide
+    nullptr,                      // nb_inplace_true_divide
+    nullptr,                      // nb_index
 };
 
-// Constructs a new PyBfloat16.
-PyObject* PyBfloat16_New(PyTypeObject* type, PyObject* args, PyObject* kwds) {
+// Constructs a new PyCustomFloat.
+template <typename T>
+PyObject* PyCustomFloat_New(PyTypeObject* type, PyObject* args,
+                            PyObject* kwds) {
   if (kwds && PyDict_Size(kwds)) {
     PyErr_SetString(PyExc_TypeError, "constructor takes no keyword arguments");
     return nullptr;
   }
   Py_ssize_t size = PyTuple_Size(args);
   if (size != 1) {
-    PyErr_SetString(PyExc_TypeError,
-                    "expected number as argument to bfloat16 constructor");
+    PyErr_Format(PyExc_TypeError,
+                 "expected number as argument to %s constructor",
+                 TypeDescriptor<T>::kTypeName);
     return nullptr;
   }
   PyObject* arg = PyTuple_GetItem(args, 0);
 
-  bfloat16 value;
-  if (PyBfloat16_Check(arg)) {
+  T value;
+  if (PyCustomFloat_Check<T>(arg)) {
     Py_INCREF(arg);
     return arg;
-  } else if (CastToBfloat16(arg, &value)) {
-    return PyBfloat16_FromBfloat16(value).release();
+  } else if (CastToCustomFloat<T>(arg, &value)) {
+    return PyCustomFloat_FromT<T>(value).release();
   } else if (PyArray_Check(arg)) {
     PyArrayObject* arr = reinterpret_cast<PyArrayObject*>(arg);
-    if (PyArray_TYPE(arr) != npy_bfloat16) {
-      return PyArray_Cast(arr, npy_bfloat16);
+    if (PyArray_TYPE(arr) != TypeDescriptor<T>::Dtype()) {
+      return PyArray_Cast(arr, TypeDescriptor<T>::Dtype());
     } else {
       Py_INCREF(arg);
       return arg;
@@ -290,10 +327,11 @@
   return nullptr;
 }
 
-// Comparisons on PyBfloat16s.
-PyObject* PyBfloat16_RichCompare(PyObject* a, PyObject* b, int op) {
-  bfloat16 x, y;
-  if (!SafeCastToBfloat16(a, &x) || !SafeCastToBfloat16(b, &y)) {
+// Comparisons on PyCustomFloats.
+template <typename T>
+PyObject* PyCustomFloat_RichCompare(PyObject* a, PyObject* b, int op) {
+  T x, y;
+  if (!SafeCastToCustomFloat<T>(a, &x) || !SafeCastToCustomFloat<T>(b, &y)) {
     return PyGenericArrType_Type.tp_richcompare(a, b, op);
   }
   bool result;
@@ -322,16 +360,18 @@
   return PyBool_FromLong(result);
 }
 
-// Implementation of repr() for PyBfloat16.
-PyObject* PyBfloat16_Repr(PyObject* self) {
-  bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
+// Implementation of repr() for PyCustomFloat.
+template <typename T>
+PyObject* PyCustomFloat_Repr(PyObject* self) {
+  T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
   std::string v = absl::StrCat(static_cast<float>(x));
   return PyUnicode_FromString(v.c_str());
 }
 
-// Implementation of str() for PyBfloat16.
-PyObject* PyBfloat16_Str(PyObject* self) {
-  bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
+// Implementation of str() for PyCustomFloat.
+template <typename T>
+PyObject* PyCustomFloat_Str(PyObject* self) {
+  T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
   std::string v = absl::StrCat(static_cast<float>(x));
   return PyUnicode_FromString(v.c_str());
 }
@@ -350,93 +390,89 @@
   return hash_double(value);
 }
 
-// Hash function for PyBfloat16.
-Py_hash_t PyBfloat16_Hash(PyObject* self) {
-  bfloat16 x = reinterpret_cast<PyBfloat16*>(self)->value;
+// Hash function for PyCustomFloat.
+template <typename T>
+Py_hash_t PyCustomFloat_Hash(PyObject* self) {
+  T x = reinterpret_cast<PyCustomFloat<T>*>(self)->value;
   return HashImpl(&_Py_HashDouble, self, static_cast<double>(x));
 }
 
-// Python type for PyBfloat16 objects.
-PyTypeObject bfloat16_type = {
-    PyVarObject_HEAD_INIT(nullptr, 0) "bfloat16",  // tp_name
-    sizeof(PyBfloat16),                            // tp_basicsize
-    0,                                             // tp_itemsize
-    nullptr,                                       // tp_dealloc
+// Python type for PyCustomFloat objects.
+template <typename T>
+PyTypeObject CustomFloatTypeDescriptor<T>::type = {
+    PyVarObject_HEAD_INIT(nullptr, 0) TypeDescriptor<T>::kTypeName,  // tp_name
+    sizeof(PyCustomFloat<T>),  // tp_basicsize
+    0,                         // tp_itemsize
+    nullptr,                   // tp_dealloc
 #if PY_VERSION_HEX < 0x03080000
     nullptr,  // tp_print
 #else
     0,  // tp_vectorcall_offset
 #endif
-    nullptr,               // tp_getattr
-    nullptr,               // tp_setattr
-    nullptr,               // tp_compare / tp_reserved
-    PyBfloat16_Repr,       // tp_repr
-    &PyBfloat16_AsNumber,  // tp_as_number
-    nullptr,               // tp_as_sequence
-    nullptr,               // tp_as_mapping
-    PyBfloat16_Hash,       // tp_hash
-    nullptr,               // tp_call
-    PyBfloat16_Str,        // tp_str
-    nullptr,               // tp_getattro
-    nullptr,               // tp_setattro
-    nullptr,               // tp_as_buffer
-                           // tp_flags
+    nullptr,                                        // tp_getattr
+    nullptr,                                        // tp_setattr
+    nullptr,                                        // tp_compare / tp_reserved
+    PyCustomFloat_Repr<T>,                          // tp_repr
+    &CustomFloatTypeDescriptor<T>::number_methods,  // tp_as_number
+    nullptr,                                        // tp_as_sequence
+    nullptr,                                        // tp_as_mapping
+    PyCustomFloat_Hash<T>,                          // tp_hash
+    nullptr,                                        // tp_call
+    PyCustomFloat_Str<T>,                           // tp_str
+    nullptr,                                        // tp_getattro
+    nullptr,                                        // tp_setattro
+    nullptr,                                        // tp_as_buffer
+                                                    // tp_flags
     Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,
-    "bfloat16 floating-point values",  // tp_doc
-    nullptr,                           // tp_traverse
-    nullptr,                           // tp_clear
-    PyBfloat16_RichCompare,            // tp_richcompare
-    0,                                 // tp_weaklistoffset
-    nullptr,                           // tp_iter
-    nullptr,                           // tp_iternext
-    nullptr,                           // tp_methods
-    nullptr,                           // tp_members
-    nullptr,                           // tp_getset
-    nullptr,                           // tp_base
-    nullptr,                           // tp_dict
-    nullptr,                           // tp_descr_get
-    nullptr,                           // tp_descr_set
-    0,                                 // tp_dictoffset
-    nullptr,                           // tp_init
-    nullptr,                           // tp_alloc
-    PyBfloat16_New,                    // tp_new
-    nullptr,                           // tp_free
-    nullptr,                           // tp_is_gc
-    nullptr,                           // tp_bases
-    nullptr,                           // tp_mro
-    nullptr,                           // tp_cache
-    nullptr,                           // tp_subclasses
-    nullptr,                           // tp_weaklist
-    nullptr,                           // tp_del
-    0,                                 // tp_version_tag
+    TypeDescriptor<T>::kTpDoc,     // tp_doc
+    nullptr,                       // tp_traverse
+    nullptr,                       // tp_clear
+    PyCustomFloat_RichCompare<T>,  // tp_richcompare
+    0,                             // tp_weaklistoffset
+    nullptr,                       // tp_iter
+    nullptr,                       // tp_iternext
+    nullptr,                       // tp_methods
+    nullptr,                       // tp_members
+    nullptr,                       // tp_getset
+    nullptr,                       // tp_base
+    nullptr,                       // tp_dict
+    nullptr,                       // tp_descr_get
+    nullptr,                       // tp_descr_set
+    0,                             // tp_dictoffset
+    nullptr,                       // tp_init
+    nullptr,                       // tp_alloc
+    PyCustomFloat_New<T>,          // tp_new
+    nullptr,                       // tp_free
+    nullptr,                       // tp_is_gc
+    nullptr,                       // tp_bases
+    nullptr,                       // tp_mro
+    nullptr,                       // tp_cache
+    nullptr,                       // tp_subclasses
+    nullptr,                       // tp_weaklist
+    nullptr,                       // tp_del
+    0,                             // tp_version_tag
 };
 
 // Numpy support
+template <typename T>
+PyArray_ArrFuncs CustomFloatTypeDescriptor<T>::arr_funcs;
 
-PyArray_ArrFuncs NPyBfloat16_ArrFuncs;
-
-PyArray_Descr NPyBfloat16_Descr = {
+template <typename T>
+PyArray_Descr CustomFloatTypeDescriptor<T>::npy_descr = {
     PyObject_HEAD_INIT(nullptr)  //
                                  /*typeobj=*/
-    (&bfloat16_type),
-    // We must register bfloat16 with a kind other than "f", because numpy
-    // considers two types with the same kind and size to be equal, but
-    // float16 != bfloat16.
-    // The downside of this is that NumPy scalar promotion does not work with
-    // bfloat16 values.
-    /*kind=*/'V',
-    // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
-    // character is unique.
-    /*type=*/'E',
-    /*byteorder=*/'=',
+    (&TypeDescriptor<T>::type),
+    /*kind=*/TypeDescriptor<T>::kNpyDescrKind,
+    /*type=*/TypeDescriptor<T>::kNpyDescrType,
+    /*byteorder=*/TypeDescriptor<T>::kNpyDescrByteorder,
     /*flags=*/NPY_NEEDS_PYAPI | NPY_USE_GETITEM | NPY_USE_SETITEM,
     /*type_num=*/0,
-    /*elsize=*/sizeof(bfloat16),
-    /*alignment=*/alignof(bfloat16),
+    /*elsize=*/sizeof(T),
+    /*alignment=*/alignof(T),
     /*subarray=*/nullptr,
     /*fields=*/nullptr,
     /*names=*/nullptr,
-    /*f=*/&NPyBfloat16_ArrFuncs,
+    /*f=*/&CustomFloatTypeDescriptor<T>::arr_funcs,
     /*metadata=*/nullptr,
     /*c_metadata=*/nullptr,
     /*hash=*/-1,  // -1 means "not computed yet".
@@ -444,20 +480,22 @@
 
 // Implementations of NumPy array methods.
 
-PyObject* NPyBfloat16_GetItem(void* data, void* arr) {
-  bfloat16 x;
-  memcpy(&x, data, sizeof(bfloat16));
-  return PyBfloat16_FromBfloat16(x).release();
+template <typename T>
+PyObject* NPyCustomFloat_GetItem(void* data, void* arr) {
+  T x;
+  memcpy(&x, data, sizeof(T));
+  return PyCustomFloat_FromT<T>(x).release();
 }
 
-int NPyBfloat16_SetItem(PyObject* item, void* data, void* arr) {
-  bfloat16 x;
-  if (!CastToBfloat16(item, &x)) {
+template <typename T>
+int NPyCustomFloat_SetItem(PyObject* item, void* data, void* arr) {
+  T x;
+  if (!CastToCustomFloat<T>(item, &x)) {
     PyErr_Format(PyExc_TypeError, "expected number, got %s",
                  Py_TYPE(item)->tp_name);
     return -1;
   }
-  memcpy(data, &x, sizeof(bfloat16));
+  memcpy(data, &x, sizeof(T));
   return 0;
 }
 
@@ -466,96 +504,110 @@
   std::swap(p[0], p[1]);
 }
 
-int NPyBfloat16_Compare(const void* a, const void* b, void* arr) {
-  bfloat16 x;
-  memcpy(&x, a, sizeof(bfloat16));
+template <typename T>
+int NPyCustomFloat_Compare(const void* a, const void* b, void* arr) {
+  T x;
+  memcpy(&x, a, sizeof(T));
 
-  bfloat16 y;
-  memcpy(&y, b, sizeof(bfloat16));
+  T y;
+  memcpy(&y, b, sizeof(T));
+  float fy(y);
+  float fx(x);
 
-  if (x < y) {
+  if (fx < fy) {
     return -1;
   }
-  if (y < x) {
+  if (fy < fx) {
     return 1;
   }
   // NaNs sort to the end.
-  if (!Eigen::numext::isnan(x) && Eigen::numext::isnan(y)) {
+  if (!Eigen::numext::isnan(fx) && Eigen::numext::isnan(fy)) {
     return -1;
   }
-  if (Eigen::numext::isnan(x) && !Eigen::numext::isnan(y)) {
+  if (Eigen::numext::isnan(fx) && !Eigen::numext::isnan(fy)) {
     return 1;
   }
   return 0;
 }
 
-void NPyBfloat16_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
-                           npy_intp sstride, npy_intp n, int swap, void* arr) {
+template <typename T>
+void NPyCustomFloat_CopySwapN(void* dstv, npy_intp dstride, void* srcv,
+                              npy_intp sstride, npy_intp n, int swap,
+                              void* arr) {
+  static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
+                "Not supported");
   char* dst = reinterpret_cast<char*>(dstv);
   char* src = reinterpret_cast<char*>(srcv);
   if (!src) {
     return;
   }
-  if (swap) {
+  if (swap && sizeof(T) == sizeof(int16_t)) {
     for (npy_intp i = 0; i < n; i++) {
       char* r = dst + dstride * i;
-      memcpy(r, src + sstride * i, sizeof(uint16_t));
+      memcpy(r, src + sstride * i, sizeof(T));
       ByteSwap16(r);
     }
-  } else if (dstride == sizeof(uint16_t) && sstride == sizeof(uint16_t)) {
-    memcpy(dst, src, n * sizeof(uint16_t));
+  } else if (dstride == sizeof(T) && sstride == sizeof(T)) {
+    memcpy(dst, src, n * sizeof(T));
   } else {
     for (npy_intp i = 0; i < n; i++) {
-      memcpy(dst + dstride * i, src + sstride * i, sizeof(uint16_t));
+      memcpy(dst + dstride * i, src + sstride * i, sizeof(T));
     }
   }
 }
 
-void NPyBfloat16_CopySwap(void* dst, void* src, int swap, void* arr) {
+template <typename T>
+void NPyCustomFloat_CopySwap(void* dst, void* src, int swap, void* arr) {
   if (!src) {
     return;
   }
   memcpy(dst, src, sizeof(uint16_t));
-  if (swap) {
+  static_assert(sizeof(T) == sizeof(int16_t) || sizeof(T) == sizeof(int8_t),
+                "Not supported");
+  if (swap && sizeof(T) == sizeof(int16_t)) {
     ByteSwap16(dst);
   }
 }
 
-npy_bool NPyBfloat16_NonZero(void* data, void* arr) {
-  bfloat16 x;
+template <typename T>
+npy_bool NPyCustomFloat_NonZero(void* data, void* arr) {
+  T x;
   memcpy(&x, data, sizeof(x));
-  return x != static_cast<bfloat16>(0);
+  return x != static_cast<T>(0);
 }
 
-int NPyBfloat16_Fill(void* buffer_raw, npy_intp length, void* ignored) {
-  bfloat16* const buffer = reinterpret_cast<bfloat16*>(buffer_raw);
+template <typename T>
+int NPyCustomFloat_Fill(void* buffer_raw, npy_intp length, void* ignored) {
+  T* const buffer = reinterpret_cast<T*>(buffer_raw);
   const float start(buffer[0]);
   const float delta = static_cast<float>(buffer[1]) - start;
   for (npy_intp i = 2; i < length; ++i) {
-    buffer[i] = static_cast<bfloat16>(start + i * delta);
+    buffer[i] = static_cast<T>(start + i * delta);
   }
   return 0;
 }
 
-void NPyBfloat16_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
-                         void* op, npy_intp n, void* arr) {
+template <typename T>
+void NPyCustomFloat_DotFunc(void* ip1, npy_intp is1, void* ip2, npy_intp is2,
+                            void* op, npy_intp n, void* arr) {
   char* c1 = reinterpret_cast<char*>(ip1);
   char* c2 = reinterpret_cast<char*>(ip2);
   float acc = 0.0f;
   for (npy_intp i = 0; i < n; ++i) {
-    bfloat16* const b1 = reinterpret_cast<bfloat16*>(c1);
-    bfloat16* const b2 = reinterpret_cast<bfloat16*>(c2);
+    T* const b1 = reinterpret_cast<T*>(c1);
+    T* const b2 = reinterpret_cast<T*>(c2);
     acc += static_cast<float>(*b1) * static_cast<float>(*b2);
     c1 += is1;
     c2 += is2;
   }
-  bfloat16* out = reinterpret_cast<bfloat16*>(op);
-  *out = static_cast<bfloat16>(acc);
+  T* out = reinterpret_cast<T*>(op);
+  *out = static_cast<T>(acc);
 }
 
-int NPyBfloat16_CompareFunc(const void* v1, const void* v2, void* arr) {
-  bfloat16 b1 = *reinterpret_cast<const bfloat16*>(v1);
-  bfloat16 b2 = *reinterpret_cast<const bfloat16*>(v2);
+template <typename T>
+int NPyCustomFloat_CompareFunc(const void* v1, const void* v2, void* arr) {
+  T b1 = *reinterpret_cast<const T*>(v1);
+  T b2 = *reinterpret_cast<const T*>(v2);
   if (b1 < b2) {
     return -1;
   }
@@ -565,9 +617,10 @@
   return 0;
 }
 
-int NPyBfloat16_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
-                           void* arr) {
-  const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
+template <typename T>
+int NPyCustomFloat_ArgMaxFunc(void* data, npy_intp n, npy_intp* max_ind,
+                              void* arr) {
+  const T* bdata = reinterpret_cast<const T*>(data);
   // Start with a max_val of NaN, this results in the first iteration preferring
   // bdata[0].
   float max_val = std::numeric_limits<float>::quiet_NaN();
@@ -585,9 +638,10 @@
   return 0;
 }
 
-int NPyBfloat16_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
-                           void* arr) {
-  const bfloat16* bdata = reinterpret_cast<const bfloat16*>(data);
+template <typename T>
+int NPyCustomFloat_ArgMinFunc(void* data, npy_intp n, npy_intp* min_ind,
+                              void* arr) {
+  const T* bdata = reinterpret_cast<const T*>(data);
   float min_val = std::numeric_limits<float>::quiet_NaN();
   // Start with a min_val of NaN, this results in the first iteration preferring
   // bdata[0].
@@ -605,20 +659,6 @@
   return 0;
 }
 
-// NumPy casts
-
-template <typename T, typename Enable = void>
-struct TypeDescriptor {
-  // typedef ... T;  // Representation type in memory for NumPy values of type
-  // static int Dtype() { return NPY_...; }  // Numpy type number for T.
-};
-
-template <>
-struct TypeDescriptor<bfloat16> {
-  typedef bfloat16 T;
-  static int Dtype() { return npy_bfloat16; }
-};
-
 template <>
 struct TypeDescriptor<unsigned char> {
   typedef unsigned char T;
@@ -730,6 +770,16 @@
   static int Dtype() { return NPY_CLONGDOUBLE; }
 };
 
+template <typename T>
+float CastToFloat(T value) {
+  return static_cast<float>(value);
+}
+
+template <typename T>
+float CastToFloat(std::complex<T> value) {
+  return CastToFloat(value.real());
+}
+
 // Performs a NumPy array cast from type 'From' to 'To'.
 template <typename From, typename To>
 void NPyCast(void* from_void, void* to_void, npy_intp n, void* fromarr,
@@ -738,26 +788,131 @@
       reinterpret_cast<typename TypeDescriptor<From>::T*>(from_void);
   auto* to = reinterpret_cast<typename TypeDescriptor<To>::T*>(to_void);
   for (npy_intp i = 0; i < n; ++i) {
-    to[i] =
-        static_cast<typename TypeDescriptor<To>::T>(static_cast<To>(from[i]));
+    to[i] = static_cast<typename TypeDescriptor<To>::T>(
+        static_cast<To>(CastToFloat(from[i])));
   }
 }
 
-// Registers a cast between bfloat16 and type 'T'. 'numpy_type' is the NumPy
-// type corresponding to 'T'.
-template <typename T>
-bool RegisterBfloat16Cast(int numpy_type) {
+// Registers a cast between T (a reduced float) and type 'OtherT'. 'numpy_type'
+// is the NumPy type corresponding to 'OtherT'.
+template <typename T, typename OtherT>
+bool RegisterCustomFloatCast(int numpy_type = TypeDescriptor<OtherT>::Dtype()) {
   PyArray_Descr* descr = PyArray_DescrFromType(numpy_type);
-  if (PyArray_RegisterCastFunc(descr, npy_bfloat16, NPyCast<T, bfloat16>) < 0) {
+  if (PyArray_RegisterCastFunc(descr, TypeDescriptor<T>::Dtype(),
+                               NPyCast<OtherT, T>) < 0) {
     return false;
   }
-  if (PyArray_RegisterCastFunc(&NPyBfloat16_Descr, numpy_type,
-                               NPyCast<bfloat16, T>) < 0) {
+  if (PyArray_RegisterCastFunc(&CustomFloatTypeDescriptor<T>::npy_descr,
+                               numpy_type, NPyCast<T, OtherT>) < 0) {
     return false;
   }
   return true;
 }
 
+template <typename T>
+bool RegisterCasts() {
+  if (!RegisterCustomFloatCast<T, Eigen::half>(NPY_HALF)) {
+    return false;
+  }
+
+  if (!RegisterCustomFloatCast<T, float>(NPY_FLOAT)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, double>(NPY_DOUBLE)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, long double>(NPY_LONGDOUBLE)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, bool>(NPY_BOOL)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, unsigned char>(NPY_UBYTE)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, unsigned short>(NPY_USHORT)) {  // NOLINT
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, unsigned int>(NPY_UINT)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, unsigned long>(NPY_ULONG)) {  // NOLINT
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, unsigned long long>(  // NOLINT
+          NPY_ULONGLONG)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, signed char>(NPY_BYTE)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, short>(NPY_SHORT)) {  // NOLINT
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, int>(NPY_INT)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, long>(NPY_LONG)) {  // NOLINT
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, long long>(NPY_LONGLONG)) {  // NOLINT
+    return false;
+  }
+  // Following the numpy convention. imag part is dropped when converting to
+  // float.
+  if (!RegisterCustomFloatCast<T, std::complex<float>>(NPY_CFLOAT)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, std::complex<double>>(NPY_CDOUBLE)) {
+    return false;
+  }
+  if (!RegisterCustomFloatCast<T, std::complex<long double>>(NPY_CLONGDOUBLE)) {
+    return false;
+  }
+
+  // Safe casts from T to other types
+  if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_FLOAT,
+                              NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_DOUBLE,
+                              NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_LONGDOUBLE,
+                              NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CFLOAT,
+                              NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CDOUBLE,
+                              NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(&TypeDescriptor<T>::npy_descr, NPY_CLONGDOUBLE,
+                              NPY_NOSCALAR) < 0) {
+    return false;
+  }
+
+  // Safe casts to T from other types
+  if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL),
+                              TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE),
+                              TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
+    return false;
+  }
+  if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE),
+                              TypeDescriptor<T>::Dtype(), NPY_NOSCALAR) < 0) {
+    return false;
+  }
+
+  return true;
+}
+
 template <typename InType, typename OutType, typename Functor>
 struct UnaryUFunc {
   static std::vector<int> Types() {
@@ -847,7 +1002,7 @@
   }
 };
 
-template <typename UFunc>
+template <typename UFunc, typename CustomFloatT>
 bool RegisterUFunc(PyObject* numpy, const char* name) {
   std::vector<int> types = UFunc::Types();
   PyUFuncGenericFunction fn =
@@ -863,8 +1018,8 @@
                  ufunc->nargs, types.size());
     return false;
   }
-  if (PyUFunc_RegisterLoopForType(ufunc, npy_bfloat16, fn,
-                                  const_cast<int*>(types.data()),
+  if (PyUFunc_RegisterLoopForType(ufunc, TypeDescriptor<CustomFloatT>::Dtype(),
+                                  fn, const_cast<int*>(types.data()),
                                   nullptr) < 0) {
     return false;
   }
@@ -873,20 +1028,24 @@
 
 namespace ufuncs {
 
+template <typename T>
 struct Add {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a + b; }
+  T operator()(T a, T b) { return a + b; }
 };
+template <typename T>
 struct Subtract {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a - b; }
+  T operator()(T a, T b) { return a - b; }
 };
+template <typename T>
 struct Multiply {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a * b; }
+  T operator()(T a, T b) { return a * b; }
 };
+template <typename T>
 struct TrueDivide {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) { return a / b; }
+  T operator()(T a, T b) { return a / b; }
 };
 
-std::pair<float, float> divmod(float a, float b) {
+inline std::pair<float, float> divmod(float a, float b) {
   if (b == 0.0f) {
     float nan = std::numeric_limits<float>::quiet_NaN();
     return {nan, nan};
@@ -914,20 +1073,23 @@
   return {floordiv, mod};
 }
 
+template <typename T>
 struct FloorDivide {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(divmod(static_cast<float>(a), static_cast<float>(b)).first);
+  T operator()(T a, T b) {
+    return T(divmod(static_cast<float>(a), static_cast<float>(b)).first);
   }
 };
+template <typename T>
 struct Remainder {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(
-        divmod(static_cast<float>(a), static_cast<float>(b)).second);
+  T operator()(T a, T b) {
+    return T(divmod(static_cast<float>(a), static_cast<float>(b)).second);
   }
 };
+template <typename T>
 struct DivmodUFunc {
   static std::vector<int> Types() {
-    return {npy_bfloat16, npy_bfloat16, npy_bfloat16, npy_bfloat16};
+    return {TypeDescriptor<T>::Dtype(), TypeDescriptor<T>::Dtype(),
+            TypeDescriptor<T>::Dtype(), TypeDescriptor<T>::Dtype()};
   }
   static void Call(char** args, npy_intp* dimensions, npy_intp* steps,
                    void* data) {
@@ -936,13 +1098,13 @@
     char* o0 = args[2];
     char* o1 = args[3];
     for (npy_intp k = 0; k < *dimensions; k++) {
-      bfloat16 x = *reinterpret_cast<const bfloat16*>(i0);
-      bfloat16 y = *reinterpret_cast<const bfloat16*>(i1);
+      T x = *reinterpret_cast<const T*>(i0);
+      T y = *reinterpret_cast<const T*>(i1);
       float floordiv, mod;
       std::tie(floordiv, mod) =
           divmod(static_cast<float>(x), static_cast<float>(y));
-      *reinterpret_cast<bfloat16*>(o0) = bfloat16(floordiv);
-      *reinterpret_cast<bfloat16*>(o1) = bfloat16(mod);
+      *reinterpret_cast<T*>(o0) = T(floordiv);
+      *reinterpret_cast<T*>(o1) = T(mod);
       i0 += steps[0];
       i1 += steps[1];
       o0 += steps[2];
@@ -950,132 +1112,127 @@
     }
   }
 };
+template <typename T>
 struct Fmod {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::fmod(static_cast<float>(a), static_cast<float>(b)));
+  T operator()(T a, T b) {
+    return T(std::fmod(static_cast<float>(a), static_cast<float>(b)));
   }
 };
+template <typename T>
 struct Negative {
-  bfloat16 operator()(bfloat16 a) { return -a; }
+  T operator()(T a) { return -a; }
 };
+template <typename T>
 struct Positive {
-  bfloat16 operator()(bfloat16 a) { return a; }
+  T operator()(T a) { return a; }
 };
+template <typename T>
 struct Power {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::pow(static_cast<float>(a), static_cast<float>(b)));
+  T operator()(T a, T b) {
+    return T(std::pow(static_cast<float>(a), static_cast<float>(b)));
   }
 };
+template <typename T>
 struct Abs {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::abs(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::abs(static_cast<float>(a))); }
 };
+template <typename T>
 struct Cbrt {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::cbrt(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::cbrt(static_cast<float>(a))); }
 };
+template <typename T>
 struct Ceil {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::ceil(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::ceil(static_cast<float>(a))); }
 };
-struct CopySign {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    // LLVM is smart enough to turn this into (a & 0x7fff) | (b & 0x8000).
-    bfloat16 abs_a = Eigen::numext::abs(a);
-    return std::signbit(static_cast<float>(b)) ? -abs_a : abs_a;
-  }
-};
+template <typename T>
+struct CopySign;
+
+template <typename T>
 struct Exp {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::exp(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::exp(static_cast<float>(a))); }
 };
+template <typename T>
 struct Exp2 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::exp2(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::exp2(static_cast<float>(a))); }
 };
+template <typename T>
 struct Expm1 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::expm1(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::expm1(static_cast<float>(a))); }
 };
+template <typename T>
 struct Floor {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::floor(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::floor(static_cast<float>(a))); }
 };
+template <typename T>
 struct Frexp {
-  std::pair<bfloat16, int> operator()(bfloat16 a) {
+  std::pair<T, int> operator()(T a) {
     int exp;
     float f = std::frexp(static_cast<float>(a), &exp);
-    return {bfloat16(f), exp};
+    return {T(f), exp};
   }
 };
+template <typename T>
 struct Heaviside {
-  bfloat16 operator()(bfloat16 bx, bfloat16 h0) {
+  T operator()(T bx, T h0) {
     float x = static_cast<float>(bx);
     if (Eigen::numext::isnan(x)) {
       return bx;
     }
     if (x < 0) {
-      return bfloat16(0.0f);
+      return T(0.0f);
     }
     if (x > 0) {
-      return bfloat16(1.0f);
+      return T(1.0f);
     }
     return h0;  // x == 0
   }
 };
+template <typename T>
 struct Conjugate {
-  bfloat16 operator()(bfloat16 a) { return a; }
+  T operator()(T a) { return a; }
 };
+template <typename T>
 struct IsFinite {
-  bool operator()(bfloat16 a) { return std::isfinite(static_cast<float>(a)); }
+  bool operator()(T a) { return std::isfinite(static_cast<float>(a)); }
 };
+template <typename T>
 struct IsInf {
-  bool operator()(bfloat16 a) { return std::isinf(static_cast<float>(a)); }
+  bool operator()(T a) { return std::isinf(static_cast<float>(a)); }
 };
+template <typename T>
 struct IsNan {
-  bool operator()(bfloat16 a) {
-    return Eigen::numext::isnan(static_cast<float>(a));
-  }
+  bool operator()(T a) { return Eigen::numext::isnan(static_cast<float>(a)); }
 };
+template <typename T>
 struct Ldexp {
-  bfloat16 operator()(bfloat16 a, int exp) {
-    return bfloat16(std::ldexp(static_cast<float>(a), exp));
+  T operator()(T a, int exp) {
+    return T(std::ldexp(static_cast<float>(a), exp));
   }
 };
+template <typename T>
 struct Log {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::log(static_cast<float>(a))); }
 };
+template <typename T>
 struct Log2 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log2(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::log2(static_cast<float>(a))); }
 };
+template <typename T>
 struct Log10 {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log10(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::log10(static_cast<float>(a))); }
 };
+template <typename T>
 struct Log1p {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::log1p(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::log1p(static_cast<float>(a))); }
 };
+template <typename T>
 struct LogAddExp {
-  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
+  T operator()(T bx, T by) {
     float x = static_cast<float>(bx);
     float y = static_cast<float>(by);
     if (x == y) {
       // Handles infinities of the same sign.
-      return bfloat16(x + std::log(2.0f));
+      return T(x + std::log(2.0f));
     }
     float out = std::numeric_limits<float>::quiet_NaN();
     if (x > y) {
@@ -1083,16 +1240,17 @@
     } else if (x < y) {
       out = y + std::log1p(std::exp(x - y));
     }
-    return bfloat16(out);
+    return T(out);
   }
 };
+template <typename T>
 struct LogAddExp2 {
-  bfloat16 operator()(bfloat16 bx, bfloat16 by) {
+  T operator()(T bx, T by) {
     float x = static_cast<float>(bx);
     float y = static_cast<float>(by);
     if (x == y) {
       // Handles infinities of the same sign.
-      return bfloat16(x + 1.0f);
+      return T(x + 1.0f);
     }
     float out = std::numeric_limits<float>::quiet_NaN();
     if (x > y) {
@@ -1100,202 +1258,432 @@
     } else if (x < y) {
       out = y + std::log1p(std::exp2(x - y)) / std::log(2.0f);
     }
-    return bfloat16(out);
+    return T(out);
   }
 };
+template <typename T>
 struct Modf {
-  std::pair<bfloat16, bfloat16> operator()(bfloat16 a) {
+  std::pair<T, T> operator()(T a) {
     float integral;
     float f = std::modf(static_cast<float>(a), &integral);
-    return {bfloat16(f), bfloat16(integral)};
+    return {T(f), T(integral)};
   }
 };
 
+template <typename T>
 struct Reciprocal {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(1.f / static_cast<float>(a));
-  }
+  T operator()(T a) { return T(1.f / static_cast<float>(a)); }
 };
+template <typename T>
 struct Rint {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::rint(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::rint(static_cast<float>(a))); }
 };
+template <typename T>
 struct Sign {
-  bfloat16 operator()(bfloat16 a) {
+  T operator()(T a) {
     float f(a);
     if (f < 0) {
-      return bfloat16(-1);
+      return T(-1);
     }
     if (f > 0) {
-      return bfloat16(1);
+      return T(1);
     }
     return a;
   }
 };
+template <typename T>
 struct SignBit {
-  bool operator()(bfloat16 a) { return std::signbit(static_cast<float>(a)); }
+  bool operator()(T a) { return std::signbit(static_cast<float>(a)); }
 };
+template <typename T>
 struct Sqrt {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::sqrt(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::sqrt(static_cast<float>(a))); }
 };
+template <typename T>
 struct Square {
-  bfloat16 operator()(bfloat16 a) {
+  T operator()(T a) {
     float f(a);
-    return bfloat16(f * f);
+    return T(f * f);
   }
 };
+template <typename T>
 struct Trunc {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::trunc(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::trunc(static_cast<float>(a))); }
 };
 
 // Trigonometric functions
+template <typename T>
 struct Sin {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::sin(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::sin(static_cast<float>(a))); }
 };
+template <typename T>
 struct Cos {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::cos(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::cos(static_cast<float>(a))); }
 };
+template <typename T>
 struct Tan {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::tan(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::tan(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arcsin {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::asin(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::asin(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arccos {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::acos(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::acos(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arctan {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::atan(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::atan(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arctan2 {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::atan2(static_cast<float>(a), static_cast<float>(b)));
+  T operator()(T a, T b) {
+    return T(std::atan2(static_cast<float>(a), static_cast<float>(b)));
   }
 };
+template <typename T>
 struct Hypot {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
-    return bfloat16(std::hypot(static_cast<float>(a), static_cast<float>(b)));
+  T operator()(T a, T b) {
+    return T(std::hypot(static_cast<float>(a), static_cast<float>(b)));
   }
 };
+template <typename T>
 struct Sinh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::sinh(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::sinh(static_cast<float>(a))); }
 };
+template <typename T>
 struct Cosh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::cosh(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::cosh(static_cast<float>(a))); }
 };
+template <typename T>
 struct Tanh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::tanh(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::tanh(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arcsinh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::asinh(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::asinh(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arccosh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::acosh(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::acosh(static_cast<float>(a))); }
 };
+template <typename T>
 struct Arctanh {
-  bfloat16 operator()(bfloat16 a) {
-    return bfloat16(std::atanh(static_cast<float>(a)));
-  }
+  T operator()(T a) { return T(std::atanh(static_cast<float>(a))); }
 };
+template <typename T>
 struct Deg2rad {
-  bfloat16 operator()(bfloat16 a) {
+  T operator()(T a) {
     static constexpr float radians_per_degree = M_PI / 180.0f;
-    return bfloat16(static_cast<float>(a) * radians_per_degree);
+    return T(static_cast<float>(a) * radians_per_degree);
   }
 };
+template <typename T>
 struct Rad2deg {
-  bfloat16 operator()(bfloat16 a) {
+  T operator()(T a) {
     static constexpr float degrees_per_radian = 180.0f / M_PI;
-    return bfloat16(static_cast<float>(a) * degrees_per_radian);
+    return T(static_cast<float>(a) * degrees_per_radian);
   }
 };
 
+template <typename T>
 struct Eq {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a == b; }
+  npy_bool operator()(T a, T b) { return a == b; }
 };
+template <typename T>
 struct Ne {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a != b; }
+  npy_bool operator()(T a, T b) { return a != b; }
 };
+template <typename T>
 struct Lt {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a < b; }
+  npy_bool operator()(T a, T b) { return a < b; }
 };
+template <typename T>
 struct Gt {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a > b; }
+  npy_bool operator()(T a, T b) { return a > b; }
 };
+template <typename T>
 struct Le {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a <= b; }
+  npy_bool operator()(T a, T b) { return a <= b; }
 };
+template <typename T>
 struct Ge {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a >= b; }
+  npy_bool operator()(T a, T b) { return a >= b; }
 };
+template <typename T>
 struct Maximum {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+  T operator()(T a, T b) {
     float fa(a), fb(b);
     return Eigen::numext::isnan(fa) || fa > fb ? a : b;
   }
 };
+template <typename T>
 struct Minimum {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+  T operator()(T a, T b) {
     float fa(a), fb(b);
     return Eigen::numext::isnan(fa) || fa < fb ? a : b;
   }
 };
+template <typename T>
 struct Fmax {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+  T operator()(T a, T b) {
     float fa(a), fb(b);
     return Eigen::numext::isnan(fb) || fa > fb ? a : b;
   }
 };
+template <typename T>
 struct Fmin {
-  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+  T operator()(T a, T b) {
     float fa(a), fb(b);
     return Eigen::numext::isnan(fb) || fa < fb ? a : b;
   }
 };
 
+template <typename T>
 struct LogicalNot {
-  npy_bool operator()(bfloat16 a) { return !a; }
+  npy_bool operator()(T a) { return !a; }
 };
+template <typename T>
 struct LogicalAnd {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a && b; }
+  npy_bool operator()(T a, T b) { return a && b; }
 };
+template <typename T>
 struct LogicalOr {
-  npy_bool operator()(bfloat16 a, bfloat16 b) { return a || b; }
+  npy_bool operator()(T a, T b) { return a || b; }
 };
+template <typename T>
 struct LogicalXor {
-  npy_bool operator()(bfloat16 a, bfloat16 b) {
+  npy_bool operator()(T a, T b) {
     return static_cast<bool>(a) ^ static_cast<bool>(b);
   }
 };
 
-struct NextAfter {
+template <typename T>
+struct NextAfter;
+
+template <typename T>
+struct Spacing {
+  T operator()(T x) {
+    // Compute the distance between the input and the next number with greater
+    // magnitude. The result should have the sign of the input.
+    T away(std::copysign(std::numeric_limits<float>::infinity(),
+                         static_cast<float>(x)));
+    return NextAfter<T>()(x, away) - x;
+  }
+};
+
+template <typename T>
+bool RegisterUFuncs(PyObject* numpy) {
+  bool ok =
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Add<T>>, T>(numpy, "add") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Subtract<T>>, T>(numpy,
+                                                               "subtract") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Multiply<T>>, T>(numpy,
+                                                               "multiply") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(numpy,
+                                                                 "divide") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp<T>>, T>(numpy,
+                                                                "logaddexp") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::LogAddExp2<T>>, T>(
+          numpy, "logaddexp2") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Negative<T>>, T>(numpy,
+                                                              "negative") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Positive<T>>, T>(numpy,
+                                                              "positive") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::TrueDivide<T>>, T>(
+          numpy, "true_divide") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::FloorDivide<T>>, T>(
+          numpy, "floor_divide") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Power<T>>, T>(numpy, "power") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy,
+                                                                "remainder") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Remainder<T>>, T>(numpy, "mod") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmod<T>>, T>(numpy, "fmod") &&
+      RegisterUFunc<ufuncs::DivmodUFunc<T>, T>(numpy, "divmod") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "absolute") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Abs<T>>, T>(numpy, "fabs") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rint<T>>, T>(numpy, "rint") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sign<T>>, T>(numpy, "sign") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Heaviside<T>>, T>(numpy,
+                                                                "heaviside") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Conjugate<T>>, T>(numpy,
+                                                               "conjugate") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp<T>>, T>(numpy, "exp") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Exp2<T>>, T>(numpy, "exp2") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Expm1<T>>, T>(numpy, "expm1") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log<T>>, T>(numpy, "log") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log2<T>>, T>(numpy, "log2") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log10<T>>, T>(numpy, "log10") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Log1p<T>>, T>(numpy, "log1p") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sqrt<T>>, T>(numpy, "sqrt") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Square<T>>, T>(numpy, "square") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cbrt<T>>, T>(numpy, "cbrt") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Reciprocal<T>>, T>(numpy,
+                                                                "reciprocal") &&
+
+      // Trigonometric functions
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sin<T>>, T>(numpy, "sin") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cos<T>>, T>(numpy, "cos") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tan<T>>, T>(numpy, "tan") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsin<T>>, T>(numpy, "arcsin") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccos<T>>, T>(numpy, "arccos") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctan<T>>, T>(numpy, "arctan") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Arctan2<T>>, T>(numpy,
+                                                              "arctan2") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Hypot<T>>, T>(numpy, "hypot") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Sinh<T>>, T>(numpy, "sinh") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Cosh<T>>, T>(numpy, "cosh") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Tanh<T>>, T>(numpy, "tanh") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arcsinh<T>>, T>(numpy,
+                                                             "arcsinh") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arccosh<T>>, T>(numpy,
+                                                             "arccosh") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Arctanh<T>>, T>(numpy,
+                                                             "arctanh") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Deg2rad<T>>, T>(numpy,
+                                                             "deg2rad") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Rad2deg<T>>, T>(numpy,
+                                                             "rad2deg") &&
+
+      // Comparison functions
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Eq<T>>, T>(numpy, "equal") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ne<T>>, T>(numpy,
+                                                            "not_equal") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Lt<T>>, T>(numpy, "less") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Gt<T>>, T>(numpy, "greater") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Le<T>>, T>(numpy,
+                                                            "less_equal") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::Ge<T>>, T>(numpy,
+                                                            "greater_equal") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Maximum<T>>, T>(numpy,
+                                                              "maximum") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Minimum<T>>, T>(numpy,
+                                                              "minimum") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmax<T>>, T>(numpy, "fmax") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::Fmin<T>>, T>(numpy, "fmin") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalAnd<T>>, T>(
+          numpy, "logical_and") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalOr<T>>, T>(
+          numpy, "logical_or") &&
+      RegisterUFunc<BinaryUFunc<T, bool, ufuncs::LogicalXor<T>>, T>(
+          numpy, "logical_xor") &&
+      RegisterUFunc<UnaryUFunc<T, bool, ufuncs::LogicalNot<T>>, T>(
+          numpy, "logical_not") &&
+
+      // Floating point functions
+      RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsFinite<T>>, T>(numpy,
+                                                                 "isfinite") &&
+      RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsInf<T>>, T>(numpy, "isinf") &&
+      RegisterUFunc<UnaryUFunc<T, bool, ufuncs::IsNan<T>>, T>(numpy, "isnan") &&
+      RegisterUFunc<UnaryUFunc<T, bool, ufuncs::SignBit<T>>, T>(numpy,
+                                                                "signbit") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::CopySign<T>>, T>(numpy,
+                                                               "copysign") &&
+      RegisterUFunc<UnaryUFunc2<T, T, T, ufuncs::Modf<T>>, T>(numpy, "modf") &&
+      RegisterUFunc<BinaryUFunc2<T, int, T, ufuncs::Ldexp<T>>, T>(numpy,
+                                                                  "ldexp") &&
+      RegisterUFunc<UnaryUFunc2<T, T, int, ufuncs::Frexp<T>>, T>(numpy,
+                                                                 "frexp") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Floor<T>>, T>(numpy, "floor") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Ceil<T>>, T>(numpy, "ceil") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Trunc<T>>, T>(numpy, "trunc") &&
+      RegisterUFunc<BinaryUFunc<T, T, ufuncs::NextAfter<T>>, T>(numpy,
+                                                                "nextafter") &&
+      RegisterUFunc<UnaryUFunc<T, T, ufuncs::Spacing<T>>, T>(numpy, "spacing");
+
+  return ok;
+}
+
+}  // namespace ufuncs
+
+template <typename T>
+bool RegisterNumpyDtype(PyObject* numpy) {
+  // If another module (presumably either TF or JAX) has registered a bfloat16
+  // type, use it. We don't want two bfloat16 types if we can avoid it since it
+  // leads to confusion if we have two different types with the same name. This
+  // assumes that the other module has a sufficiently complete bfloat16
+  // implementation. The only known NumPy bfloat16 extension at the time of
+  // writing is this one (distributed in TF and JAX).
+  // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
+  // so we can unambiguously refer to a single canonical definition of bfloat16.
+  int typenum =
+      PyArray_TypeNumFromName(const_cast<char*>(TypeDescriptor<T>::kTypeName));
+  if (typenum != NPY_NOTYPE) {
+    PyArray_Descr* descr = PyArray_DescrFromType(typenum);
+    // The test for an argmax function here is to verify that the
+    // bfloat16 implementation is sufficiently new, and, say, not from
+    // an older version of TF or JAX.
+    if (descr && descr->f && descr->f->argmax) {
+      TypeDescriptor<T>::npy_type = typenum;
+      TypeDescriptor<T>::type_ptr = descr->typeobj;
+      return true;
+    }
+  }
+
+  TypeDescriptor<T>::type.tp_base = &PyGenericArrType_Type;
+
+  if (PyType_Ready(&TypeDescriptor<T>::type) < 0) {
+    return false;
+  }
+
+  // Initializes the NumPy descriptor.
+  PyArray_ArrFuncs& arr_funcs = CustomFloatTypeDescriptor<T>::arr_funcs;
+  PyArray_InitArrFuncs(&arr_funcs);
+  arr_funcs.getitem = NPyCustomFloat_GetItem<T>;
+  arr_funcs.setitem = NPyCustomFloat_SetItem<T>;
+  arr_funcs.compare = NPyCustomFloat_Compare<T>;
+  arr_funcs.copyswapn = NPyCustomFloat_CopySwapN<T>;
+  arr_funcs.copyswap = NPyCustomFloat_CopySwap<T>;
+  arr_funcs.nonzero = NPyCustomFloat_NonZero<T>;
+  arr_funcs.fill = NPyCustomFloat_Fill<T>;
+  arr_funcs.dotfunc = NPyCustomFloat_DotFunc<T>;
+  arr_funcs.compare = NPyCustomFloat_CompareFunc<T>;
+  arr_funcs.argmax = NPyCustomFloat_ArgMaxFunc<T>;
+  arr_funcs.argmin = NPyCustomFloat_ArgMinFunc<T>;
+
+  Py_TYPE(&CustomFloatTypeDescriptor<T>::npy_descr) = &PyArrayDescr_Type;
+  TypeDescriptor<T>::npy_type =
+      PyArray_RegisterDataType(&CustomFloatTypeDescriptor<T>::npy_descr);
+  TypeDescriptor<T>::type_ptr = &TypeDescriptor<T>::type;
+  if (TypeDescriptor<T>::Dtype() < 0) {
+    return false;
+  }
+
+  Safe_PyObjectPtr typeDict_obj =
+      make_safe(PyObject_GetAttrString(numpy, "sctypeDict"));
+  if (!typeDict_obj) return false;
+  // Add the type object to `numpy.typeDict`: that makes
+  // `numpy.dtype(type_name)` work.
+  if (PyDict_SetItemString(
+          typeDict_obj.get(), TypeDescriptor<T>::kTypeName,
+          reinterpret_cast<PyObject*>(&TypeDescriptor<T>::type)) < 0) {
+    return false;
+  }
+
+  // Support dtype(type_name)
+  if (PyDict_SetItemString(TypeDescriptor<T>::type.tp_dict, "dtype",
+                           reinterpret_cast<PyObject*>(
+                               &CustomFloatTypeDescriptor<T>::npy_descr)) < 0) {
+    return false;
+  }
+
+  return RegisterCasts<T>() && ufuncs::RegisterUFuncs<T>(numpy);
+}
+
+namespace ufuncs {
+
+template <>
+struct CopySign<bfloat16> {
+  bfloat16 operator()(bfloat16 a, bfloat16 b) {
+    // LLVM is smart enough to turn this into (a & 0x7fff) | (b & 0x8000).
+    bfloat16 abs_a = Eigen::numext::abs(a);
+    return std::signbit(static_cast<float>(b)) ? -abs_a : abs_a;
+  }
+};
+
+template <>
+struct NextAfter<bfloat16> {
   bfloat16 operator()(bfloat16 from, bfloat16 to) {
     uint16_t from_as_int, to_as_int;
     const uint16_t sign_mask = 1 << 15;
@@ -1333,13 +1721,80 @@
   }
 };
 
-struct Spacing {
-  bfloat16 operator()(bfloat16 x) {
-    // Compute the distance between the input and the next number with greater
-    // magnitude. The result should have the sign of the input.
-    bfloat16 away(std::copysign(std::numeric_limits<float>::infinity(),
-                                static_cast<float>(x)));
-    return NextAfter()(x, away) - x;
+}  // namespace ufuncs
+
+using bfloat16 = Eigen::bfloat16;
+
+template <>
+struct TypeDescriptor<bfloat16> : CustomFloatTypeDescriptor<bfloat16> {
+  typedef bfloat16 T;
+  static constexpr const char* kTypeName = "bfloat16";
+  static constexpr const char* kTpDoc = "bfloat16 floating-point values";
+  // We must register bfloat16 with a kind other than "f", because numpy
+  // considers two types with the same kind and size to be equal, but
+  // float16 != bfloat16.
+  // The downside of this is that NumPy scalar promotion does not work with
+  // bfloat16 values.
+  static constexpr char kNpyDescrKind = 'V';
+  // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
+  // character is unique.
+  static constexpr char kNpyDescrType = 'E';
+  static constexpr char kNpyDescrByteorder = '=';
+};
+
+template <>
+struct TypeDescriptor<float8_e4m3b11>
+    : CustomFloatTypeDescriptor<float8_e4m3b11> {
+  typedef float8_e4m3b11 T;
+  static constexpr const char* kTypeName = "float8_e4m3b11";
+  static constexpr const char* kTpDoc = "float8_e4m3b11 floating-point values";
+  // We must register float8_e4m3b11 with a kind other than "f", because numpy
+  // considers two types with the same kind and size to be equal, and we
+  // expect multiple 1 byte floating point types.
+  // The downside of this is that NumPy scalar promotion does not work with
+  // float8_e4m3b11 values.
+  static constexpr char kNpyDescrKind = 'V';
+  // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
+  // character is unique.
+  static constexpr char kNpyDescrType = 'L';
+  static constexpr char kNpyDescrByteorder = '=';
+};
+
+namespace ufuncs {
+
+template <>
+struct CopySign<float8_e4m3b11> {
+  float8_e4m3b11 operator()(float8_e4m3b11 a, float8_e4m3b11 b) {
+    return float8_e4m3b11::FromRep((a.rep() & 0x7f) | (b.rep() & 0x80));
+  }
+};
+
+template <>
+struct NextAfter<float8_e4m3b11> {
+  float8_e4m3b11 operator()(float8_e4m3b11 from, float8_e4m3b11 to) {
+    uint8_t from_rep = from.rep();
+    uint8_t to_rep = to.rep();
+    if (from_rep == 0x80 || to_rep == 0x80) {
+      return float8_e4m3b11::FromRep(0x80);
+    }
+    if (from_rep == to_rep) {
+      return to;
+    }
+    if (from_rep == 0) {
+      return float8_e4m3b11::FromRep(0x01 | (to_rep & 0x80));
+    }
+    const uint16_t sign_mask = 0x80;
+    uint8_t from_sign = from_rep & sign_mask;
+    uint8_t to_sign = to_rep & sign_mask;
+    uint8_t from_abs = from_rep & ~sign_mask;
+    uint8_t to_abs = to_rep & ~sign_mask;
+    uint8_t magnitude_adjustment =
+        (from_abs > to_abs || from_sign != to_sign) ? 0xFF : 0x0001;
+    uint8_t out_int = from_rep + magnitude_adjustment;
+    if (out_int == 0x80) {
+      out_int = 0x0;
+    }
+    return float8_e4m3b11::FromRep(out_int);
   }
 };
 
@@ -1361,332 +1816,18 @@
     return false;
   }
 
-  // If another module (presumably either TF or JAX) has registered a bfloat16
-  // type, use it. We don't want two bfloat16 types if we can avoid it since it
-  // leads to confusion if we have two different types with the same name. This
-  // assumes that the other module has a sufficiently complete bfloat16
-  // implementation. The only known NumPy bfloat16 extension at the time of
-  // writing is this one (distributed in TF and JAX).
-  // TODO(phawkins): distribute the bfloat16 extension as its own pip package,
-  // so we can unambiguously refer to a single canonical definition of bfloat16.
-  int typenum = PyArray_TypeNumFromName(const_cast<char*>("bfloat16"));
-  if (typenum != NPY_NOTYPE) {
-    PyArray_Descr* descr = PyArray_DescrFromType(typenum);
-    // The test for an argmax function here is to verify that the
-    // bfloat16 implementation is sufficiently new, and, say, not from
-    // an older version of TF or JAX.
-    if (descr && descr->f && descr->f->argmax) {
-      npy_bfloat16 = typenum;
-      bfloat16_type_ptr = descr->typeobj;
-      return true;
-    }
-  }
-
-  bfloat16_type.tp_base = &PyGenericArrType_Type;
-
-  if (PyType_Ready(&bfloat16_type) < 0) {
+  if (!RegisterNumpyDtype<bfloat16>(numpy.get())) {
     return false;
   }
-
-  // Initializes the NumPy descriptor.
-  PyArray_InitArrFuncs(&NPyBfloat16_ArrFuncs);
-  NPyBfloat16_ArrFuncs.getitem = NPyBfloat16_GetItem;
-  NPyBfloat16_ArrFuncs.setitem = NPyBfloat16_SetItem;
-  NPyBfloat16_ArrFuncs.compare = NPyBfloat16_Compare;
-  NPyBfloat16_ArrFuncs.copyswapn = NPyBfloat16_CopySwapN;
-  NPyBfloat16_ArrFuncs.copyswap = NPyBfloat16_CopySwap;
-  NPyBfloat16_ArrFuncs.nonzero = NPyBfloat16_NonZero;
-  NPyBfloat16_ArrFuncs.fill = NPyBfloat16_Fill;
-  NPyBfloat16_ArrFuncs.dotfunc = NPyBfloat16_DotFunc;
-  NPyBfloat16_ArrFuncs.compare = NPyBfloat16_CompareFunc;
-  NPyBfloat16_ArrFuncs.argmax = NPyBfloat16_ArgMaxFunc;
-  NPyBfloat16_ArrFuncs.argmin = NPyBfloat16_ArgMinFunc;
-
-  Py_TYPE(&NPyBfloat16_Descr) = &PyArrayDescr_Type;
-  npy_bfloat16 = PyArray_RegisterDataType(&NPyBfloat16_Descr);
-  bfloat16_type_ptr = &bfloat16_type;
-  if (npy_bfloat16 < 0) {
+  if (!RegisterNumpyDtype<float8_e4m3b11>(numpy.get())) {
     return false;
   }
-
-  Safe_PyObjectPtr typeDict_obj =
-      make_safe(PyObject_GetAttrString(numpy.get(), "sctypeDict"));
-  if (!typeDict_obj) return false;
-  // Add the type object to `numpy.typeDict`: that makes
-  // `numpy.dtype('bfloat16')` work.
-  if (PyDict_SetItemString(typeDict_obj.get(), "bfloat16",
-                           reinterpret_cast<PyObject*>(&bfloat16_type)) < 0) {
-    return false;
-  }
-
-  // Support dtype(bfloat16)
-  if (PyDict_SetItemString(bfloat16_type.tp_dict, "dtype",
-                           reinterpret_cast<PyObject*>(&NPyBfloat16_Descr)) <
-      0) {
-    return false;
-  }
-
-  // Register casts
-  if (!RegisterBfloat16Cast<Eigen::half>(NPY_HALF)) {
-    return false;
-  }
-
-  if (!RegisterBfloat16Cast<float>(NPY_FLOAT)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<double>(NPY_DOUBLE)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<long double>(NPY_LONGDOUBLE)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<bool>(NPY_BOOL)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned char>(NPY_UBYTE)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned short>(NPY_USHORT)) {  // NOLINT
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned int>(NPY_UINT)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned long>(NPY_ULONG)) {  // NOLINT
-    return false;
-  }
-  if (!RegisterBfloat16Cast<unsigned long long>(NPY_ULONGLONG)) {  // NOLINT
-    return false;
-  }
-  if (!RegisterBfloat16Cast<signed char>(NPY_BYTE)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<short>(NPY_SHORT)) {  // NOLINT
-    return false;
-  }
-  if (!RegisterBfloat16Cast<int>(NPY_INT)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<long>(NPY_LONG)) {  // NOLINT
-    return false;
-  }
-  if (!RegisterBfloat16Cast<long long>(NPY_LONGLONG)) {  // NOLINT
-    return false;
-  }
-  // Following the numpy convention. imag part is dropped when converting to
-  // float.
-  if (!RegisterBfloat16Cast<std::complex<float>>(NPY_CFLOAT)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<std::complex<double>>(NPY_CDOUBLE)) {
-    return false;
-  }
-  if (!RegisterBfloat16Cast<std::complex<long double>>(NPY_CLONGDOUBLE)) {
-    return false;
-  }
-
-  // Safe casts from bfloat16 to other types
-  if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_FLOAT, NPY_NOSCALAR) <
-      0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_DOUBLE, NPY_NOSCALAR) <
-      0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_LONGDOUBLE,
-                              NPY_NOSCALAR) < 0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_CFLOAT, NPY_NOSCALAR) <
-      0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_CDOUBLE, NPY_NOSCALAR) <
-      0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(&NPyBfloat16_Descr, NPY_CLONGDOUBLE,
-                              NPY_NOSCALAR) < 0) {
-    return false;
-  }
-
-  // Safe casts to bfloat16 from other types
-  if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BOOL), npy_bfloat16,
-                              NPY_NOSCALAR) < 0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_UBYTE), npy_bfloat16,
-                              NPY_NOSCALAR) < 0) {
-    return false;
-  }
-  if (PyArray_RegisterCanCast(PyArray_DescrFromType(NPY_BYTE), npy_bfloat16,
-                              NPY_NOSCALAR) < 0) {
-    return false;
-  }
-
-  bool ok =
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Add>>(numpy.get(),
-                                                                  "add") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Subtract>>(
-          numpy.get(), "subtract") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Multiply>>(
-          numpy.get(), "multiply") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
-          numpy.get(), "divide") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp>>(
-          numpy.get(), "logaddexp") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::LogAddExp2>>(
-          numpy.get(), "logaddexp2") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Negative>>(
-          numpy.get(), "negative") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Positive>>(
-          numpy.get(), "positive") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::TrueDivide>>(
-          numpy.get(), "true_divide") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::FloorDivide>>(
-          numpy.get(), "floor_divide") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Power>>(numpy.get(),
-                                                                    "power") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
-          numpy.get(), "remainder") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Remainder>>(
-          numpy.get(), "mod") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmod>>(numpy.get(),
-                                                                   "fmod") &&
-      RegisterUFunc<ufuncs::DivmodUFunc>(numpy.get(), "divmod") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
-                                                                 "absolute") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Abs>>(numpy.get(),
-                                                                 "fabs") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rint>>(numpy.get(),
-                                                                  "rint") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sign>>(numpy.get(),
-                                                                  "sign") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Heaviside>>(
-          numpy.get(), "heaviside") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Conjugate>>(
-          numpy.get(), "conjugate") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp>>(numpy.get(),
-                                                                 "exp") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Exp2>>(numpy.get(),
-                                                                  "exp2") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Expm1>>(numpy.get(),
-                                                                   "expm1") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log>>(numpy.get(),
-                                                                 "log") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log2>>(numpy.get(),
-                                                                  "log2") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log10>>(numpy.get(),
-                                                                   "log10") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Log1p>>(numpy.get(),
-                                                                   "log1p") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sqrt>>(numpy.get(),
-                                                                  "sqrt") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Square>>(numpy.get(),
-                                                                    "square") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cbrt>>(numpy.get(),
-                                                                  "cbrt") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Reciprocal>>(
-          numpy.get(), "reciprocal") &&
-
-      // Trigonometric functions
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sin>>(numpy.get(),
-                                                                 "sin") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cos>>(numpy.get(),
-                                                                 "cos") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tan>>(numpy.get(),
-                                                                 "tan") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsin>>(numpy.get(),
-                                                                    "arcsin") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccos>>(numpy.get(),
-                                                                    "arccos") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctan>>(numpy.get(),
-                                                                    "arctan") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Arctan2>>(
-          numpy.get(), "arctan2") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Hypot>>(numpy.get(),
-                                                                    "hypot") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Sinh>>(numpy.get(),
-                                                                  "sinh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Cosh>>(numpy.get(),
-                                                                  "cosh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Tanh>>(numpy.get(),
-                                                                  "tanh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arcsinh>>(
-          numpy.get(), "arcsinh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arccosh>>(
-          numpy.get(), "arccosh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Arctanh>>(
-          numpy.get(), "arctanh") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Deg2rad>>(
-          numpy.get(), "deg2rad") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Rad2deg>>(
-          numpy.get(), "rad2deg") &&
-
-      // Comparison functions
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Eq>>(numpy.get(),
-                                                             "equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ne>>(numpy.get(),
-                                                             "not_equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Lt>>(numpy.get(),
-                                                             "less") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Gt>>(numpy.get(),
-                                                             "greater") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Le>>(numpy.get(),
-                                                             "less_equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::Ge>>(numpy.get(),
-                                                             "greater_equal") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Maximum>>(
-          numpy.get(), "maximum") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Minimum>>(
-          numpy.get(), "minimum") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmax>>(numpy.get(),
-                                                                   "fmax") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::Fmin>>(numpy.get(),
-                                                                   "fmin") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalAnd>>(
-          numpy.get(), "logical_and") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalOr>>(
-          numpy.get(), "logical_or") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bool, ufuncs::LogicalXor>>(
-          numpy.get(), "logical_xor") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::LogicalNot>>(
-          numpy.get(), "logical_not") &&
-
-      // Floating point functions
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsFinite>>(numpy.get(),
-                                                                  "isfinite") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsInf>>(numpy.get(),
-                                                               "isinf") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::IsNan>>(numpy.get(),
-                                                               "isnan") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bool, ufuncs::SignBit>>(numpy.get(),
-                                                                 "signbit") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::CopySign>>(
-          numpy.get(), "copysign") &&
-      RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, bfloat16, ufuncs::Modf>>(
-          numpy.get(), "modf") &&
-      RegisterUFunc<BinaryUFunc2<bfloat16, int, bfloat16, ufuncs::Ldexp>>(
-          numpy.get(), "ldexp") &&
-      RegisterUFunc<UnaryUFunc2<bfloat16, bfloat16, int, ufuncs::Frexp>>(
-          numpy.get(), "frexp") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Floor>>(numpy.get(),
-                                                                   "floor") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Ceil>>(numpy.get(),
-                                                                  "ceil") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Trunc>>(numpy.get(),
-                                                                   "trunc") &&
-      RegisterUFunc<BinaryUFunc<bfloat16, bfloat16, ufuncs::NextAfter>>(
-          numpy.get(), "nextafter") &&
-      RegisterUFunc<UnaryUFunc<bfloat16, bfloat16, ufuncs::Spacing>>(
-          numpy.get(), "spacing");
-
-  return ok;
+  // TODO(parkers): Enable CanCast to-from fp8 and bf16 and f16.
+  return true;
 }
 
 bool RegisterNumpyBfloat16() {
-  if (npy_bfloat16 != NPY_NOTYPE) {
+  if (TypeDescriptor<bfloat16>::Dtype() != NPY_NOTYPE) {
     // Already initialized.
     return true;
   }
@@ -1701,9 +1842,13 @@
 }
 
 PyObject* Bfloat16Dtype() {
-  return reinterpret_cast<PyObject*>(bfloat16_type_ptr);
+  return reinterpret_cast<PyObject*>(TypeDescriptor<bfloat16>::type_ptr);
 }
 
-int Bfloat16NumpyType() { return npy_bfloat16; }
+int Bfloat16NumpyType() { return TypeDescriptor<bfloat16>::Dtype(); }
+
+PyObject* Float8_E4M3B11Dtype() {
+  return reinterpret_cast<PyObject*>(TypeDescriptor<float8_e4m3b11>::type_ptr);
+}
 
 }  // namespace tensorflow
diff --git a/tensorflow/python/lib/core/bfloat16.h b/tensorflow/python/lib/core/bfloat16.h
index e40207b..6c1e971 100644
--- a/tensorflow/python/lib/core/bfloat16.h
+++ b/tensorflow/python/lib/core/bfloat16.h
@@ -29,6 +29,9 @@
 // Returns the id number of the bfloat16 numpy type.
 int Bfloat16NumpyType();
 
+// Returns a pointer to the float8_e4m3b11 dtype object.
+PyObject* Float8_E4M3B11Dtype();
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_PYTHON_LIB_CORE_BFLOAT16_H_
diff --git a/tensorflow/python/lib/core/bfloat16_test.py b/tensorflow/python/lib/core/bfloat16_test.py
index 73ebe5c..e0d442b 100644
--- a/tensorflow/python/lib/core/bfloat16_test.py
+++ b/tensorflow/python/lib/core/bfloat16_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Test cases for the bfloat16 Python type."""
+"""Test cases for the bfloat16,float8_e4m3b11 Python types."""
 
 import collections
 import copy
@@ -32,266 +32,330 @@
 from tensorflow.python.platform import test
 
 bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
+float8_e4m3b11 = _pywrap_bfloat16.TF_float8_e4m3b11_type()
 
 
-def numpy_assert_allclose(a, b, **kwargs):
-  a = a.astype(np.float32) if a.dtype == bfloat16 else a
-  b = b.astype(np.float32) if b.dtype == bfloat16 else b
+def numpy_assert_allclose(a, b, float_type, **kwargs):
+  a = a.astype(np.float32) if a.dtype == float_type else a
+  b = b.astype(np.float32) if b.dtype == float_type else b
   return np.testing.assert_allclose(a, b, **kwargs)
 
 
-def numpy_promote_types(a: Type[np.generic],
-                        b: Type[np.generic]) -> Type[np.generic]:
-  if a == bfloat16 and b == bfloat16:
-    return bfloat16
-  if a == bfloat16:
-    a = np.float32
-  if b == bfloat16:
-    b = np.float32
+def numpy_promote_types(
+    a: Type[np.generic], b: Type[np.generic], float_type: Type[np.generic],
+    next_largest_fp_type: Type[np.generic]) -> Type[np.generic]:
+  if a == float_type and b == float_type:
+    return float_type
+  if a == float_type:
+    a = next_largest_fp_type
+  if b == float_type:
+    b = next_largest_fp_type
   return np.promote_types(a, b)
 
 
-epsilon = float.fromhex("1.0p-7")
+def truncate(x, float_type):
+  if isinstance(x, np.ndarray):
+    return x.astype(float_type).astype(np.float32)
+  else:
+    return type(x)(float_type(x))
+
+
+def test_binary_operation(a, b, op, float_type):
+  a = float_type(a)
+  b = float_type(b)
+  expected = op(np.float32(a), np.float32(b))
+  result = op(a, b)
+  if math.isnan(expected):
+    if not math.isnan(result):
+      raise AssertionError("%s expected to be nan." % repr(result))
+  else:
+    np.testing.assert_equal(
+        truncate(expected, float_type=float_type), float(result))
+
+
+epsilon = {
+    bfloat16: float.fromhex("1.0p-7"),
+    float8_e4m3b11: float.fromhex("1.0p-3"),
+}
 
 # Values that should round trip exactly to float and back.
-FLOAT_VALUES = [
-    0.0, 1.0, -1, 0.5, -0.5, epsilon, 1.0 + epsilon, 1.0 - epsilon,
-    -1.0 - epsilon, -1.0 + epsilon, 3.5, 42.0, 255.0, 256.0,
+FLOAT_VALUES = {}
+FLOAT_VALUES[bfloat16] = [
+    0.0, 1.0, -1, 0.5, -0.5, epsilon[bfloat16], 1.0 + epsilon[bfloat16],
+    1.0 - epsilon[bfloat16], -1.0 - epsilon[bfloat16], -1.0 + epsilon[bfloat16],
+    3.5, 4, 5, 7,
     float("inf"),
     float("-inf"),
     float("nan")
 ]
 
+FLOAT_VALUES[float8_e4m3b11] = [
+    0.0,
+    1.0,
+    -1,
+    0.5,
+    -0.5,
+    epsilon[float8_e4m3b11],
+    1.0 + epsilon[float8_e4m3b11],
+    1.0 - epsilon[float8_e4m3b11],
+    -1.0 - epsilon[float8_e4m3b11],
+    -1.0 + epsilon[float8_e4m3b11],
+    3.5,
+    4,
+    5,
+    7,
+    float(30),  # max float
+    float(-30),  # min float
+    float("nan")
+]
 
-class Bfloat16Test(parameterized.TestCase):
-  """Tests the non-numpy Python methods of the bfloat16 type."""
 
-  def testRoundTripToFloat(self):
-    for v in FLOAT_VALUES:
-      np.testing.assert_equal(v, float(bfloat16(v)))
+# pylint: disable=g-complex-comprehension
+@parameterized.named_parameters(({
+    "testcase_name": "_" + dtype.__name__,
+    "float_type": dtype
+} for dtype in [bfloat16, float8_e4m3b11]))
+class CustomFloatTest(parameterized.TestCase):
+  """Tests the non-numpy Python methods of the custom float type."""
 
-  def testRoundTripNumpyTypes(self):
+  def testRoundTripToFloat(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      np.testing.assert_equal(v, float(float_type(v)))
+
+  def testRoundTripNumpyTypes(self, float_type):
     for dtype in [np.float16, np.float32, np.float64, np.longdouble]:
-      np.testing.assert_equal(-3.75, dtype(bfloat16(dtype(-3.75))))
-      np.testing.assert_equal(1.5, float(bfloat16(dtype(1.5))))
-      np.testing.assert_equal(4.5, dtype(bfloat16(np.array(4.5, dtype))))
+      np.testing.assert_equal(-3.75, dtype(float_type(dtype(-3.75))))
+      np.testing.assert_equal(1.5, float(float_type(dtype(1.5))))
+      np.testing.assert_equal(4.5, dtype(float_type(np.array(4.5, dtype))))
       np.testing.assert_equal(
-          np.array([2, 5, -1], bfloat16), bfloat16(np.array([2, 5, -1], dtype)))
+          np.array([2, 5, -1], float_type),
+          float_type(np.array([2, 5, -1], dtype)))
 
-  def testRoundTripToInt(self):
-    for v in [-256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512]:
-      self.assertEqual(v, int(bfloat16(v)))
+  def testRoundTripToInt(self, float_type):
+    for v in {
+        bfloat16: [
+            -256, -255, -34, -2, -1, 0, 1, 2, 10, 47, 128, 255, 256, 512
+        ],
+        float8_e4m3b11: list(range(-30, 30, 2)) + list(range(-15, 15, 2)),
+    }[float_type]:
+      self.assertEqual(v, int(float_type(v)))
 
-  # pylint: disable=g-complex-comprehension
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + dtype.__name__,
-      "dtype": dtype
-  } for dtype in [bfloat16, np.float16, np.float32, np.float64, np.longdouble]))
-  def testRoundTripToNumpy(self, dtype):
-    for v in FLOAT_VALUES:
-      np.testing.assert_equal(v, bfloat16(dtype(v)))
-      np.testing.assert_equal(v, dtype(bfloat16(dtype(v))))
-      np.testing.assert_equal(v, dtype(bfloat16(np.array(v, dtype))))
-    if dtype != bfloat16:
-      np.testing.assert_equal(
-          np.array(FLOAT_VALUES, dtype),
-          bfloat16(np.array(FLOAT_VALUES, dtype)).astype(dtype))
+  def testRoundTripToNumpy(self, float_type):
+    for dtype in [
+        float_type, np.float16, np.float32, np.float64, np.longdouble
+    ]:
+      with self.subTest(dtype.__name__):
+        for v in FLOAT_VALUES[float_type]:
+          np.testing.assert_equal(v, float_type(dtype(v)))
+          np.testing.assert_equal(v, dtype(float_type(dtype(v))))
+          np.testing.assert_equal(v, dtype(float_type(np.array(v, dtype))))
+        if dtype != float_type:
+          np.testing.assert_equal(
+              np.array(FLOAT_VALUES[float_type], dtype),
+              float_type(np.array(FLOAT_VALUES[float_type],
+                                  dtype)).astype(dtype))
 
-  def testStr(self):
-    self.assertEqual("0", str(bfloat16(0.0)))
-    self.assertEqual("1", str(bfloat16(1.0)))
-    self.assertEqual("-3.5", str(bfloat16(-3.5)))
-    self.assertEqual("0.0078125", str(bfloat16(float.fromhex("1.0p-7"))))
-    self.assertEqual("inf", str(bfloat16(float("inf"))))
-    self.assertEqual("-inf", str(bfloat16(float("-inf"))))
-    self.assertEqual("nan", str(bfloat16(float("nan"))))
+  def testStr(self, float_type):
+    for value in [
+        0.0, 1.0, -3.5,
+        float.fromhex("1.0p-7"),
+        float("inf"),
+        float("-inf"),
+        float("nan")
+    ]:
+      self.assertEqual("%.6g" % float(float_type(value)),
+                       str(float_type(value)))
 
-  def testRepr(self):
-    self.assertEqual("0", repr(bfloat16(0)))
-    self.assertEqual("1", repr(bfloat16(1)))
-    self.assertEqual("-3.5", repr(bfloat16(-3.5)))
-    self.assertEqual("0.0078125", repr(bfloat16(float.fromhex("1.0p-7"))))
-    self.assertEqual("inf", repr(bfloat16(float("inf"))))
-    self.assertEqual("-inf", repr(bfloat16(float("-inf"))))
-    self.assertEqual("nan", repr(bfloat16(float("nan"))))
+  def testRepr(self, float_type):
+    for value in [
+        0.0, 1.0, -3.5,
+        float.fromhex("1.0p-7"),
+        float("inf"),
+        float("-inf"),
+        float("nan")
+    ]:
+      self.assertEqual("%.6g" % float(float_type(value)),
+                       repr(float_type(value)))
 
-  def testHashZero(self):
+  def testHashZero(self, float_type):
     """Tests that negative zero and zero hash to the same value."""
-    self.assertEqual(hash(bfloat16(-0.0)), hash(bfloat16(0.0)))
+    self.assertEqual(hash(float_type(-0.0)), hash(float_type(0.0)))
 
-  @parameterized.parameters(np.extract(np.isfinite(FLOAT_VALUES), FLOAT_VALUES))
-  def testHashNumbers(self, value):
-    self.assertEqual(hash(value), hash(bfloat16(value)), str(value))
+  def testHashNumbers(self, float_type):
+    for value in np.extract(
+        np.isfinite(FLOAT_VALUES[float_type]), FLOAT_VALUES[float_type]):
+      with self.subTest(value):
+        self.assertEqual(hash(value), hash(float_type(value)), str(value))
 
-  @parameterized.named_parameters(("PositiveNan", bfloat16(float("nan"))),
-                                  ("NegativeNan", bfloat16(float("-nan"))))
-  def testHashNan(self, nan):
-    nan_hash = hash(nan)
-    nan_object_hash = object.__hash__(nan)
-    # The hash of a NaN is either 0 or a hash of the object pointer.
-    self.assertIn(nan_hash, (sys.hash_info.nan, nan_object_hash), str(nan))
+  def testHashNan(self, float_type):
+    for name, nan in [("PositiveNan", float_type(float("nan"))),
+                      ("NegativeNan", float_type(float("-nan")))]:
+      with self.subTest(name):
+        nan_hash = hash(nan)
+        nan_object_hash = object.__hash__(nan)
+        # The hash of a NaN is either 0 or a hash of the object pointer.
+        self.assertIn(nan_hash, (sys.hash_info.nan, nan_object_hash), str(nan))
 
-  def testHashInf(self):
-    self.assertEqual(sys.hash_info.inf, hash(bfloat16(float("inf"))), "inf")
-    self.assertEqual(-sys.hash_info.inf, hash(bfloat16(float("-inf"))), "-inf")
+  def testHashInf(self, float_type):
+    if float_type == float8_e4m3b11:
+      self.skipTest("Not supported")  # no inf for e4m3b11
+    self.assertEqual(sys.hash_info.inf, hash(float_type(float("inf"))), "inf")
+    self.assertEqual(-sys.hash_info.inf, hash(float_type(float("-inf"))),
+                     "-inf")
 
   # Tests for Python operations
-  def testNegate(self):
-    for v in FLOAT_VALUES:
-      np.testing.assert_equal(-v, float(-bfloat16(v)))
+  def testNegate(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      np.testing.assert_equal(
+          float(float_type(-float(float_type(v)))), float(-float_type(v)))
 
-  def testAdd(self):
-    np.testing.assert_equal(0, float(bfloat16(0) + bfloat16(0)))
-    np.testing.assert_equal(1, float(bfloat16(1) + bfloat16(0)))
-    np.testing.assert_equal(0, float(bfloat16(1) + bfloat16(-1)))
-    np.testing.assert_equal(5.5, float(bfloat16(2) + bfloat16(3.5)))
-    np.testing.assert_equal(1.25, float(bfloat16(3.5) + bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(float("inf")) + bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(float("-inf")) + bfloat16(-2.25)))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) + bfloat16(float("nan")))))
+  def testAdd(self, float_type):
+    for a, b in [(0, 0), (1, 0), (1, -1), (2, 3.5), (3.5, -2.25),
+                 (float("inf"), -2.25), (float("-inf"), -2.25),
+                 (3.5, float("nan"))]:
+      test_binary_operation(a, b, op=lambda a, b: a + b, float_type=float_type)
 
-  def testAddScalarTypePromotion(self):
+  def testAddScalarTypePromotion(self, float_type):
     """Tests type promotion against Numpy scalar values."""
-    types = [bfloat16, np.float16, np.float32, np.float64, np.longdouble]
+    types = [float_type, np.float16, np.float32, np.float64, np.longdouble]
     for lhs_type in types:
       for rhs_type in types:
-        expected_type = numpy_promote_types(lhs_type, rhs_type)
+        expected_type = numpy_promote_types(
+            lhs_type,
+            rhs_type,
+            float_type=float_type,
+            next_largest_fp_type={
+                bfloat16: np.float32,
+                float8_e4m3b11: np.float32,
+            }[float_type])
         actual_type = type(lhs_type(3.5) + rhs_type(2.25))
         self.assertEqual(expected_type, actual_type)
 
-  def testAddArrayTypePromotion(self):
+  def testAddArrayTypePromotion(self, float_type):
     self.assertEqual(np.float32,
-                     type(bfloat16(3.5) + np.array(2.25, np.float32)))
+                     type(float_type(3.5) + np.array(2.25, np.float32)))
     self.assertEqual(np.float32,
-                     type(np.array(3.5, np.float32) + bfloat16(2.25)))
+                     type(np.array(3.5, np.float32) + float_type(2.25)))
 
-  def testSub(self):
-    np.testing.assert_equal(0, float(bfloat16(0) - bfloat16(0)))
-    np.testing.assert_equal(1, float(bfloat16(1) - bfloat16(0)))
-    np.testing.assert_equal(2, float(bfloat16(1) - bfloat16(-1)))
-    np.testing.assert_equal(-1.5, float(bfloat16(2) - bfloat16(3.5)))
-    np.testing.assert_equal(5.75, float(bfloat16(3.5) - bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(-2.25) - bfloat16(float("inf"))))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(-2.25) - bfloat16(float("-inf"))))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) - bfloat16(float("nan")))))
+  def testSub(self, float_type):
+    for a, b in [(0, 0), (1, 0), (1, -1), (2, 3.5), (3.5, -2.25),
+                 (-2.25, float("inf")), (-2.25, float("-inf")),
+                 (3.5, float("nan"))]:
+      test_binary_operation(a, b, op=lambda a, b: a - b, float_type=float_type)
 
-  def testMul(self):
-    np.testing.assert_equal(0, float(bfloat16(0) * bfloat16(0)))
-    np.testing.assert_equal(0, float(bfloat16(1) * bfloat16(0)))
-    np.testing.assert_equal(-1, float(bfloat16(1) * bfloat16(-1)))
-    np.testing.assert_equal(-7.875, float(bfloat16(3.5) * bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(float("inf")) * bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(float("-inf")) * bfloat16(-2.25)))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) * bfloat16(float("nan")))))
+  def testMul(self, float_type):
+    for a, b in [(0, 0), (1, 0), (1, -1), (3.5, -2.25), (float("inf"), -2.25),
+                 (float("-inf"), -2.25), (3.5, float("nan"))]:
+      test_binary_operation(a, b, op=lambda a, b: a * b, float_type=float_type)
 
-  def testDiv(self):
-    self.assertTrue(math.isnan(float(bfloat16(0) / bfloat16(0))))
-    np.testing.assert_equal(float("inf"), float(bfloat16(1) / bfloat16(0)))
-    np.testing.assert_equal(-1, float(bfloat16(1) / bfloat16(-1)))
-    np.testing.assert_equal(-1.75, float(bfloat16(3.5) / bfloat16(-2)))
-    np.testing.assert_equal(
-        float("-inf"), float(bfloat16(float("inf")) / bfloat16(-2.25)))
-    np.testing.assert_equal(
-        float("inf"), float(bfloat16(float("-inf")) / bfloat16(-2.25)))
-    self.assertTrue(math.isnan(float(bfloat16(3.5) / bfloat16(float("nan")))))
+  def testDiv(self, float_type):
+    for a, b in [(0, 0), (1, 0), (1, -1), (2, 3.5), (3.5, -2.25),
+                 (float("inf"), -2.25), (float("-inf"), -2.25),
+                 (3.5, float("nan"))]:
+      test_binary_operation(a, b, op=lambda a, b: a / b, float_type=float_type)
 
-  def testLess(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v < w, bfloat16(v) < bfloat16(w))
+  def testLess(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      for w in FLOAT_VALUES[float_type]:
+        self.assertEqual(v < w, float_type(v) < float_type(w))
 
-  def testLessEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v <= w, bfloat16(v) <= bfloat16(w))
+  def testLessEqual(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      for w in FLOAT_VALUES[float_type]:
+        self.assertEqual(v <= w, float_type(v) <= float_type(w))
 
-  def testGreater(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v > w, bfloat16(v) > bfloat16(w))
+  def testGreater(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      for w in FLOAT_VALUES[float_type]:
+        self.assertEqual(v > w, float_type(v) > float_type(w))
 
-  def testGreaterEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v >= w, bfloat16(v) >= bfloat16(w))
+  def testGreaterEqual(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      for w in FLOAT_VALUES[float_type]:
+        self.assertEqual(v >= w, float_type(v) >= float_type(w))
 
-  def testEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v == w, bfloat16(v) == bfloat16(w))
+  def testEqual(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      for w in FLOAT_VALUES[float_type]:
+        self.assertEqual(v == w, float_type(v) == float_type(w))
 
-  def testNotEqual(self):
-    for v in FLOAT_VALUES:
-      for w in FLOAT_VALUES:
-        self.assertEqual(v != w, bfloat16(v) != bfloat16(w))
+  def testNotEqual(self, float_type):
+    for v in FLOAT_VALUES[float_type]:
+      for w in FLOAT_VALUES[float_type]:
+        self.assertEqual(v != w, float_type(v) != float_type(w))
 
-  def testNan(self):
-    a = np.isnan(bfloat16(float("nan")))
+  def testNan(self, float_type):
+    a = np.isnan(float_type(float("nan")))
     self.assertTrue(a)
-    numpy_assert_allclose(np.array([1.0, a]), np.array([1.0, a]))
-
-    a = np.array([bfloat16(1.34375),
-                  bfloat16(1.4375),
-                  bfloat16(float("nan"))],
-                 dtype=bfloat16)
-    b = np.array(
-        [bfloat16(1.3359375),
-         bfloat16(1.4375),
-         bfloat16(float("nan"))],
-        dtype=bfloat16)
     numpy_assert_allclose(
-        a, b, rtol=0.1, atol=0.1, equal_nan=True, err_msg="", verbose=True)
+        np.array([1.0, a]), np.array([1.0, a]), float_type=float_type)
 
-  def testSort(self):
-    values_to_sort = np.float32(FLOAT_VALUES)
+    a = np.array(
+        [float_type(1.34375),
+         float_type(1.4375),
+         float_type(float("nan"))],
+        dtype=float_type)
+    b = np.array(
+        [float_type(1.3359375),
+         float_type(1.4375),
+         float_type(float("nan"))],
+        dtype=float_type)
+    numpy_assert_allclose(
+        a,
+        b,
+        rtol=0.1,
+        atol=0.1,
+        equal_nan=True,
+        err_msg="",
+        verbose=True,
+        float_type=float_type)
+
+  def testSort(self, float_type):
+    values_to_sort = np.float32(FLOAT_VALUES[float_type])
     sorted_f32 = np.sort(values_to_sort)
-    sorted_bf16 = np.sort(values_to_sort.astype(bfloat16))  # pylint: disable=too-many-function-args
+    sorted_bf16 = np.sort(values_to_sort.astype(float_type))  # pylint: disable=too-many-function-args
     np.testing.assert_equal(sorted_f32, np.float32(sorted_bf16))
 
-  def testArgmax(self):
-    values_to_sort = np.float32(bfloat16(np.float32(FLOAT_VALUES)))
+  def testArgmax(self, float_type):
+    values_to_sort = np.float32(
+        float_type(np.float32(FLOAT_VALUES[float_type])))
     argmax_f32 = np.argmax(values_to_sort)
-    argmax_bf16 = np.argmax(values_to_sort.astype(bfloat16))  # pylint: disable=too-many-function-args
+    argmax_bf16 = np.argmax(values_to_sort.astype(float_type))  # pylint: disable=too-many-function-args
     np.testing.assert_equal(argmax_f32, argmax_bf16)
 
-  def testArgmaxOnNan(self):
+  def testArgmaxOnNan(self, float_type):
     """Ensures we return the right thing for multiple NaNs."""
     one_with_nans = np.array(
         [1.0, float("nan"), float("nan")], dtype=np.float32)
     np.testing.assert_equal(
-        np.argmax(one_with_nans.astype(bfloat16)), np.argmax(one_with_nans))
+        np.argmax(one_with_nans.astype(float_type)), np.argmax(one_with_nans))
 
-  def testArgmaxOnNegativeInfinity(self):
+  def testArgmaxOnNegativeInfinity(self, float_type):
     """Ensures we return the right thing for negative infinities."""
     inf = np.array([float("-inf")], dtype=np.float32)
-    np.testing.assert_equal(np.argmax(inf.astype(bfloat16)), np.argmax(inf))
+    np.testing.assert_equal(np.argmax(inf.astype(float_type)), np.argmax(inf))
 
-  def testArgmin(self):
-    values_to_sort = np.float32(bfloat16(np.float32(FLOAT_VALUES)))
+  def testArgmin(self, float_type):
+    values_to_sort = np.float32(
+        float_type(np.float32(FLOAT_VALUES[float_type])))
     argmin_f32 = np.argmin(values_to_sort)
-    argmin_bf16 = np.argmin(values_to_sort.astype(bfloat16))  # pylint: disable=too-many-function-args
+    argmin_bf16 = np.argmin(values_to_sort.astype(float_type))  # pylint: disable=too-many-function-args
     np.testing.assert_equal(argmin_f32, argmin_bf16)
 
-  def testArgminOnNan(self):
+  def testArgminOnNan(self, float_type):
     """Ensures we return the right thing for multiple NaNs."""
     one_with_nans = np.array(
         [1.0, float("nan"), float("nan")], dtype=np.float32)
     np.testing.assert_equal(
-        np.argmin(one_with_nans.astype(bfloat16)), np.argmin(one_with_nans))
+        np.argmin(one_with_nans.astype(float_type)), np.argmin(one_with_nans))
 
-  def testArgminOnPositiveInfinity(self):
+  def testArgminOnPositiveInfinity(self, float_type):
     """Ensures we return the right thing for positive infinities."""
     inf = np.array([float("inf")], dtype=np.float32)
-    np.testing.assert_equal(np.argmin(inf.astype(bfloat16)), np.argmin(inf))
+    np.testing.assert_equal(np.argmin(inf.astype(float_type)), np.argmin(inf))
 
-  def testDtypeFromString(self):
-    assert np.dtype("bfloat16") == np.dtype(bfloat16)
+  def testDtypeFromString(self, float_type):
+    assert np.dtype(float_type.__name__) == np.dtype(float_type)
 
 
 BinaryOp = collections.namedtuple("BinaryOp", ["op"])
@@ -316,34 +380,39 @@
 ]
 
 
-class Bfloat16NumPyTest(parameterized.TestCase):
-  """Tests the NumPy integration of the bfloat16 type."""
+# pylint: disable=g-complex-comprehension
+@parameterized.named_parameters(({
+    "testcase_name": "_" + dtype.__name__,
+    "float_type": dtype
+} for dtype in [bfloat16, float8_e4m3b11]))
+class CustomFloatNumPyTest(parameterized.TestCase):
+  """Tests the NumPy integration of the float8_e4m3b11 type."""
 
-  def testDtype(self):
-    self.assertEqual(bfloat16, np.dtype(bfloat16))
+  def testDtype(self, float_type):
+    self.assertEqual(float_type, np.dtype(float_type))
 
-  def testDeepCopyDoesNotAlterHash(self):
+  def testDeepCopyDoesNotAlterHash(self, float_type):
     # For context, see https://github.com/google/jax/issues/4651. If the hash
     # value of the type descriptor is not initialized correctly, a deep copy
     # can change the type hash.
-    dtype = np.dtype(bfloat16)
+    dtype = np.dtype(float_type)
     h = hash(dtype)
     _ = copy.deepcopy(dtype)
     self.assertEqual(h, hash(dtype))
 
-  def testArray(self):
-    x = np.array([[1, 2, 3]], dtype=bfloat16)
-    self.assertEqual(bfloat16, x.dtype)
+  def testArray(self, float_type):
+    x = np.array([[1, 2, 3]], dtype=float_type)
+    self.assertEqual(float_type, x.dtype)
     self.assertEqual("[[1 2 3]]", str(x))
     np.testing.assert_equal(x, x)
-    numpy_assert_allclose(x, x)
+    numpy_assert_allclose(x, x, float_type=float_type)
     self.assertTrue((x == x).all())
 
-  def testComparisons(self):
-    x = np.array([401408, 7, -32], dtype=np.float32)
-    bx = x.astype(bfloat16)
-    y = np.array([82432, 7, 0], dtype=np.float32)
-    by = y.astype(bfloat16)
+  def testComparisons(self, float_type):
+    x = np.array([30, 7, -30], dtype=np.float32)
+    bx = x.astype(float_type)
+    y = np.array([17, 7, 0], dtype=np.float32)
+    by = y.astype(float_type)
     np.testing.assert_equal(x == y, bx == by)
     np.testing.assert_equal(x != y, bx != by)
     np.testing.assert_equal(x < y, bx < by)
@@ -351,22 +420,24 @@
     np.testing.assert_equal(x <= y, bx <= by)
     np.testing.assert_equal(x >= y, bx >= by)
 
-  def testEqual2(self):
-    a = np.array([401408], bfloat16)
-    b = np.array([82432], bfloat16)
+  def testEqual2(self, float_type):
+    if float_type == float8_e4m3b11:
+      self.skipTest("Not supported")  # out of range.
+    a = np.array([401408], float_type)
+    b = np.array([82432], float_type)
     self.assertFalse(a.__eq__(b))
 
-  def testCanCast(self):
+  def testCanCast(self, float_type):
     allowed_casts = [
-        (np.bool_, bfloat16),
-        (np.int8, bfloat16),
-        (np.uint8, bfloat16),
-        (bfloat16, np.float32),
-        (bfloat16, np.float64),
-        (bfloat16, np.longdouble),
-        (bfloat16, np.complex64),
-        (bfloat16, np.complex128),
-        (bfloat16, np.clongdouble),
+        (np.bool_, float_type),
+        (np.int8, float_type),
+        (np.uint8, float_type),
+        (float_type, np.float32),
+        (float_type, np.float64),
+        (float_type, np.longdouble),
+        (float_type, np.complex64),
+        (float_type, np.complex128),
+        (float_type, np.clongdouble),
     ]
     all_dtypes = [
         np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16,
@@ -375,10 +446,13 @@
         np.longlong, np.uintc, np.ulonglong
     ]
     for d in all_dtypes:
-      self.assertEqual((bfloat16, d) in allowed_casts, np.can_cast(bfloat16, d))
-      self.assertEqual((d, bfloat16) in allowed_casts, np.can_cast(d, bfloat16))
+      with self.subTest(d.__name__):
+        self.assertEqual((float_type, d) in allowed_casts,
+                         np.can_cast(float_type, d))
+        self.assertEqual((d, float_type) in allowed_casts,
+                         np.can_cast(d, float_type))
 
-  def testCasts(self):
+  def testCasts(self, float_type):
     for dtype in [
         np.float16, np.float32, np.float64, np.longdouble, np.int8, np.int16,
         np.int32, np.int64, np.complex64, np.complex128, np.clongdouble,
@@ -386,186 +460,228 @@
         np.longlong, np.uintc, np.ulonglong
     ]:
       x = np.array([[1, 2, 3]], dtype=dtype)
-      y = x.astype(bfloat16)
+      y = x.astype(float_type)
       z = y.astype(dtype)
       self.assertTrue(np.all(x == y))
-      self.assertEqual(bfloat16, y.dtype)
+      self.assertEqual(float_type, y.dtype)
       self.assertTrue(np.all(x == z))
       self.assertEqual(dtype, z.dtype)
 
-  def testConformNumpyComplex(self):
+  def testConformNumpyComplex(self, float_type):
     for dtype in [np.complex64, np.complex128, np.clongdouble]:
-      x = np.array([1.1, 2.2 + 2.2j, 3.3], dtype=dtype)
+      x = np.array([1.5, 2.5 + 2.j, 3.25], dtype=dtype)
       y_np = x.astype(np.float32)
-      y_tf = x.astype(bfloat16)
-      numpy_assert_allclose(y_np, y_tf, atol=2e-2)
+      y_tf = x.astype(float_type)
+      numpy_assert_allclose(y_np, y_tf, atol=2e-2, float_type=float_type)
 
       z_np = y_np.astype(dtype)
       z_tf = y_tf.astype(dtype)
-      numpy_assert_allclose(z_np, z_tf, atol=2e-2)
+      numpy_assert_allclose(z_np, z_tf, atol=2e-2, float_type=float_type)
 
-  def testArange(self):
+  def testArange(self, float_type):
     np.testing.assert_equal(
-        np.arange(100, dtype=np.float32).astype(bfloat16),
-        np.arange(100, dtype=bfloat16))
+        np.arange(100, dtype=np.float32).astype(float_type),
+        np.arange(100, dtype=float_type))
     np.testing.assert_equal(
-        np.arange(-10.5, 7.8, 0.5, dtype=np.float32).astype(bfloat16),
-        np.arange(-10.5, 7.8, 0.5, dtype=bfloat16))
+        np.arange(-16, 16, 1, dtype=np.float32).astype(float_type),
+        np.arange(-16, 16, 1, dtype=float_type))
     np.testing.assert_equal(
-        np.arange(-0., -7., -0.25, dtype=np.float32).astype(bfloat16),
-        np.arange(-0., -7., -0.25, dtype=bfloat16))
+        np.arange(-0., -7., -0.25, dtype=np.float32).astype(float_type),
+        np.arange(-0., -7., -0.25, dtype=float_type))
     np.testing.assert_equal(
-        np.arange(-16384., 16384., 64., dtype=np.float32).astype(bfloat16),
-        np.arange(-16384., 16384., 64., dtype=bfloat16))
+        np.arange(-30., 30., 2., dtype=np.float32).astype(float_type),
+        np.arange(-30., 30., 2., dtype=float_type))
 
-  # pylint: disable=g-complex-comprehension
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in UNARY_UFUNCS))
-  def testUnaryUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7, 10).astype(bfloat16)
-    numpy_assert_allclose(
-        op(x).astype(np.float32), op(x.astype(np.float32)), rtol=1e-2)
+  def testUnaryUfunc(self, float_type):
+    for op in UNARY_UFUNCS:
+      with self.subTest(op.__name__):
+        rng = np.random.RandomState(seed=42)
+        x = rng.randn(3, 7, 10).astype(float_type)
+        numpy_assert_allclose(
+            op(x).astype(np.float32),
+            truncate(op(x.astype(np.float32)), float_type=float_type),
+            rtol=1e-4,
+            float_type=float_type)
 
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in BINARY_UFUNCS))
-  def testBinaryUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7, 10).astype(bfloat16)
-    y = rng.randn(4, 1, 7, 10).astype(bfloat16)
-    numpy_assert_allclose(
-        op(x, y).astype(np.float32),
-        op(x.astype(np.float32), y.astype(np.float32)),
-        rtol=1e-2)
+  def testBinaryUfunc(self, float_type):
+    for op in BINARY_UFUNCS:
+      with self.subTest(op.__name__):
+        rng = np.random.RandomState(seed=42)
+        x = rng.randn(3, 7, 10).astype(float_type)
+        y = rng.randn(4, 1, 7, 10).astype(float_type)
+        numpy_assert_allclose(
+            op(x, y).astype(np.float32),
+            truncate(
+                op(x.astype(np.float32), y.astype(np.float32)),
+                float_type=float_type),
+            rtol=1e-4,
+            float_type=float_type)
 
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in BINARY_PREDICATE_UFUNCS))
-  def testBinaryPredicateUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    y = rng.randn(4, 1, 7).astype(bfloat16)
-    np.testing.assert_equal(
-        op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
+  def testBinaryPredicateUfunc(self, float_type):
+    for op in BINARY_PREDICATE_UFUNCS:
+      with self.subTest(op.__name__):
+        rng = np.random.RandomState(seed=42)
+        x = rng.randn(3, 7).astype(float_type)
+        y = rng.randn(4, 1, 7).astype(float_type)
+        np.testing.assert_equal(
+            op(x, y), op(x.astype(np.float32), y.astype(np.float32)))
 
-  @parameterized.named_parameters(({
-      "testcase_name": "_" + op.__name__,
-      "op": op
-  } for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]))
-  def testPredicateUfunc(self, op):
-    rng = np.random.RandomState(seed=42)
-    shape = (3, 7, 10)
-    posinf_flips = rng.rand(*shape) < 0.1
-    neginf_flips = rng.rand(*shape) < 0.1
-    nan_flips = rng.rand(*shape) < 0.1
-    vals = rng.randn(*shape)
-    vals = np.where(posinf_flips, np.inf, vals)
-    vals = np.where(neginf_flips, -np.inf, vals)
-    vals = np.where(nan_flips, np.nan, vals)
-    vals = vals.astype(bfloat16)
-    np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
+  def testPredicateUfunc(self, float_type):
+    for op in [np.isfinite, np.isinf, np.isnan, np.signbit, np.logical_not]:
+      with self.subTest(op.__name__):
+        rng = np.random.RandomState(seed=42)
+        shape = (3, 7, 10)
+        posinf_flips = rng.rand(*shape) < 0.1
+        neginf_flips = rng.rand(*shape) < 0.1
+        nan_flips = rng.rand(*shape) < 0.1
+        vals = rng.randn(*shape)
+        vals = np.where(posinf_flips, np.inf, vals)
+        vals = np.where(neginf_flips, -np.inf, vals)
+        vals = np.where(nan_flips, np.nan, vals)
+        vals = vals.astype(float_type)
+        np.testing.assert_equal(op(vals), op(vals.astype(np.float32)))
 
-  def testDivmod(self):
+  def testDivmod(self, float_type):
     rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    y = rng.randn(4, 1, 7).astype(bfloat16)
+    x = rng.randn(3, 7).astype(float_type)
+    y = rng.randn(4, 1, 7).astype(float_type)
     o1, o2 = np.divmod(x, y)
     e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32))
-    numpy_assert_allclose(o1, e1, rtol=1e-2)
-    numpy_assert_allclose(o2, e2, rtol=1e-2)
+    numpy_assert_allclose(
+        o1,
+        truncate(e1, float_type=float_type),
+        rtol=1e-2,
+        float_type=float_type)
+    numpy_assert_allclose(
+        o2,
+        truncate(e2, float_type=float_type),
+        rtol=1e-2,
+        float_type=float_type)
 
-  def testModf(self):
+  def testModf(self, float_type):
     rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
+    x = rng.randn(3, 7).astype(float_type)
     o1, o2 = np.modf(x)
     e1, e2 = np.modf(x.astype(np.float32))
-    numpy_assert_allclose(o1.astype(np.float32), e1, rtol=1e-2)
-    numpy_assert_allclose(o2.astype(np.float32), e2, rtol=1e-2)
+    numpy_assert_allclose(
+        o1.astype(np.float32),
+        truncate(e1, float_type=float_type),
+        rtol=1e-2,
+        float_type=float_type)
+    numpy_assert_allclose(
+        o2.astype(np.float32),
+        truncate(e2, float_type=float_type),
+        rtol=1e-2,
+        float_type=float_type)
 
-  def testLdexp(self):
+  def testLdexp(self, float_type):
     rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
-    y = rng.randint(-50, 50, (1, 7))
+    x = rng.randn(3, 7).astype(float_type)
+    y = rng.randint(-50, 50, (1, 7)).astype(np.int32)
+    self.assertEqual(np.ldexp(x, y).dtype, x.dtype)
     numpy_assert_allclose(
         np.ldexp(x, y).astype(np.float32),
-        np.ldexp(x.astype(np.float32), y),
+        truncate(np.ldexp(x.astype(np.float32), y), float_type=float_type),
         rtol=1e-2,
-        atol=1e-6)
+        atol=1e-6,
+        float_type=float_type)
 
-  def testFrexp(self):
+  def testFrexp(self, float_type):
     rng = np.random.RandomState(seed=42)
-    x = rng.randn(3, 7).astype(bfloat16)
+    x = rng.randn(3, 7).astype(float_type)
     mant1, exp1 = np.frexp(x)
     mant2, exp2 = np.frexp(x.astype(np.float32))
     np.testing.assert_equal(exp1, exp2)
-    numpy_assert_allclose(mant1, mant2, rtol=1e-2)
+    numpy_assert_allclose(mant1, mant2, rtol=1e-2, float_type=float_type)
 
-  @parameterized.parameters(list(range(1, 128)))
-  def testCopySign(self, nan_payload):
-    inf_bits = 0x7f80
-    nan_bits = inf_bits | nan_payload
-    little_endian_uint16 = np.dtype(np.uint16).newbyteorder("L")
-    little_endian_bfloat = np.dtype(bfloat16).newbyteorder("L")
-    nan = little_endian_uint16.type(nan_bits).view(little_endian_bfloat)
-    nan_with_sign = np.copysign(nan, bfloat16(-1))
-    nan_with_sign_bits = nan_with_sign.view(little_endian_uint16)
-    np.testing.assert_equal(nan_bits | (1 << 15), nan_with_sign_bits)
+  def testCopySign(self, float_type):
+    if float_type == float8_e4m3b11:
+      self.skipTest("Not supported")  # Nans don't have payload.
+    for nan_payload in list(range(1, 128)):
+      with self.subTest(nan_payload):
+        one = np.array(1., dtype=float_type)
+        inf_bits = 0x7f80
+        two = np.array(2., dtype=float_type)
+        nan_bits = inf_bits | nan_payload
+        zero = np.array(0., dtype=float_type)
+        little_endian_uint16 = np.dtype(np.uint16).newbyteorder("L")
+        nan = np.array(np.nan, dtype=float_type)
+        little_endian_bfloat = np.dtype(bfloat16).newbyteorder("L")
+        np.testing.assert_equal(
+            np.nextafter(one, two) - one, epsilon[float_type])
+        nan = little_endian_uint16.type(nan_bits).view(little_endian_bfloat)
+        np.testing.assert_equal(
+            np.nextafter(one, zero) - one, -epsilon[float_type] / 2)
+        nan_with_sign = np.copysign(nan, bfloat16(-1))
+        nan_with_sign_bits = nan_with_sign.view(little_endian_uint16)
+        np.testing.assert_equal(nan_bits | (1 << 15), nan_with_sign_bits)
 
-  def testNextAfter(self):
-    one = np.array(1., dtype=bfloat16)
-    two = np.array(2., dtype=bfloat16)
-    zero = np.array(0., dtype=bfloat16)
-    nan = np.array(np.nan, dtype=bfloat16)
-    np.testing.assert_equal(np.nextafter(one, two) - one, epsilon)
-    np.testing.assert_equal(np.nextafter(one, zero) - one, -epsilon / 2)
+  def testNextAfter(self, float_type):
+    one = np.array(1., dtype=float_type)
+    two = np.array(2., dtype=float_type)
+    zero = np.array(0., dtype=float_type)
+    nan = np.array(np.nan, dtype=float_type)
+    np.testing.assert_equal(np.nextafter(one, two) - one, epsilon[float_type])
+    np.testing.assert_equal(
+        np.nextafter(one, zero) - one, -epsilon[float_type] / 2)
     np.testing.assert_equal(np.isnan(np.nextafter(nan, one)), True)
     np.testing.assert_equal(np.isnan(np.nextafter(one, nan)), True)
     np.testing.assert_equal(np.nextafter(one, one), one)
-    smallest_denormal = float.fromhex("1.0p-133")
+    smallest_denormal = {
+        bfloat16: float.fromhex("1.0p-133"),
+        float8_e4m3b11: float.fromhex("1.0p-13"),
+    }[float_type]
     np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal)
     np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal)
-    for a, b in itertools.permutations([0., -0., nan], 2):
+    for a, b in itertools.permutations([0., nan], 2):
       np.testing.assert_equal(
           np.nextafter(
               np.array(a, dtype=np.float32), np.array(b, dtype=np.float32)),
           np.nextafter(
-              np.array(a, dtype=bfloat16), np.array(b, dtype=bfloat16)))
+              np.array(a, dtype=float_type), np.array(b, dtype=float_type)))
 
-  def testSpacing(self):
+  def testSpacing(self, float_type):
     # Sweep a variety of binades to see that spacing gives the proper ULP.
     # All subnormals have a fixed distance of 2^-133.
     with self.subTest(name="Subnormals"):
-      for i in range(-133, -126):
-        power_of_two = bfloat16(2.0**i)
-        distance = float.fromhex("0x1p-133")
+      if float_type == float8_e4m3b11:
+        self.skipTest("Not supported")
+      for i in {
+          float8_e4m3b11: range(-13, -10),
+          bfloat16: range(-133, -126)
+      }[float_type]:
+        power_of_two = float_type(2.0**i)
+        distance = {
+            float8_e4m3b11: float.fromhex("0x1p-13"),
+            bfloat16: float.fromhex("0x1p-133")
+        }[float_type]
         np.testing.assert_equal(np.spacing(power_of_two), distance)
         np.testing.assert_equal(np.spacing(-power_of_two), -distance)
     # Normals have a distance which depends on their binade.
     with self.subTest(name="Normals"):
-      for i in range(-126, 127):
-        power_of_two = bfloat16(2.0**i)
-        distance = epsilon * power_of_two
+      for i in {
+          float8_e4m3b11: range(-10, 4),
+          bfloat16: range(-126, 127)
+      }[float_type]:
+        power_of_two = float_type(2.0**i)
+        distance = epsilon[float_type] * power_of_two
         np.testing.assert_equal(np.spacing(power_of_two), distance)
         np.testing.assert_equal(np.spacing(-power_of_two), -distance)
-    inf = bfloat16(float("inf"))
-    nan = bfloat16(float("nan"))
+    inf = float_type(float("inf"))
+    nan = float_type(float("nan"))
     # Check that spacing agrees with arithmetic involving nextafter.
     with self.subTest(name="NextAfter"):
-      for x in FLOAT_VALUES:
-        x_bfloat16 = bfloat16(x)
-        spacing = np.spacing(x_bfloat16)
-        toward = np.copysign(inf, x_bfloat16)
-        nextup = np.nextafter(x_bfloat16, toward)
-        np.testing.assert_equal(spacing, nextup - x_bfloat16)
+      for x in FLOAT_VALUES[float_type]:
+        x_float_type = float_type(x)
+        spacing = np.spacing(x_float_type)
+        toward = np.copysign(inf, x_float_type)
+        nextup = np.nextafter(x_float_type, toward)
+        np.testing.assert_equal(spacing, nextup - x_float_type)
     # Check that spacing for special values gives the correct answer.
     with self.subTest(name="NonFinite"):
       np.testing.assert_equal(np.spacing(nan), np.spacing(np.float32(nan)))
-      np.testing.assert_equal(np.spacing(inf), np.spacing(np.float32(inf)))
+      if float_type != float8_e4m3b11:  # inf not supported.
+        np.testing.assert_equal(np.spacing(inf), np.spacing(np.float32(inf)))
 
 
 if __name__ == "__main__":
diff --git a/tensorflow/python/lib/core/bfloat16_wrapper.cc b/tensorflow/python/lib/core/bfloat16_wrapper.cc
index 741468b..8a3c6f8 100644
--- a/tensorflow/python/lib/core/bfloat16_wrapper.cc
+++ b/tensorflow/python/lib/core/bfloat16_wrapper.cc
@@ -21,4 +21,7 @@
 
   m.def("TF_bfloat16_type",
         [] { return pybind11::handle(tensorflow::Bfloat16Dtype()); });
+
+  m.def("TF_float8_e4m3b11_type",
+        [] { return pybind11::handle(tensorflow::Float8_E4M3B11Dtype()); });
 }
diff --git a/tensorflow/python/lib/core/float8_e4m3b11.cc b/tensorflow/python/lib/core/float8_e4m3b11.cc
new file mode 100644
index 0000000..7f47da5
--- /dev/null
+++ b/tensorflow/python/lib/core/float8_e4m3b11.cc
@@ -0,0 +1,87 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/lib/core/float8_e4m3b11.h"
+
+#include <stdio.h>
+
+namespace tensorflow {
+
+uint8_t float_to_float8_e4m3b11(float v) {
+  static_assert(sizeof(float) == sizeof(uint32_t), "Invalid");
+  uint32_t tmp = *reinterpret_cast<uint32_t*>(&v);
+
+  uint32_t sign = (tmp & 0x80000000) >> 24;
+  uint32_t exponent = (tmp >> 23) & 0xff;
+  uint32_t mantissa = tmp & 0x7fffff;
+  // subnormals
+  if (exponent < 127 - 10) {
+    if (exponent < 127 - 14) {
+      return 0x00;
+    }
+    uint32_t shifted_mantissa =
+        (0x800000 | mantissa) >> (10 - ((exponent - 127)));
+    if (shifted_mantissa == 0) return 0x00;
+    return sign | shifted_mantissa;
+  }
+  if (exponent > 127 + 4) {
+    if (exponent == 255 && mantissa != 0) {
+      return 0x80;  // nan.
+    }
+    return 0x7f | sign;
+  }
+  exponent = exponent - (127 - 11);
+  uint8_t result = sign | (exponent << 3) | (mantissa >> 20);
+  if (result == 0x80) {
+    result = 0;
+  }
+  return result;
+}
+
+static uint32_t clz_uint32(uint32_t x) {
+#ifdef __GNUC__
+  return __builtin_clz(x);
+#else
+  uint32_t out = 32;
+  while (x != 0) {
+    x = x >> 1;
+    out -= 1;
+  }
+  return out;
+#endif
+}
+
+float float8_e4m3b11_to_float(uint8_t v) {
+  if (v == 0x80) {
+    return NAN;
+  }
+  if (v == 0) {
+    return 0;
+  }
+  uint32_t sign = (0x80 & v) << 24;
+  uint32_t exponent = (((v & 0x78) >> 3) + (127 - 11));
+  uint32_t mantissa = (v & 0x7) << 20;
+  // subnormals
+  if ((v & 0x78) == 0) {
+    uint32_t nzeros = clz_uint32(v & 0x7);
+    mantissa = ((v & 0x7) << (nzeros - 29 + 21)) & (0x3 << 21);
+    uint32_t tmp = sign | ((0x72 - nzeros + 31) << 23) | mantissa;
+    return *reinterpret_cast<float*>(&tmp);
+  }
+  uint32_t tmp = sign | (exponent << 23) | mantissa;
+  return *reinterpret_cast<float*>(&tmp);
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/python/lib/core/float8_e4m3b11.h b/tensorflow/python/lib/core/float8_e4m3b11.h
new file mode 100644
index 0000000..79a9f31
--- /dev/null
+++ b/tensorflow/python/lib/core/float8_e4m3b11.h
@@ -0,0 +1,64 @@
+/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_FLOAT_E4M3B11_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_FLOAT_E4M3B11_H_
+
+#include <stdint.h>
+
+#include <cmath>
+#include <cstring>
+#include <memory>
+
+namespace tensorflow {
+
+uint8_t float_to_float8_e4m3b11(float v);
+float float8_e4m3b11_to_float(uint8_t v);
+
+class float8_e4m3b11 {
+ public:
+  // Exponent: 4, Mantissa: 3, bias: 11
+  float8_e4m3b11() {}
+  float8_e4m3b11(float v) : rep_(float_to_float8_e4m3b11(v)) {}  // NOLINT
+
+  operator float() const {  // NOLINT: Allow implicit conversion to float,
+                            // because it is lossless.
+    return float8_e4m3b11_to_float(rep_);
+  }
+
+  float8_e4m3b11 operator-() const {
+    if ((rep_ & 0x7f) == 0x00) {
+      return *this;
+    }  // nan or 0.
+    float8_e4m3b11 result = *this;
+    result.rep_ = result.rep_ ^ 0x80;
+    return result;
+  }
+
+  uint8_t rep() const { return rep_; }
+
+  static float8_e4m3b11 FromRep(uint8_t rep) {
+    float8_e4m3b11 result;
+    memcpy(&result, &rep, sizeof(float8_e4m3b11));
+    return result;
+  }
+
+ private:
+  uint8_t rep_;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_PYTHON_LIB_CORE_FLOAT_E4M3B11_H_