Skip to content

Commit

Permalink
Merge pull request #166 from graphcore-research:add-e8m0-datatype
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673897992
  • Loading branch information
The ml_dtypes Authors committed Sep 12, 2024
2 parents 40e66e5 + b6d3659 commit d581b6f
Show file tree
Hide file tree
Showing 9 changed files with 444 additions and 30 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):
## [Unreleased]

* Added new 8-bit float types following IEEE 754 convention:
`ml_dtypes.float8_e4m3` and `ml_dtypes.float8_e3m4`.
`ml_dtypes.float8_e4m3`, `ml_dtypes.float8_e3m4`
* Added the 8-bit floating point type `ml_dtypes.float8_e8m0fnu`, which is the
OpenCompute MX scale format.
* Added new 4-bit and 6-bit float types:
`ml_dtypes.float4_e2m1fn`, `ml_dtypes.float6_e2m3fn` and `ml_dtypes.float6_e3m2fn`.
* Fix outputs of float `divmod` and `floor_divide` when denominator is zero.
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ This type has the following characteristics:
* NaNs: Supported with sign bit set to 1, exponent bits and mantissa bits set to all 0s - `0b10000000`
* denormals when exponent is 0

### `float8_e8m0fnu`

[OpenCompute MX](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf)
scale format E8M0, which has the following properties:
* Unsigned format
* 8 exponent bits
* Exponent range from -127 to 127
* No zero and infinity
* Single NaN value (0xFF).

## `int2`, `int4`, `uint2` and `uint4`

2 and 4-bit integer types, where each element is represented unpacked (i.e.,
Expand Down
3 changes: 3 additions & 0 deletions ml_dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"float8_e4m3fnuz",
"float8_e5m2",
"float8_e5m2fnuz",
"float8_e8m0fnu",
"iinfo",
"int2",
"int4",
Expand All @@ -49,6 +50,7 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu
from ml_dtypes._ml_dtypes_ext import int2
from ml_dtypes._ml_dtypes_ext import int4
from ml_dtypes._ml_dtypes_ext import uint2
Expand All @@ -66,6 +68,7 @@
float8_e4m3fnuz: Type[np.generic]
float8_e5m2: Type[np.generic]
float8_e5m2fnuz: Type[np.generic]
float8_e8m0fnu: Type[np.generic]
int2: Type[np.generic]
int4: Type[np.generic]
uint2: Type[np.generic]
Expand Down
59 changes: 58 additions & 1 deletion ml_dtypes/_finfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz
from ml_dtypes._ml_dtypes_ext import float8_e5m2
from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu
import numpy as np

_bfloat16_dtype = np.dtype(bfloat16)
Expand All @@ -40,6 +41,7 @@
_float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz)
_float8_e5m2_dtype = np.dtype(float8_e5m2)
_float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz)
_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu)


class _Bfloat16MachArLike:
Expand Down Expand Up @@ -141,6 +143,15 @@ def __init__(self):
self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal)


class _Float8E8m0fnuMachArLike:

def __init__(self):
smallest_normal = float.fromhex("0x1p-127")
self.smallest_normal = float8_e8m0fnu(smallest_normal)
smallest_subnormal = float.fromhex("0x1p-127")
self.smallest_subnormal = float8_e8m0fnu(smallest_subnormal)


class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring
__doc__ = np.finfo.__doc__
_finfo_cache: Dict[type, np.finfo] = {} # pylint: disable=g-bare-generic
Expand Down Expand Up @@ -628,6 +639,51 @@ def float_to_str(f):
# pylint: enable=protected-access
return obj

@staticmethod
def _float8_e8m0fnu_finfo():
def float_to_str(f):
return "%6.2e" % float(f)

tiny = float.fromhex("0x1p-127")
resolution = 0.1
eps = float.fromhex("0x1p+0")
epsneg = float.fromhex("0x1p-1")
max_ = float.fromhex("0x1p+127")

obj = object.__new__(np.finfo)
obj.dtype = _float8_e8m0fnu_dtype
obj.bits = 8
obj.eps = float8_e8m0fnu(eps)
obj.epsneg = float8_e8m0fnu(epsneg)
obj.machep = 0
obj.negep = -1
obj.max = float8_e8m0fnu(max_)
obj.min = float8_e8m0fnu(tiny)
obj.nexp = 8
obj.nmant = 0
obj.iexp = obj.nexp
obj.maxexp = 128
obj.minexp = -127
obj.precision = 1
obj.resolution = float8_e8m0fnu(resolution)
# pylint: disable=protected-access
obj._machar = _Float8E8m0fnuMachArLike()
if not hasattr(obj, "tiny"):
obj.tiny = float8_e8m0fnu(tiny)
if not hasattr(obj, "smallest_normal"):
obj.smallest_normal = obj._machar.smallest_normal
obj.smallest_subnormal = obj._machar.smallest_subnormal

obj._str_tiny = float_to_str(tiny)
obj._str_smallest_normal = float_to_str(tiny)
obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal)
obj._str_max = float_to_str(max_)
obj._str_epsneg = float_to_str(epsneg)
obj._str_eps = float_to_str(eps)
obj._str_resolution = float_to_str(resolution)
# pylint: enable=protected-access
return obj

