Skip to content

Commit

Permalink
Add float8_e8m0_fnu (E8M0) OCP MX scale format
Browse files Browse the repository at this point in the history
Adding the OCP MX scale format `E8M0`, which has the following properties:
* Unsigned format;
* 8 exponent bits;
* Exponent range from -127 to 127;
* No zero and infinity;
* Single NaN value (0xFF);

`ml_dtypes` `float8_base` C++ class is extended to support floating point formats
which are unsigned and with no zero (i.e. additional `kIsSigned` and `kHasZero` Traits properties).

Base on these traits, `float8_e8m0_fnu` has been implemented using the existing functionalities (convert, unary/binary ops, ...).
Float8 Python unit tests have been extended to be able to cover unsigned floating point formats.
  • Loading branch information
balancap committed Sep 3, 2024
1 parent acf7e8c commit 90281bc
Show file tree
Hide file tree
Showing 7 changed files with 418 additions and 29 deletions.
3 changes: 3 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"iinfo",
"int2",
"int4",
Expand All @@ -43,6 +44,7 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu
from ml_dtypes._ml_dtypes_ext import int2
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint2
Expand All @@ -57,6 +59,7 @@
float8_e4m3fnuz: Type[np.generic]
float8_e5m2: Type[np.generic]
float8_e5m2fnuz: Type[np.generic]
float8_e8m0fnu: Type[np.generic]
int2: Type[np.generic]
int4: Type[np.generic]
uint2: Type[np.generic]
Expand Down
64 changes: 64 additions & 0 deletions ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand All @@ -34,6 +35,7 @@
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz)
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)


class _Bfloat16MachArLike:
Expand Down Expand Up @@ -108,6 +110,15 @@ def __init__(self):
self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal)


class _Float8E8m0fnuMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-127")
self.smallest_normal = float8_e8m0fnu(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-127")
self.smallest_subnormal = float8_e8m0fnu(smallest_subnormal)


class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[np.dtype, np.finfo] = {}
Expand Down Expand Up @@ -472,6 +483,51 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e8m0fnu_finfo():
def float_to_str(f):
return "%6.2e" % float(f)

tiny = float.fromhex("0x1p-127")
resolution = 0.1
eps = float.fromhex("0x1p+0")
epsneg = float.fromhex("0x1p-1")
max_ = float.fromhex("0x1p+127")

obj = object.__new__(np.finfo)
obj.dtype = _float8_e8m0fnu_dtype
obj.bits = 8
obj.eps = float8_e8m0fnu(eps)
obj.epsneg = float8_e8m0fnu(epsneg)
obj.machep = 0
obj.negep = -1
obj.max = float8_e8m0fnu(max_)
obj.min = float8_e8m0fnu(tiny)
obj.nexp = 8
obj.nmant = 0
obj.iexp = obj.nexp
obj.maxexp = 128
obj.minexp = -127
obj.precision = 1
obj.resolution = float8_e8m0fnu(resolution)
# pylint: disable=protected-access
obj._machar = _Float8E8m0fnuMachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e8m0fnu(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
# pylint: enable=protected-access
return obj

def __new__(cls, dtype):
if (
isinstance(dtype, str)
Expand Down Expand Up @@ -539,4 +595,12 @@ def __new__(cls, dtype):
if _float8_e5m2fnuz_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e5m2fnuz_dtype] = cls._float8_e5m2fnuz_finfo()
return cls._finfo_cache[_float8_e5m2fnuz_dtype]
if (
isinstance(dtype, str)
and dtype == "float8_e8m0fnu"
or dtype == _float8_e8m0fnu_dtype
):
if _float8_e8m0fnu_dtype not in cls._finfo_cache:
cls._finfo_cache[_float8_e8m0fnu_dtype] = cls._float8_e8m0fnu_finfo()
return cls._finfo_cache[_float8_e8m0fnu_dtype]
return super().__new__(cls, dtype)
26 changes: 26 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,21 @@ struct TypeDescriptor<float8_e5m2fnuz> : CustomFloatType<float8_e5m2fnuz> {
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e8m0fnu> : CustomFloatType<float8_e8m0fnu> {
typedef float8_e8m0fnu T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e8m0fnu";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e8m0fnu";
static constexpr const char* kTpDoc = "float8_e8m0fnu floating-point values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'W';
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<int2> : IntNTypeDescriptor<int2> {
typedef int2 T;
Expand Down Expand Up @@ -318,6 +333,9 @@ bool Initialize() {
if (!RegisterFloatDtype<float8_e5m2fnuz>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e8m0fnu>(numpy.get())) {
return false;
}

if (!RegisterIntNDtype<int2>(numpy.get())) {
return false;
Expand Down Expand Up @@ -366,6 +384,8 @@ bool Initialize() {
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3fn, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e5m2, float>();
success &= RegisterTwoWayCustomCast<float8_e3m4, float8_e4m3, float>();
success &= RegisterTwoWayCustomCast<float8_e8m0fnu, bfloat16, float>();
success &= RegisterTwoWayCustomCast<bfloat16, float8_e8m0fnu, float>();
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
return success;
Expand Down Expand Up @@ -435,6 +455,12 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "float8_e8m0fnu",
reinterpret_cast<PyObject*>(
TypeDescriptor<float8_e8m0fnu>::type_ptr)) <
0) {
return nullptr;
}
if (PyObject_SetAttrString(m.get(), "bfloat16",
reinterpret_cast<PyObject*>(
TypeDescriptor<bfloat16>::type_ptr)) < 0) {
Expand Down
10 changes: 9 additions & 1 deletion ml_dtypes/_src/ufuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,20 +322,28 @@ using BitsType = typename GetUnsignedInteger<sizeof(T)>::type;

template <typename T>
std::pair<BitsType<T>, BitsType<T>> SignAndMagnitude(T x) {
const BitsType<T> x_bits = Eigen::numext::bit_cast<BitsType<T>>(x);
// Unsigned floating point format (e.g. E8M0) => no sign bit (zero by default).
if constexpr(!std::numeric_limits<T>::is_signed) {
return {BitsType<T>(0), x_bits};
}
// For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without
// flipping the sign. Therefore, we need to explicitly check the
// most-significant bit.
constexpr BitsType<T> kSignMask = BitsType<T>(1)
<< (sizeof(BitsType<T>) * CHAR_BIT - 1);
const BitsType<T> x_abs_bits =
Eigen::numext::bit_cast<BitsType<T>>(Eigen::numext::abs(x));
const BitsType<T> x_bits = Eigen::numext::bit_cast<BitsType<T>>(x);
return {x_bits & kSignMask, x_abs_bits};
}

template <typename T>
struct CopySign {
T operator()(T a, T b) {
// Unsigned floating point format => no change.
if constexpr(!std::numeric_limits<T>::is_signed) {
return a;
}
auto [a_sign, a_abs_bits] = SignAndMagnitude(a);
auto [b_sign, b_abs_bits] = SignAndMagnitude(b);
BitsType<T> rep = a_abs_bits | b_sign;
Expand Down
Loading

0 comments on commit 90281bc

Please sign in to comment.