From b6d3659bda3eacb7f56e315eeab41241feeb073f Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 31 Jul 2024 16:23:51 +0100 Subject: [PATCH] Add float8_e8m0_fnu (E8M0) OCP MX scale format Adding the OCP MX scale format `E8M0`, which has the following properties: * Unsigned format; * 8 exponent bits; * Exponent range from -127 to 127; * No zero and infinity; * Single NaN value (0xFF); `ml_dtypes` `float8_base` C++ class is extended to support floating point formats which are unsigned and with no zero (i.e. additional `kIsSigned` and `kHasZero` Traits properties). Base on these traits, `float8_e8m0_fnu` has been implemented using the existing functionalities (convert, unary/binary ops, ...). Float8 Python unit tests have been extended to be able to cover unsigned floating point formats. --- ml_dtypes/__init__.py | 3 + ml_dtypes/_finfo.py | 59 ++++++- ml_dtypes/_src/dtypes.cc | 22 +++ ml_dtypes/_src/ufuncs.h | 11 +- ml_dtypes/include/float8.h | 248 ++++++++++++++++++++++++++- ml_dtypes/tests/custom_float_test.py | 80 +++++++-- ml_dtypes/tests/finfo_test.py | 36 +++- 7 files changed, 430 insertions(+), 29 deletions(-) diff --git a/ml_dtypes/__init__.py b/ml_dtypes/__init__.py index 094b6ca7..cfc10780 100644 --- a/ml_dtypes/__init__.py +++ b/ml_dtypes/__init__.py @@ -27,6 +27,7 @@ "float8_e4m3fnuz", "float8_e5m2", "float8_e5m2fnuz", + "float8_e8m0fnu", "iinfo", "int2", "int4", @@ -49,6 +50,7 @@ from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz from ml_dtypes._ml_dtypes_ext import float8_e5m2 from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz +from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu from ml_dtypes._ml_dtypes_ext import int2 from ml_dtypes._ml_dtypes_ext import int4 from ml_dtypes._ml_dtypes_ext import uint2 @@ -66,6 +68,7 @@ float8_e4m3fnuz: Type[np.generic] float8_e5m2: Type[np.generic] float8_e5m2fnuz: Type[np.generic] +float8_e8m0fnu: Type[np.generic] int2: Type[np.generic] int4: Type[np.generic] uint2: Type[np.generic] diff --git a/ml_dtypes/_finfo.py b/ml_dtypes/_finfo.py index 454969c5..6c2f6651 100644 --- a/ml_dtypes/_finfo.py +++ b/ml_dtypes/_finfo.py @@ -27,6 +27,7 @@ from ml_dtypes._ml_dtypes_ext import float8_e4m3fnuz from ml_dtypes._ml_dtypes_ext import float8_e5m2 from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz +from ml_dtypes._ml_dtypes_ext import float8_e8m0fnu import numpy as np _bfloat16_dtype = np.dtype(bfloat16) @@ -40,6 +41,7 @@ _float8_e4m3fnuz_dtype = np.dtype(float8_e4m3fnuz) _float8_e5m2_dtype = np.dtype(float8_e5m2) _float8_e5m2fnuz_dtype = np.dtype(float8_e5m2fnuz) +_float8_e8m0fnu_dtype = np.dtype(float8_e8m0fnu) class _Bfloat16MachArLike: @@ -141,6 +143,15 @@ def __init__(self): self.smallest_subnormal = float8_e5m2fnuz(smallest_subnormal) +class _Float8E8m0fnuMachArLike: + + def __init__(self): + smallest_normal = float.fromhex("0x1p-127") + self.smallest_normal = float8_e8m0fnu(smallest_normal) + smallest_subnormal = float.fromhex("0x1p-127") + self.smallest_subnormal = float8_e8m0fnu(smallest_subnormal) + + class finfo(np.finfo): # pylint: disable=invalid-name,missing-class-docstring __doc__ = np.finfo.__doc__ _finfo_cache: Dict[type, np.finfo] = {} # pylint: disable=g-bare-generic @@ -628,6 +639,51 @@ def float_to_str(f): # pylint: enable=protected-access return obj + @staticmethod + def _float8_e8m0fnu_finfo(): + def float_to_str(f): + return "%6.2e" % float(f) + + tiny = float.fromhex("0x1p-127") + resolution = 0.1 + eps = float.fromhex("0x1p+0") + epsneg = float.fromhex("0x1p-1") + max_ = float.fromhex("0x1p+127") + + obj = object.__new__(np.finfo) + obj.dtype = _float8_e8m0fnu_dtype + obj.bits = 8 + obj.eps = float8_e8m0fnu(eps) + obj.epsneg = float8_e8m0fnu(epsneg) + obj.machep = 0 + obj.negep = -1 + obj.max = float8_e8m0fnu(max_) + obj.min = float8_e8m0fnu(tiny) + obj.nexp = 8 + obj.nmant = 0 + obj.iexp = obj.nexp + obj.maxexp = 128 + obj.minexp = -127 + obj.precision = 1 + obj.resolution = float8_e8m0fnu(resolution) + # pylint: disable=protected-access + obj._machar = _Float8E8m0fnuMachArLike() + if not hasattr(obj, "tiny"): + obj.tiny = float8_e8m0fnu(tiny) + if not hasattr(obj, "smallest_normal"): + obj.smallest_normal = obj._machar.smallest_normal + obj.smallest_subnormal = obj._machar.smallest_subnormal + + obj._str_tiny = float_to_str(tiny) + obj._str_smallest_normal = float_to_str(tiny) + obj._str_smallest_subnormal = float_to_str(obj.smallest_subnormal) + obj._str_max = float_to_str(max_) + obj._str_epsneg = float_to_str(epsneg) + obj._str_eps = float_to_str(eps) + obj._str_resolution = float_to_str(resolution) + # pylint: enable=protected-access + return obj + _finfo_type_map = { _bfloat16_dtype: _bfloat16_finfo, _float4_e2m1fn_dtype: _float4_e2m1fn_finfo, @@ -640,6 +696,7 @@ def float_to_str(f): _float8_e4m3b11fnuz_dtype: _float8_e4m3b11fnuz_finfo, _float8_e5m2_dtype: _float8_e5m2_finfo, _float8_e5m2fnuz_dtype: _float8_e5m2fnuz_finfo, + _float8_e8m0fnu_dtype: _float8_e8m0fnu_finfo, } _finfo_name_map = {t.name: t for t in _finfo_type_map} @@ -656,6 +713,6 @@ def __new__(cls, dtype): init = cls._finfo_type_map.get(key) if init is not None: - cls._finfo_cache[dtype] = init() + cls._finfo_cache[dtype] = init.__func__() return cls._finfo_cache[dtype] return super().__new__(cls, dtype) diff --git a/ml_dtypes/_src/dtypes.cc b/ml_dtypes/_src/dtypes.cc index a48fb17f..06213dba 100644 --- a/ml_dtypes/_src/dtypes.cc +++ b/ml_dtypes/_src/dtypes.cc @@ -216,6 +216,21 @@ struct TypeDescriptor : CustomFloatType { static constexpr char kNpyDescrByteorder = '='; }; +template <> +struct TypeDescriptor : CustomFloatType { + typedef float8_e8m0fnu T; + static constexpr bool is_floating = true; + static constexpr bool is_integral = false; + static constexpr const char* kTypeName = "float8_e8m0fnu"; + static constexpr const char* kQualifiedTypeName = "ml_dtypes.float8_e8m0fnu"; + static constexpr const char* kTpDoc = "float8_e8m0fnu floating-point values"; + static constexpr char kNpyDescrKind = 'V'; + // TODO(phawkins): there doesn't seem to be a way of guaranteeing a type + // character is unique. + static constexpr char kNpyDescrType = 'W'; + static constexpr char kNpyDescrByteorder = '='; +}; + template <> struct TypeDescriptor : IntNTypeDescriptor { typedef int2 T; @@ -379,6 +394,9 @@ bool Initialize() { !RegisterFloatDtype(numpy.get())) { return false; } + if (!RegisterFloatDtype(numpy.get())) { + return false; + } if (!RegisterIntNDtype(numpy.get()) || !RegisterIntNDtype(numpy.get()) || @@ -393,6 +411,9 @@ bool Initialize() { float8_e4m3b11fnuz, float8_e4m3fn, float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz, float6_e2m3fn, float6_e3m2fn, float4_e2m1fn>(); + // Only registering to/from BF16 and FP32 for float8_e8m0fnu. + success &= RegisterTwoWayCustomCast(); + success &= RegisterTwoWayCustomCast(); success &= RegisterOneWayCustomCast(); success &= RegisterOneWayCustomCast(); return success; @@ -433,6 +454,7 @@ extern "C" EXPORT_SYMBOL PyObject* PyInit__ml_dtypes_ext() { !InitModuleType(m.get(), "float8_e4m3fnuz") || !InitModuleType(m.get(), "float8_e5m2") || !InitModuleType(m.get(), "float8_e5m2fnuz") || + !InitModuleType(m.get(), "float8_e8m0fnu") || !InitModuleType(m.get(), "bfloat16") || !InitModuleType(m.get(), "int2") || !InitModuleType(m.get(), "int4") || diff --git a/ml_dtypes/_src/ufuncs.h b/ml_dtypes/_src/ufuncs.h index 19a7a00e..9672a5ac 100644 --- a/ml_dtypes/_src/ufuncs.h +++ b/ml_dtypes/_src/ufuncs.h @@ -322,6 +322,12 @@ using BitsType = typename GetUnsignedInteger::type; template std::pair, BitsType> SignAndMagnitude(T x) { + const BitsType x_bits = Eigen::numext::bit_cast>(x); + // Unsigned floating point format (e.g. E8M0) => no sign bit (zero by + // default). + if constexpr (!std::numeric_limits::is_signed) { + return {BitsType(0), x_bits}; + } // For types that represent NaN by -0, (i.e. *fnuz), abs(x) remains -0 without // flipping the sign. Therefore, we need to explicitly check the // most-significant bit. @@ -332,13 +338,16 @@ std::pair, BitsType> SignAndMagnitude(T x) { constexpr bool has_nan = std::numeric_limits::has_quiet_NaN; const BitsType x_abs_bits = Eigen::numext::bit_cast>(Eigen::numext::abs(x)); - const BitsType x_bits = Eigen::numext::bit_cast>(x); return {has_nan ? x_bits & kSignMask : x_bits ^ x_abs_bits, x_abs_bits}; } template struct CopySign { T operator()(T a, T b) { + // Unsigned floating point format => no change. + if constexpr (!std::numeric_limits::is_signed) { + return a; + } auto [a_sign, a_abs_bits] = SignAndMagnitude(a); auto [b_sign, b_abs_bits] = SignAndMagnitude(b); BitsType rep = a_abs_bits | b_sign; diff --git a/ml_dtypes/include/float8.h b/ml_dtypes/include/float8.h index 4d32305e..cde14ed6 100644 --- a/ml_dtypes/include/float8.h +++ b/ml_dtypes/include/float8.h @@ -18,6 +18,8 @@ limitations under the License. // 8-bit Floating Point Interchange Format, as described by // https://arxiv.org/abs/2209.05433 +// https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1 +// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf #include #include @@ -51,6 +53,7 @@ class float8_e4m3fnuz; class float8_e4m3b11fnuz; class float8_e5m2; class float8_e5m2fnuz; +class float8_e8m0fnu; template class float8_base { @@ -423,6 +426,78 @@ class float8_e5m2fnuz : public float8_base { explicit EIGEN_DEVICE_FUNC operator bool() const { return rep() != 0; } }; +class float8_e8m0fnu : public float8_base { + // 8-bit floating point with 8 bit exponent, no sign and zero mantissa. + // + // See: + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + // + // An 8-bit floating point type with no sign bit, 8 bits exponent and 0 bits + // mantissa. The suffix "fnuz" is consistent with LLVM/MLIR naming and is + // derived from the differences to IEEE floating point conventions. `F` is + // for "finite" (no infinities), `N` for with special NaN encoding, `U` for + // unsigned. + // + // This type has the following characteristics: + // * bit encoding: S0E8M0 - `0bEEEEEEEE` + // * exponent bias: 127 + // * infinities: Not supported + // * NaNs: Supported with exponent bits set to 1s - `0b11111111` + private: + using Base = float8_base; + friend class float8_base; + using Base::Base; + + public: + template = 0> + explicit EIGEN_DEVICE_FUNC float8_e8m0fnu(T f8) + : float8_e8m0fnu(ConvertFrom(f8)) {} + + constexpr float8_e8m0fnu operator-() const { + // No negative numbers supported in E8M0 => NaN + return float8_e8m0fnu::FromRep(0xFF); + } + + float8_e8m0fnu operator-(const float8_e8m0fnu& other) const { + return Base::operator-(other); + } + + explicit EIGEN_DEVICE_FUNC operator bool() const { + // No zero supported in E8M0 format. + return true; + } + + // Comparison simplified to uint8_t compare. + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<( + const float8_e8m0fnu& other) const { + if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { + return false; + } + return rep() < other.rep(); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=( + const float8_e8m0fnu& other) const { + if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { + return false; + } + return rep() <= other.rep(); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>( + const float8_e8m0fnu& other) const { + if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { + return false; + } + return rep() > other.rep(); + } + EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=( + const float8_e8m0fnu& other) const { + if (Eigen::numext::isnan(*this) || Eigen::numext::isnan(other)) { + return false; + } + return rep() >= other.rep(); + } +}; + constexpr double ConstexprAbs(double x) { return x < 0.0 ? -x : x; } constexpr double ConstexprCeil(double x) { @@ -472,9 +547,11 @@ constexpr int MinExponent10FromMinExponent(int min_exponent) { // emax * log10(2)) constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, int digits) { - // We only support digits in {2,5}. This table would grow if we wanted to - // handle more values. + // We only support digits in {1,2,3,4,5}. This table would grow if we wanted + // to handle more values. constexpr double kLog10OfOnePredecessor[] = { + // log10(1 - 2**-1) + -0.3010299956639812, // log10(1 - 2**-2) -0.12493873660829993, // log10(1 - 2**-3) @@ -484,7 +561,7 @@ constexpr int MaxExponent10FromMaxExponentAndDigits(int max_exponent, // log10(1 - 2**-5) -0.013788284485633295, }; - return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 2] + + return static_cast(ConstexprFloor(kLog10OfOnePredecessor[digits - 1] + max_exponent * kLog10Of2)); } @@ -929,6 +1006,64 @@ struct numeric_limits_float8_e5m2fnuz : public numeric_limits_float8_base { } }; +struct numeric_limits_float8_e8m0fnu : public numeric_limits_float8_base { + private: + static inline constexpr const int kExponentBias = 127; + static inline constexpr const int kMantissaBits = 0; + + public: + // NOLINTBEGIN: these names must match std::numeric_limits. + static inline constexpr const bool is_signed = false; + static inline constexpr const std::float_denorm_style has_denorm = + std::denorm_absent; + static inline constexpr const int digits = kMantissaBits + 1; + static inline constexpr const int digits10 = Digits10FromDigits(digits); + static inline constexpr const int max_digits10 = + MaxDigits10FromDigits(digits); + // 2**-127 smallest valid normalized value.. + static inline constexpr const int min_exponent = -127 + 1; + static inline constexpr const int min_exponent10 = + MinExponent10FromMinExponent(min_exponent); + // 128 encoding using for NaN + static inline constexpr const int max_exponent = 127; + static inline constexpr const int max_exponent10 = + MaxExponent10FromMaxExponentAndDigits(max_exponent, digits); + static inline constexpr const bool is_iec559 = false; + static inline constexpr const bool has_infinity = false; + static inline constexpr const bool has_signaling_NaN = false; + // NOLINTEND + + static constexpr float8_e8m0fnu min() { + return float8_e8m0fnu::FromRep(0x00); + } + static constexpr float8_e8m0fnu lowest() { + return float8_e8m0fnu::FromRep(0x00); + } + static constexpr float8_e8m0fnu max() { + return float8_e8m0fnu::FromRep(0xfe); + } + static constexpr float8_e8m0fnu epsilon() { + return float8_e8m0fnu::FromRep((-kMantissaBits + kExponentBias) + << kMantissaBits); + } + static constexpr float8_e8m0fnu round_error() { + return float8_e8m0fnu::FromRep((-1 + kExponentBias) << kMantissaBits); + } + static constexpr float8_e8m0fnu infinity() { + return float8_e8m0fnu::FromRep(0xFF); + } // NaN. + static constexpr float8_e8m0fnu quiet_NaN() { + return float8_e8m0fnu::FromRep(0xFF); + } + static constexpr float8_e8m0fnu signaling_NaN() { + return float8_e8m0fnu::FromRep(0xFF); + } + static constexpr float8_e8m0fnu denorm_min() { + // No denorm => smallest value. + return float8_e8m0fnu::FromRep(0x00); + } +}; + } // namespace float8_internal } // namespace ml_dtypes @@ -961,6 +1096,10 @@ struct numeric_limits template <> struct numeric_limits : public ml_dtypes::float8_internal::numeric_limits_float8_e5m2fnuz {}; + +template <> +struct numeric_limits + : public ml_dtypes::float8_internal::numeric_limits_float8_e8m0fnu {}; } // namespace std namespace ml_dtypes { @@ -1028,6 +1167,12 @@ constexpr inline bool(isnan)(const float8_e5m2fnuz& a) { return a.rep() == 0x80; } +constexpr inline float8_e8m0fnu abs(const float8_e8m0fnu& a) { return a; } + +constexpr inline bool(isnan)(const float8_e8m0fnu& a) { + return a.rep() == 0xff; +} + template constexpr inline bool(isinf)(const float8_base& a) { if constexpr (std::numeric_limits::has_infinity) { @@ -1054,6 +1199,32 @@ std::ostream& operator<<(std::ostream& os, const float8_base& f8) { // Inline conversion routines between float8 and other types. //============================================================================== +template +bool constexpr IsPowerOfTwo(T x) { + return (x != 0) && ((x & (x - 1)) == 0); +} +// Helper for getting a bytes size which is a power of two. +template +struct NextPowerOfTwo { + static constexpr int value = Size; +}; +template <> +struct NextPowerOfTwo<3> { + static constexpr int value = 4; +}; +template <> +struct NextPowerOfTwo<5> { + static constexpr int value = 8; +}; +template <> +struct NextPowerOfTwo<6> { + static constexpr int value = 8; +}; +template <> +struct NextPowerOfTwo<7> { + static constexpr int value = 8; +}; + // Helper for getting a bit representation provided a byte size. template using GetUnsignedInteger = @@ -1079,9 +1250,13 @@ struct ConvertImpl struct TraitsBase { using BitsType = GetUnsignedInteger; + static constexpr bool kIsSigned = std::numeric_limits::is_signed; + static constexpr bool kHasZero = true; + static constexpr int kBits = sizeof(Float) * CHAR_BIT; static constexpr int kMantissaBits = Eigen::NumTraits::digits() - 1; - static constexpr int kExponentBits = kBits - kMantissaBits - 1; + // Extra bit used in exponent for unsigned float. + static constexpr int kExponentBits = kBits - kMantissaBits - int(kIsSigned); static constexpr BitsType kExponentMask = ((BitsType{1} << kExponentBits) - 1) << kMantissaBits; static constexpr BitsType kMantissaMask = (BitsType{1} << kMantissaBits) - 1; @@ -1108,6 +1283,13 @@ struct Traits : public TraitsBase { static constexpr int kExponentBias = Base::kExponentBias + 1; }; +template <> +struct Traits : public TraitsBase { + using Base = TraitsBase; + // No zero in E8MO OCP MX format description. + static constexpr bool kHasZero = false; +}; + template constexpr inline Bits RoundBitsToNearestEven(Bits bits, int roundoff) { // Round to nearest even by adding a bias term. @@ -1190,6 +1372,9 @@ struct ConvertImpl>> { using FromTraits = Traits; using FromBits = typename FromTraits::BitsType; + static constexpr bool kFromIsSigned = FromTraits::kIsSigned; + static constexpr bool kFromHasZero = FromTraits::kHasZero; + static constexpr int kFromBits = FromTraits::kBits; static constexpr int kFromMantissaBits = FromTraits::kMantissaBits; static constexpr int kFromExponentBits = FromTraits::kExponentBits; static constexpr int kFromExponentBias = FromTraits::kExponentBias; @@ -1197,6 +1382,9 @@ struct ConvertImpl; using ToBits = typename ToTraits::BitsType; + static constexpr bool kToIsSigned = ToTraits::kIsSigned; + static constexpr bool kToHasZero = ToTraits::kHasZero; + static constexpr int kToBits = ToTraits::kBits; static constexpr int kToMantissaBits = ToTraits::kMantissaBits; static constexpr int kToExponentBits = ToTraits::kExponentBits; static constexpr int kToExponentBias = ToTraits::kExponentBias; @@ -1207,15 +1395,22 @@ struct ConvertImpl::value; + using WideBits = GetUnsignedInteger; + static_assert(!std::is_void_v, + "`WideBits` type can not be void type."); + static constexpr int kExponentOffset = kToExponentBias - kFromExponentBias; static constexpr int kDigitShift = kToMantissaBits - kFromMantissaBits; static EIGEN_DEVICE_FUNC inline To run(From from) { // Shift bits to destination type, without sign bit. const bool from_sign_bit = - Eigen::numext::bit_cast(from) >> (FromTraits::kBits - 1); + Eigen::numext::bit_cast(from) >> (kFromBits - 1) && + kFromIsSigned; const FromBits from_bits = Eigen::numext::bit_cast(Eigen::numext::abs(from)); @@ -1228,8 +1423,22 @@ struct ConvertImpl::quiet_NaN() : Eigen::NumTraits::quiet_NaN(); } - if (from_bits == 0) { - return from_sign_bit ? -To{} : To{}; + // Dealing with zero, when `From` has one. + if (from_bits == 0 && kFromHasZero) { + if constexpr (kToHasZero) { + // Keep the sign, if `To` supports it. + return from_sign_bit && kToIsSigned ? -To{} : To{}; + } else { + return kSaturate ? std::numeric_limits::denorm_min() + : Eigen::NumTraits::quiet_NaN(); + } + } + // `To` unsigned floating format: NaN or saturate. + if constexpr (!kToIsSigned && kFromIsSigned) { + if (from_sign_bit) { + return kSaturate ? std::numeric_limits::lowest() + : Eigen::NumTraits::quiet_NaN(); + } } const int biased_from_exponent = from_bits >> kFromMantissaBits; @@ -1284,7 +1493,9 @@ struct ConvertImpl 0 ? 1 : 0); + // Zero exponent valid if From has no zero representation. + FromBits from_has_leading_one = + (biased_from_exponent > 0 || !kFromHasZero ? 1 : 0); int exponent_shift = -kDigitShift - biased_to_exponent + from_has_leading_one; // Insert the implicit leading 1 bit on the mantissa for normalized @@ -1473,6 +1684,7 @@ using float8_e4m3fnuz = float8_internal::float8_e4m3fnuz; using float8_e4m3b11fnuz = float8_internal::float8_e4m3b11fnuz; using float8_e5m2 = float8_internal::float8_e5m2; using float8_e5m2fnuz = float8_internal::float8_e5m2fnuz; +using float8_e8m0fnu = float8_internal::float8_e8m0fnu; } // namespace ml_dtypes @@ -1575,6 +1787,12 @@ EIGEN_DEVICE_FUNC inline bool isinf_impl( return ml_dtypes::float8_internal::isinf(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isinf_impl( + const ml_dtypes::float8_e8m0fnu& x) { + return ml_dtypes::float8_internal::isinf(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isnan_impl( const ml_dtypes::float8_e3m4& x) { @@ -1617,6 +1835,12 @@ EIGEN_DEVICE_FUNC inline bool isnan_impl( return ml_dtypes::float8_internal::isnan(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isnan_impl( + const ml_dtypes::float8_e8m0fnu& x) { + return ml_dtypes::float8_internal::isnan(x); +} + template <> EIGEN_DEVICE_FUNC inline bool isfinite_impl( const ml_dtypes::float8_e3m4& x) { @@ -1659,6 +1883,12 @@ EIGEN_DEVICE_FUNC inline bool isfinite_impl( return ml_dtypes::float8_internal::isfinite(x); } +template <> +EIGEN_DEVICE_FUNC inline bool isfinite_impl( + const ml_dtypes::float8_e8m0fnu& x) { + return ml_dtypes::float8_internal::isfinite(x); +} + } // namespace internal } // namespace Eigen diff --git a/ml_dtypes/tests/custom_float_test.py b/ml_dtypes/tests/custom_float_test.py index f30d7a7f..36c5d8a4 100644 --- a/ml_dtypes/tests/custom_float_test.py +++ b/ml_dtypes/tests/custom_float_test.py @@ -41,6 +41,7 @@ float8_e4m3fnuz = ml_dtypes.float8_e4m3fnuz float8_e5m2 = ml_dtypes.float8_e5m2 float8_e5m2fnuz = ml_dtypes.float8_e5m2fnuz +float8_e8m0fnu = ml_dtypes.float8_e8m0fnu try: @@ -116,6 +117,11 @@ def dtype_has_nan(dtype): return False +def dtype_is_signed(dtype): + """Determines if the floating dtype has a sign bit.""" + return ml_dtypes.finfo(dtype).min < 0 + + FLOAT_DTYPES = [ bfloat16, float4_e2m1fn, @@ -128,6 +134,7 @@ def dtype_has_nan(dtype): float8_e4m3fnuz, float8_e5m2, float8_e5m2fnuz, + float8_e8m0fnu, ] # Values that should round trip exactly to float and back. @@ -157,6 +164,16 @@ def dtype_has_nan(dtype): ] for dtype in FLOAT_DTYPES } +# E8M0 specific values +FLOAT_VALUES[float8_e8m0fnu] = [ + 0.125, + 1.0, + 0.5, + 1.0 + float(ml_dtypes.finfo(float8_e8m0fnu).eps), + 4, + float(ml_dtypes.finfo(float8_e8m0fnu).max), + float("nan"), +] # Remove values unsupported by some types. FLOAT_VALUES[float4_e2m1fn] = [ @@ -200,6 +217,7 @@ def dtype_has_nan(dtype): range(1 << n, 2 << n, 1 << max(0, n - 2)) for n in range(16) ) ), + float8_e8m0fnu: [1, 2, 256], } @@ -236,6 +254,9 @@ def testRoundTripToFloat(self, float_type): def testRoundTripNumpyTypes(self, float_type): for dtype in [np.float16, np.float32, np.float64, np.longdouble]: for f in FLOAT_VALUES[float_type]: + # Ignore values converting to NaN/Inf + if np.abs(f) > np.finfo(dtype).max: + continue np.testing.assert_equal(dtype(f), dtype(float_type(dtype(f)))) np.testing.assert_equal(float(dtype(f)), float(float_type(dtype(f)))) np.testing.assert_equal(dtype(f), dtype(float_type(np.array(f, dtype)))) @@ -248,7 +269,8 @@ def testRoundTripNumpyTypes(self, float_type): def testRoundTripToInt(self, float_type): for v in INT_VALUES[float_type]: self.assertEqual(v, int(float_type(v))) - self.assertEqual(-v, int(float_type(-v))) + if dtype_is_signed(float_type): + self.assertEqual(-v, int(float_type(-v))) @ignore_warning(category=RuntimeWarning, message="overflow encountered") def testRoundTripToNumpy(self, float_type): @@ -261,12 +283,18 @@ def testRoundTripToNumpy(self, float_type): ]: with self.subTest(dtype.__name__): for v in FLOAT_VALUES[float_type]: + if np.abs(v) > ml_dtypes.finfo(dtype).max: + continue np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v)))) np.testing.assert_equal(dtype(v), dtype(float_type(dtype(v)))) np.testing.assert_equal( dtype(v), dtype(float_type(np.array(v, dtype))) ) - if dtype != float_type: + + if ( + dtype != float_type + and ml_dtypes.finfo(float_type).max <= ml_dtypes.finfo(dtype).max + ): np.testing.assert_equal( np.array(FLOAT_VALUES[float_type], dtype), float_type(np.array(FLOAT_VALUES[float_type], dtype)).astype( @@ -276,6 +304,12 @@ def testRoundTripToNumpy(self, float_type): def testCastBetweenCustomTypes(self, float_type): for dtype in FLOAT_DTYPES: + # float8_e8m0 only registering cast <=> bfloat16 + if ( + float_type == float8_e8m0fnu or dtype == float8_e8m0fnu + ) and dtype != bfloat16: + continue + x = np.array(FLOAT_VALUES[float_type], dtype=dtype) y = x.astype(float_type) z = x.astype(float).astype(float_type) @@ -307,6 +341,9 @@ def testItem(self, float_type): def testHashZero(self, float_type): """Tests that negative zero and zero hash to the same value.""" + if float_type == float8_e8m0fnu: + raise self.skipTest("Skip hash zero test for E8M0 datatype.") + self.assertEqual(hash(float_type(-0.0)), hash(float_type(0.0))) def testHashNumbers(self, float_type): @@ -663,9 +700,9 @@ def testDeepCopyDoesNotAlterHash(self, float_type): self.assertEqual(h, hash(dtype)) def testArray(self, float_type): - x = np.array([[1, 2, 3]], dtype=float_type) + x = np.array([[1, 2, 4]], dtype=float_type) self.assertEqual(float_type, x.dtype) - self.assertEqual("[[1 2 3]]", str(x)) + self.assertEqual("[[1 2 4]]", str(x)) np.testing.assert_equal(x, x) numpy_assert_allclose(x, x, float_type=float_type) self.assertTrue((x == x).all()) @@ -673,9 +710,15 @@ def testArray(self, float_type): def testComparisons(self, float_type): x0, x1, y0 = 6, 1, 3 x = np.array([x0, x1, -x0], dtype=np.float32) - bx = x.astype(float_type) y = np.array([y0, x1, 0], dtype=np.float32) + + if float_type == float8_e8m0fnu: + x = np.array([30, 7, 1], dtype=np.float32) + y = np.array([17, 7, 0.125], dtype=np.float32) + + bx = x.astype(float_type) by = y.astype(float_type) + np.testing.assert_equal(x == y, bx == by) np.testing.assert_equal(x != y, bx != by) np.testing.assert_equal(x < y, bx < by) @@ -757,7 +800,7 @@ def testCasts(self, float_type): np.uintc, np.ulonglong, ]: - x = np.array([[1, 2, 3]], dtype=dtype) + x = np.array([[1, 2, 4]], dtype=dtype) y = x.astype(float_type) z = y.astype(dtype) self.assertTrue(np.all(x == y)) @@ -768,7 +811,7 @@ def testCasts(self, float_type): @ignore_warning(category=ComplexWarning) def testConformNumpyComplex(self, float_type): for dtype in [np.complex64, np.complex128, np.clongdouble]: - x = np.array([0.5, 1.5 + 2.0j, 4.0], dtype=dtype) + x = np.array([0.5, 1.0 + 2.0j, 4.0], dtype=dtype) y_np = x.astype(np.float32) y_tf = x.astype(float_type) numpy_assert_allclose(y_np, y_tf, atol=2e-2, float_type=float_type) @@ -779,9 +822,12 @@ def testConformNumpyComplex(self, float_type): def testArange(self, float_type): np.testing.assert_equal( - np.arange(100, dtype=np.float32).astype(float_type), - np.arange(100, dtype=float_type), + np.arange(1, 100, dtype=np.float32).astype(float_type), + np.arange(1, 100, dtype=float_type), ) + if float_type == float8_e8m0fnu: + raise self.skipTest("Skip negative ranges for E8M0.") + np.testing.assert_equal( np.arange(-6, 6, 2, dtype=np.float32).astype(float_type), np.arange(-6, 6, 2, dtype=float_type), @@ -852,7 +898,11 @@ def testDivmod(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) y = rng.randn(4, 1, 7).astype(float_type) + + x = np.where(np.isfinite(x), x, float_type(1)) + y = np.where(np.isfinite(y), y, float_type(1)) y = np.where(y == 0, float_type(1), y) + o1, o2 = np.divmod(x, y) e1, e2 = np.divmod(x.astype(np.float32), y.astype(np.float32)) numpy_assert_allclose( @@ -944,6 +994,7 @@ def testLdexp(self, float_type): def testFrexp(self, float_type): rng = np.random.RandomState(seed=42) x = rng.randn(3, 7).astype(float_type) + x = np.where(np.isfinite(x), x, float_type(1)) mant1, exp1 = np.frexp(x) mant2, exp2 = np.frexp(x.astype(np.float32)) np.testing.assert_equal(exp1, exp2) @@ -956,6 +1007,9 @@ def testFrexp(self, float_type): numpy_assert_allclose(mant1, mant2, float_type=float_type, **kwargs) def testCopySign(self, float_type): + if not dtype_is_signed(float_type): + raise self.skipTest("Skip copy sign test for unsigned floating formats.") + bits_type = np.uint16 if float_type == bfloat16 else np.uint8 bit_size = ml_dtypes.finfo(float_type).bits bit_sign = 1 << (bit_size - 1) @@ -979,8 +1033,9 @@ def testNextAfter(self, float_type): ) np.testing.assert_equal(np.nextafter(one, one), one) smallest_denormal = ml_dtypes.finfo(float_type).smallest_subnormal - np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal) - np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal) + if dtype_is_signed(float_type): + np.testing.assert_equal(np.nextafter(zero, one), smallest_denormal) + np.testing.assert_equal(np.nextafter(zero, -one), -smallest_denormal) if dtype_has_nan(float_type): nan = np.array(np.nan, dtype=float_type) @@ -1017,7 +1072,8 @@ def testSpacing(self, float_type): power_of_two = float_type(2.0**i) distance = ml_dtypes.finfo(float_type).eps * power_of_two np.testing.assert_equal(np.spacing(power_of_two), distance) - np.testing.assert_equal(np.spacing(-power_of_two), -distance) + if dtype_is_signed(float_type): + np.testing.assert_equal(np.spacing(-power_of_two), -distance) # Check that spacing agrees with arithmetic involving nextafter. with self.subTest(name="NextAfter"): diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 0823b471..d15311c1 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -30,6 +30,7 @@ ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e5m2, ml_dtypes.float8_e5m2fnuz, + ml_dtypes.float8_e8m0fnu, ] DTYPES_WITH_NO_INFINITY = [ @@ -37,6 +38,7 @@ ml_dtypes.float8_e4m3fn, ml_dtypes.float8_e4m3fnuz, ml_dtypes.float8_e5m2fnuz, + ml_dtypes.float8_e8m0fnu, ] DTYPES_WITH_NO_INFINITY_AND_NO_NAN = [ @@ -96,21 +98,39 @@ def assert_zero(val): if info.bits >= 8: self.assertEqual(info.bits, np.array(0, dtype).itemsize * 8) - self.assertEqual(info.nmant + info.nexp + 1, info.bits) + # Unsigned float => no sign bit. + if info.min >= 0.0: + self.assertEqual(info.nmant + info.nexp, info.bits) + else: + self.assertEqual(info.nmant + info.nexp + 1, info.bits) assert_representable(info.tiny) assert_representable(info.max) assert_representable(info.min) if dtype not in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: assert_infinite(np.spacing(info.max)) + assert info.max > 0 + + if info.min < 0 and dtype not in DTYPES_WITH_NO_INFINITY_AND_NO_NAN: + # Only valid for signed floating format. assert_infinite(-np.spacing(info.min)) + elif info.min > 0: + # No zero in floating point format. + assert_infinite(0) + assert_infinite(make_val(-1)) + elif info.min == 0: + # Zero supported, but not negative values. + self.assertEqual(make_val(0), 0) + assert_infinite(make_val(-1)) assert_representable(2.0 ** (info.maxexp - 1)) assert_infinite(2.0**info.maxexp) assert_representable(info.smallest_subnormal) - assert_zero(info.smallest_subnormal * 0.5) + if info.min < 0: + assert_zero(info.smallest_subnormal * 0.5) + self.assertGreater(info.smallest_normal, 0) self.assertEqual(info.tiny, info.smallest_normal) # Identities according to the documentation: @@ -119,11 +139,15 @@ def assert_zero(val): self.assertEqual(info.eps, make_val(2**info.machep)) self.assertEqual(info.iexp, info.nexp) - # Check that minexp is consistent with nmant - self.assertEqual( - make_val(2**info.minexp).view(UINT_TYPES[info.bits]), - 2**info.nmant, + is_min_exponent_valid_normal = ( + make_val(2**info.minexp) == info.smallest_normal ) + # Check that minexp is consistent with nmant (subnormal representation) + if not is_min_exponent_valid_normal and info.nmant > 0: + self.assertEqual( + make_val(2**info.minexp).view(UINT_TYPES[info.bits]), + 2**info.nmant, + ) if __name__ == "__main__":