Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add float8_e8m0_fnu (E8M0) OCP MX scale format #166

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"iinfo",
"int2",
"int4",
Expand All @@ -49,6 +50,7 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu
from ml_dtypes._ml_dtypes_ext import int2
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint2
Expand All @@ -66,6 +68,7 @@
float8_e4m3fnuz: Type[np.generic]
float8_e5m2: Type[np.generic]
float8_e5m2fnuz: Type[np.generic]
float8_e8m0fnu: Type[np.generic]
int2: Type[np.generic]
int4: Type[np.generic]
uint2: Type[np.generic]
Expand Down
59 changes: 58 additions & 1 deletion ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand All @@ -40,6 +41,7 @@
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz)
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)


class _Bfloat16MachArLike:
Expand Down Expand Up @@ -141,6 +143,15 @@ def __init__(self):
self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal)


class _Float8E8m0fnuMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-127")
self.smallest_normal = float8_e8m0fnu(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-127")
self.smallest_subnormal = float8_e8m0fnu(smallest_subnormal)


class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[type, np.finfo] = {} # pylint: disable=g-bare-generic
Expand Down Expand Up @@ -628,6 +639,51 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e8m0fnu_finfo():
def float_to_str(f):
return "%6.2e" % float(f)

tiny = float.fromhex("0x1p-127")
resolution = 0.1
eps = float.fromhex("0x1p+0")
epsneg = float.fromhex("0x1p-1")
max_ = float.fromhex("0x1p+127")

obj = object.__new__(np.finfo)
obj.dtype = _float8_e8m0fnu_dtype
obj.bits = 8
obj.eps = float8_e8m0fnu(eps)
obj.epsneg = float8_e8m0fnu(epsneg)
obj.machep = 0
obj.negep = -1
obj.max = float8_e8m0fnu(max_)
obj.min = float8_e8m0fnu(tiny)
obj.nexp = 8
obj.nmant = 0
obj.iexp = obj.nexp
obj.maxexp = 128
obj.minexp = -127
obj.precision = 1
obj.resolution = float8_e8m0fnu(resolution)
# pylint: disable=protected-access
obj._machar = _Float8E8m0fnuMachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e8m0fnu(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
# pylint: enable=protected-access
return obj

_finfo_type_map = {
_bfloat16_dtype: _bfloat16_finfo,
_float4_e2m1fn_dtype: _float4_e2m1fn_finfo,
Expand All @@ -640,6 +696,7 @@ def float_to_str(f):
_float8_e4m3b11fnuz_dtype: _float8_e4m3b11fnuz_finfo,
_float8_e5m2_dtype: _float8_e5m2_finfo,
_float8_e5m2fnuz_dtype: _float8_e5m2fnuz_finfo,
_float8_e8m0fnu_dtype: _float8_e8m0fnu_finfo,
}
_finfo_name_map = {t.name: t for t in _finfo_type_map}

Expand All @@ -656,6 +713,6 @@ def __new__(cls, dtype):

init = cls._finfo_type_map.get(key)
if init is not None:
cls._finfo_cache[dtype] = init()
cls._finfo_cache[dtype] = init.__func__()
return cls._finfo_cache[dtype]
return super().__new__(cls, dtype)
22 changes: 22 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ struct TypeDescriptor<float4_e2m1fn> : CustomFloatType<float4_e2m1fn> {
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e8m0fnu> : CustomFloatType<float8_e8m0fnu> {
typedef float8_e8m0fnu T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e8m0fnu";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e8m0fnu";
static constexpr const char* kTpDoc = "float8_e8m0fnu floating-point values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'W';
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<int2> : IntNTypeDescriptor<int2> {
typedef int2 T;
Expand Down Expand Up @@ -379,6 +394,9 @@ bool Initialize() {
!RegisterFloatDtype<float4_e2m1fn>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e8m0fnu>(numpy.get())) {
return false;
}

if (!RegisterIntNDtype<int2>(numpy.get()) ||
!RegisterIntNDtype<uint2>(numpy.get()) ||
Expand All @@ -393,6 +411,9 @@ bool Initialize() {
float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz,
float8_e5m2, float8_e5m2fnuz, float6_e2m3fn,
float6_e3m2fn, float4_e2m1fn>();
// Only registering to/from BF16 and FP32 for float8_e8m0fnu.
success &= RegisterTwoWayCustomCast<float8_e8m0fnu, bfloat16, float>();
success &= RegisterTwoWayCustomCast<bfloat16, float8_e8m0fnu, float>();
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
return success;
Expand Down Expand Up @@ -433,6 +454,7 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
!InitModuleType<float8_e4m3fnuz>(m.get(), "float8_e4m3fnuz") ||
!InitModuleType<float8_e5m2>(m.get(), "float8_e5m2") ||
!InitModuleType<float8_e5m2fnuz>(m.get(), "float8_e5m2fnuz") ||
!InitModuleType<float8_e8m0fnu>(m.get(), "float8_e8m0fnu") ||
!InitModuleType<bfloat16>(m.get(), "bfloat16") ||
!InitModuleType<int2>(m.get(), "int2") ||
!InitModuleType<int4>(m.get(), "int4") ||
Expand Down
11 changes: 10 additions & 1 deletion ml_dtypes/_src/ufuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ using BitsType = typename GetUnsignedInteger<sizeof(T)>::type;

template <typename T>
std::pair<BitsType<T>, BitsType<T>> SignAndMagnitude(T x) {
const BitsType<T> x_bits = Eigen::numext::bit_cast<BitsType<T>>(x);
// Unsigned floating point format (e.g. E8M0) => no sign bit (zero by
// default).
if constexpr (!std::numeric_limits<T>::is_signed) {
return {BitsType<T>(0), x_bits};
}
// For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without
// flipping the sign. Therefore, we need to explicitly check the
// most-significant bit.
Expand All @@ -332,13 +338,16 @@ std::pair<BitsType<T>, BitsType<T>> SignAndMagnitude(T x) {
constexpr bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
const BitsType<T> x_abs_bits =
Eigen::numext::bit_cast<BitsType<T>>(Eigen::numext::abs(x));
const BitsType<T> x_bits = Eigen::numext::bit_cast<BitsType<T>>(x);
return {has_nan ? x_bits & kSignMask : x_bits ^ x_abs_bits, x_abs_bits};
}

template <typename T>
struct CopySign {
T operator()(T a, T b) {
// Unsigned floating point format => no change.
if constexpr (!std::numeric_limits<T>::is_signed) {
return a;
}
auto [a_sign, a_abs_bits] = SignAndMagnitude(a);
auto [b_sign, b_abs_bits] = SignAndMagnitude(b);
BitsType<T> rep = a_abs_bits | b_sign;
Expand Down
Loading