_finfo_type_map = {
_bfloat16_dtype: _bfloat16_finfo,
_float4_e2m1fn_dtype: _float4_e2m1fn_finfo,
Expand All @@ -640,6 +696,7 @@ def float_to_str(f):
_float8_e4m3b11fnuz_dtype: _float8_e4m3b11fnuz_finfo,
_float8_e5m2_dtype: _float8_e5m2_finfo,
_float8_e5m2fnuz_dtype: _float8_e5m2fnuz_finfo,
_float8_e8m0fnu_dtype: _float8_e8m0fnu_finfo,
}
_finfo_name_map = {t.name: t for t in _finfo_type_map}

Expand All @@ -656,6 +713,6 @@ def __new__(cls, dtype):

init = cls._finfo_type_map.get(key)
if init is not None:
cls._finfo_cache[dtype] = init()
cls._finfo_cache[dtype] = init.__func__() # pytype: disable=attribute-error
return cls._finfo_cache[dtype]
return super().__new__(cls, dtype)
22 changes: 22 additions & 0 deletions ml_dtypes/_src/dtypes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ struct TypeDescriptor<float4_e2m1fn> : CustomFloatType<float4_e2m1fn> {
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<float8_e8m0fnu> : CustomFloatType<float8_e8m0fnu> {
typedef float8_e8m0fnu T;
static constexpr bool is_floating = true;
static constexpr bool is_integral = false;
static constexpr const char* kTypeName = "float8_e8m0fnu";
static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e8m0fnu";
static constexpr const char* kTpDoc = "float8_e8m0fnu floating-point values";
static constexpr char kNpyDescrKind = 'V';
// TODO(phawkins): there doesn't seem to be a way of guaranteeing a type
// character is unique.
static constexpr char kNpyDescrType = 'W';
static constexpr char kNpyDescrByteorder = '=';
};

template <>
struct TypeDescriptor<int2> : IntNTypeDescriptor<int2> {
typedef int2 T;
Expand Down Expand Up @@ -379,6 +394,9 @@ bool Initialize() {
!RegisterFloatDtype<float4_e2m1fn>(numpy.get())) {
return false;
}
if (!RegisterFloatDtype<float8_e8m0fnu>(numpy.get())) {
return false;
}

if (!RegisterIntNDtype<int2>(numpy.get()) ||
!RegisterIntNDtype<uint2>(numpy.get()) ||
Expand All @@ -393,6 +411,9 @@ bool Initialize() {
float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz,
float8_e5m2, float8_e5m2fnuz, float6_e2m3fn,
float6_e3m2fn, float4_e2m1fn>();
// Only registering to/from BF16 and FP32 for float8_e8m0fnu.
success &= RegisterTwoWayCustomCast<float8_e8m0fnu, bfloat16, float>();
success &= RegisterTwoWayCustomCast<bfloat16, float8_e8m0fnu, float>();
success &= RegisterOneWayCustomCast<int2, int4, int8_t>();
success &= RegisterOneWayCustomCast<uint2, uint4, uint8_t>();
return success;
Expand Down Expand Up @@ -433,6 +454,7 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() {
!InitModuleType<float8_e4m3fnuz>(m.get(), "float8_e4m3fnuz") ||
!InitModuleType<float8_e5m2>(m.get(), "float8_e5m2") ||
!InitModuleType<float8_e5m2fnuz>(m.get(), "float8_e5m2fnuz") ||
!InitModuleType<float8_e8m0fnu>(m.get(), "float8_e8m0fnu") ||
!InitModuleType<bfloat16>(m.get(), "bfloat16") ||
!InitModuleType<int2>(m.get(), "int2") ||
!InitModuleType<int4>(m.get(), "int4") ||
Expand Down
11 changes: 10 additions & 1 deletion ml_dtypes/_src/ufuncs.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ using BitsType = typename GetUnsignedInteger<sizeof(T)>::type;

template <typename T>
std::pair<BitsType<T>, BitsType<T>> SignAndMagnitude(T x) {
const BitsType<T> x_bits = Eigen::numext::bit_cast<BitsType<T>>(x);
// Unsigned floating point format (e.g. E8M0) => no sign bit (zero by
// default).
if constexpr (!std::numeric_limits<T>::is_signed) {
return {BitsType<T>(0), x_bits};
}
// For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without
// flipping the sign. Therefore, we need to explicitly check the
// most-significant bit.
Expand All @@ -332,13 +338,16 @@ std::pair<BitsType<T>, BitsType<T>> SignAndMagnitude(T x) {
constexpr bool has_nan = std::numeric_limits<T>::has_quiet_NaN;
const BitsType<T> x_abs_bits =
Eigen::numext::bit_cast<BitsType<T>>(Eigen::numext::abs(x));
const BitsType<T> x_bits = Eigen::numext::bit_cast<BitsType<T>>(x);
return {has_nan ? x_bits & kSignMask : x_bits ^ x_abs_bits, x_abs_bits};
}

template <typename T>
struct CopySign {
T operator()(T a, T b) {
// Unsigned floating point format => no change.
if constexpr (!std::numeric_limits<T>::is_signed) {
return a;
}
auto [a_sign, a_abs_bits] = SignAndMagnitude(a);
auto [b_sign, b_abs_bits] = SignAndMagnitude(b);
BitsType<T> rep = a_abs_bits | b_sign;
Expand Down
Loading

0 comments on commit d581b6f

Please sign in to comment.