From 850fae2c3b14c6afa9f26ed695ae75d9effc9f4c Mon Sep 17 00:00:00 2001 From: Durgadoss R Date: Thu, 29 Aug 2024 18:09:14 +0530 Subject: [PATCH] [APFloat] Add APFloat support for E8M0 type This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format. * The bias calculation is different and convertIEEE* APIs are updated to handle this. * Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly. * Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation. * Many utility functions are updated to handle the fact that this format does not support Zero. * Provide a separate initFromAPInt() implementation to handle the quirks of the format. * Add specific tests to verify the range of values for this format. Signed-off-by: Durgadoss R --- llvm/include/llvm/ADT/APFloat.h | 35 ++++ llvm/lib/Support/APFloat.cpp | 137 ++++++++++--- llvm/unittests/ADT/APFloatTest.cpp | 315 +++++++++++++++++++++++++++-- 3 files changed, 445 insertions(+), 42 deletions(-) diff --git a/llvm/include/llvm/ADT/APFloat.h b/llvm/include/llvm/ADT/APFloat.h index 7039e961bff82d..be9ddcf78cac65 100644 --- a/llvm/include/llvm/ADT/APFloat.h +++ b/llvm/include/llvm/ADT/APFloat.h @@ -195,6 +195,12 @@ struct APFloatBase { // improved range compared to half (16-bit) formats, at (potentially) // greater throughput than single precision (32-bit) formats. S_FloatTF32, + // 8-bit floating point number with (all the) 8 bits for the exponent + // like in FP32. There are no zeroes, no infinities, and no denormal values. + // NaN is represented with all bits set to 1. Bias is 127. + // This represents the scale data type in the MX specification from + // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + S_Float8E8M0FN, // 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 // types, there are no infinity or NaN values. The format is detailed in // https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf @@ -229,6 +235,7 @@ struct APFloatBase { static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE; static const fltSemantics &Float8E3M4() LLVM_READNONE; static const fltSemantics &FloatTF32() LLVM_READNONE; + static const fltSemantics &Float8E8M0FN() LLVM_READNONE; static const fltSemantics &Float6E3M2FN() LLVM_READNONE; static const fltSemantics &Float6E2M3FN() LLVM_READNONE; static const fltSemantics &Float4E2M1FN() LLVM_READNONE; @@ -591,6 +598,7 @@ class IEEEFloat final : public APFloatBase { unsigned int significandLSB() const; unsigned int significandMSB() const; void zeroSignificand(); + unsigned int getNumHighBits() const; /// Return true if the significand excluding the integral bit is all ones. bool isSignificandAllOnes() const; bool isSignificandAllOnesExceptLSB() const; @@ -652,6 +660,7 @@ class IEEEFloat final : public APFloatBase { APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const; APInt convertFloat8E3M4APFloatToAPInt() const; APInt convertFloatTF32APFloatToAPInt() const; + APInt convertFloat8E8M0FNAPFloatToAPInt() const; APInt convertFloat6E3M2FNAPFloatToAPInt() const; APInt convertFloat6E2M3FNAPFloatToAPInt() const; APInt convertFloat4E2M1FNAPFloatToAPInt() const; @@ -672,6 +681,7 @@ class IEEEFloat final : public APFloatBase { void initFromFloat8E4M3B11FNUZAPInt(const APInt &api); void initFromFloat8E3M4APInt(const APInt &api); void initFromFloatTF32APInt(const APInt &api); + void initFromFloat8E8M0FNAPInt(const APInt &api); void initFromFloat6E3M2FNAPInt(const APInt &api); void initFromFloat6E2M3FNAPInt(const APInt &api); void initFromFloat4E2M1FNAPInt(const APInt &api); @@ -1079,6 +1089,9 @@ class APFloat : public APFloatBase { /// \param Semantics - type float semantics static APFloat getAllOnesValue(const fltSemantics &Semantics); + /// Returns true if the given semantics supports either NaN or Infinity. + /// + /// \param Sem - type float semantics static bool hasNanOrInf(const fltSemantics &Sem) { switch (SemanticsToEnum(Sem)) { default: @@ -1091,6 +1104,28 @@ class APFloat : public APFloatBase { } } + /// Returns true if the given semantics can represent Zero. + /// + /// \param Sem - type float semantics + static bool hasZero(const fltSemantics &Sem) { + return &Sem != &Float8E8M0FN(); + } + + /// Returns true if the given semantics has actual significand. + /// + /// \param Sem - type float semantics + static bool hasSignificand(const fltSemantics &Sem) { + return &Sem != &Float8E8M0FN(); + } + + /// Returns true if the given semantics has only exponent + /// and no significand. + /// + /// \param Sem - type float semantics + static bool hasExponentOnly(const fltSemantics &Sem) { + return !hasSignificand(Sem); + } + /// Used to insert APFloat objects, or objects that contain APFloat objects, /// into FoldingSets. void Profile(FoldingSetNodeID &NID) const; diff --git a/llvm/lib/Support/APFloat.cpp b/llvm/lib/Support/APFloat.cpp index 7f68c5ab9b7cf7..9254a86f3a9946 100644 --- a/llvm/lib/Support/APFloat.cpp +++ b/llvm/lib/Support/APFloat.cpp @@ -145,6 +145,9 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = { 4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero}; static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8}; static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19}; +static constexpr fltSemantics semFloat8E8M0FN = { + 127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes}; + static constexpr fltSemantics semFloat6E3M2FN = { 4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly}; static constexpr fltSemantics semFloat6E2M3FN = { @@ -222,6 +225,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) { return Float8E3M4(); case S_FloatTF32: return FloatTF32(); + case S_Float8E8M0FN: + return Float8E8M0FN(); case S_Float6E3M2FN: return Float6E3M2FN(); case S_Float6E2M3FN: @@ -264,6 +269,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) { return S_Float8E3M4; else if (&Sem == &llvm::APFloat::FloatTF32()) return S_FloatTF32; + else if (&Sem == &llvm::APFloat::Float8E8M0FN()) + return S_Float8E8M0FN; else if (&Sem == &llvm::APFloat::Float6E3M2FN()) return S_Float6E3M2FN; else if (&Sem == &llvm::APFloat::Float6E2M3FN()) @@ -294,6 +301,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() { } const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; } const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; } +const fltSemantics &APFloatBase::Float8E8M0FN() { return semFloat8E8M0FN; } const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; } const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; } const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; } @@ -396,7 +404,8 @@ static inline Error createError(const Twine &Err) { } static constexpr inline unsigned int partCountForBits(unsigned int bits) { - return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth; + return std::max(1u, ((bits) + APFloatBase::integerPartWidth - 1) / + APFloatBase::integerPartWidth); } /* Returns 0U-9U. Return values >= 10U are not digits. */ @@ -955,6 +964,9 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) { significand[part] = 0; } + if (!APFloat::hasSignificand(*semantics)) + return; + unsigned QNaNBit = semantics->precision - 2; if (SNaN) { @@ -1025,6 +1037,19 @@ bool IEEEFloat::isSmallestNormalized() const { isSignificandAllZerosExceptMSB(); } +unsigned int IEEEFloat::getNumHighBits() const { + const unsigned int PartCount = partCountForBits(semantics->precision); + const unsigned int Bits = PartCount * integerPartWidth; + + // Compute how many bits are used in the final word. + // When precision is just 1, it represents the 'Pth' + // Precision bit and not the actual significand bit. + const unsigned int NumHighBits = (semantics->precision > 1) + ? (Bits - semantics->precision + 1) + : (Bits - semantics->precision); + return NumHighBits; +} + bool IEEEFloat::isSignificandAllOnes() const { // Test if the significand excluding the integral bit is all ones. This allows // us to test for binade boundaries. @@ -1035,13 +1060,12 @@ bool IEEEFloat::isSignificandAllOnes() const { return false; // Set the unused high bits to all ones when we compare. - const unsigned NumHighBits = - PartCount*integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && "Can not have more high bits to fill than integerPartWidth"); const integerPart HighBitFill = ~integerPart(0) << (integerPartWidth - NumHighBits); - if (~(Parts[PartCount - 1] | HighBitFill)) + if ((semantics->precision <= 1) || (~(Parts[PartCount - 1] | HighBitFill))) return false; return true; @@ -1062,8 +1086,7 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const { } // Set the unused high bits to all ones when we compare. - const unsigned NumHighBits = - PartCount * integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits <= integerPartWidth && NumHighBits > 0 && "Can not have more high bits to fill than integerPartWidth"); const integerPart HighBitFill = ~integerPart(0) @@ -1085,13 +1108,12 @@ bool IEEEFloat::isSignificandAllZeros() const { return false; // Compute how many bits are used in the final word. - const unsigned NumHighBits = - PartCount*integerPartWidth - semantics->precision + 1; + const unsigned NumHighBits = getNumHighBits(); assert(NumHighBits < integerPartWidth && "Can not have more high bits to " "clear than integerPartWidth"); const integerPart HighBitMask = ~integerPart(0) >> NumHighBits; - if (Parts[PartCount - 1] & HighBitMask) + if ((semantics->precision > 1) && (Parts[PartCount - 1] & HighBitMask)) return false; return true; @@ -1106,25 +1128,26 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const { return false; } - const unsigned NumHighBits = - PartCount * integerPartWidth - semantics->precision + 1; - return Parts[PartCount - 1] == integerPart(1) - << (integerPartWidth - NumHighBits); + const unsigned NumHighBits = getNumHighBits(); + const integerPart MSBMask = integerPart(1) + << (integerPartWidth - NumHighBits); + return ((semantics->precision <= 1) || (Parts[PartCount - 1] == MSBMask)); } bool IEEEFloat::isLargest() const { + bool IsMaxExp = isFiniteNonZero() && exponent == semantics->maxExponent; if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && semantics->nanEncoding == fltNanEncoding::AllOnes) { // The largest number by magnitude in our format will be the floating point // number with maximum exponent and with significand that is all ones except // the LSB. - return isFiniteNonZero() && exponent == semantics->maxExponent && - isSignificandAllOnesExceptLSB(); + return (IsMaxExp && APFloat::hasSignificand(*semantics)) + ? isSignificandAllOnesExceptLSB() + : IsMaxExp; } else { // The largest number by magnitude in our format will be the floating point // number with maximum exponent and with significand that is all ones. - return isFiniteNonZero() && exponent == semantics->maxExponent && - isSignificandAllOnes(); + return IsMaxExp && isSignificandAllOnes(); } } @@ -1165,7 +1188,8 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) { IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) { initialize(&ourSemantics); - makeZero(false); + APFloat::hasZero(ourSemantics) ? makeZero(false) + : makeSmallestNormalized(false); } // Delegate to the previous constructor, because later copy constructor may @@ -1729,6 +1753,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode, category = fcZero; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + // This condition handles the case where the semantics + // does not have zero but uses the all-zero encoding + // to represent the smallest normal value. + if (!APFloat::hasZero(*semantics)) + makeSmallestNormalized(false); } /* The fcZero case is a denormal that underflowed to zero. */ @@ -2606,6 +2635,8 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics, fs = opOK; } + if (category == fcZero && !APFloat::hasZero(*semantics)) + makeSmallestNormalized(false); return fs; } @@ -3070,6 +3101,8 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) { fs = opOK; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!APFloat::hasZero(*semantics)) + makeSmallestNormalized(false); /* Check whether the normalized exponent is high enough to overflow max during the log-rebasing in the max-exponent check below. */ @@ -3533,15 +3566,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const { template APInt IEEEFloat::convertIEEEFloatToAPInt() const { assert(semantics == &S); - - constexpr int bias = -(S.minExponent - 1); + const int bias = + (semantics == &semFloat8E8M0FN) ? -S.minExponent : -(S.minExponent - 1); constexpr unsigned int trailing_significand_bits = S.precision - 1; constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth; constexpr integerPart integer_bit = integerPart{1} << (trailing_significand_bits % integerPartWidth); constexpr uint64_t significand_mask = integer_bit - 1; constexpr unsigned int exponent_bits = - S.sizeInBits - 1 - trailing_significand_bits; + trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits) + : S.sizeInBits; static_assert(exponent_bits < 64); constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1; @@ -3557,6 +3591,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const { !(significandParts()[integer_bit_part] & integer_bit)) myexponent = 0; // denormal } else if (category == fcZero) { + if (semantics == &semFloat8E8M0FN) + llvm_unreachable("semantics does not support zero!"); myexponent = ::exponentZero(S) + bias; mysignificand.fill(0); } else if (category == fcInfinity) { @@ -3659,6 +3695,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const { return convertIEEEFloatToAPInt(); } +APInt IEEEFloat::convertFloat8E8M0FNAPFloatToAPInt() const { + assert(partCount() == 1); + return convertIEEEFloatToAPInt(); +} + APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const { assert(partCount() == 1); return convertIEEEFloatToAPInt(); @@ -3721,6 +3762,9 @@ APInt IEEEFloat::bitcastToAPInt() const { if (semantics == (const llvm::fltSemantics *)&semFloatTF32) return convertFloatTF32APFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FN) + return convertFloat8E8M0FNAPFloatToAPInt(); + if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN) return convertFloat6E3M2FNAPFloatToAPInt(); @@ -3819,6 +3863,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) { } } +// The E8M0 format has the following characteristics: +// It is an 8-bit unsigned format with only exponents (no actual significand) +// No encodings for {zero, infinities or denorms} +// NaN is represented by all 1's +// Bias is 127 +void IEEEFloat::initFromFloat8E8M0FNAPInt(const APInt &api) { + const uint64_t exponent_mask = 0xff; + uint64_t val = api.getRawData()[0]; + uint64_t myexponent = (val & exponent_mask); + + initialize(&semFloat8E8M0FN); + assert(partCount() == 1); + + // This format has unsigned representation only + sign = 0; + + // Set the significand + // This format does not have any significand but the 'Pth' precision bit is + // always set to 1 for consistency in APFloat's internal representation. + uint64_t mysignificand = 1; + significandParts()[0] = mysignificand; + + // This format can either have a NaN or fcNormal + // All 1's i.e. 255 is a NaN + if (val == exponent_mask) { + category = fcNaN; + exponent = exponentNaN(); + return; + } + // Handle fcNormal... + category = fcNormal; + exponent = myexponent - 127; // 127 is bias + return; +} template void IEEEFloat::initFromIEEEAPInt(const APInt &api) { assert(api.getBitWidth() == S.sizeInBits); @@ -3999,6 +4077,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) { return initFromFloat8E3M4APInt(api); if (Sem == &semFloatTF32) return initFromFloatTF32APInt(api); + if (Sem == &semFloat8E8M0FN) + return initFromFloat8E8M0FNAPInt(api); if (Sem == &semFloat6E3M2FN) return initFromFloat6E3M2FNAPInt(api); if (Sem == &semFloat6E2M3FN) @@ -4032,9 +4112,9 @@ void IEEEFloat::makeLargest(bool Negative) { significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth) ? (~integerPart(0) >> NumUnusedHighBits) : 0; - if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly && - semantics->nanEncoding == fltNanEncoding::AllOnes) + semantics->nanEncoding == fltNanEncoding::AllOnes && + (semantics->precision > 1)) significand[0] &= ~integerPart(1); } @@ -4509,6 +4589,8 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { exponent = 0; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) sign = false; + if (!APFloat::hasZero(*semantics)) + makeSmallestNormalized(false); break; } @@ -4574,7 +4656,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) { // the integral bit to 1, and increment the exponent. If we have a // denormal always increment since moving denormals and the numbers in the // smallest normal binade have the same exponent in our representation. - bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes(); + // If there are only exponents, any increment always crosses the + // BinadeBoundary. + bool WillCrossBinadeBoundary = + !isDenormal() && + (APFloat::hasSignificand(*semantics) ? isSignificandAllOnes() : true); if (WillCrossBinadeBoundary) { integerPart *Parts = significandParts(); @@ -4626,6 +4712,9 @@ void IEEEFloat::makeInf(bool Negative) { } void IEEEFloat::makeZero(bool Negative) { + if (!APFloat::hasZero(*semantics)) + llvm_unreachable("This floating point format does not support Zero"); + category = fcZero; sign = Negative; if (semantics->nanEncoding == fltNanEncoding::NegativeZero) { diff --git a/llvm/unittests/ADT/APFloatTest.cpp b/llvm/unittests/ADT/APFloatTest.cpp index 6c49d78e5c8ea9..caf57522460dfa 100644 --- a/llvm/unittests/ADT/APFloatTest.cpp +++ b/llvm/unittests/ADT/APFloatTest.cpp @@ -814,8 +814,10 @@ TEST(APFloatTest, IsSmallestNormalized) { const fltSemantics &Semantics = APFloat::EnumToSemantics(static_cast(I)); - EXPECT_FALSE(APFloat::getZero(Semantics, false).isSmallestNormalized()); - EXPECT_FALSE(APFloat::getZero(Semantics, true).isSmallestNormalized()); + if (APFloat::hasZero(Semantics)) { + EXPECT_FALSE(APFloat::getZero(Semantics, false).isSmallestNormalized()); + EXPECT_FALSE(APFloat::getZero(Semantics, true).isSmallestNormalized()); + } if (APFloat::hasNanOrInf(Semantics)) { EXPECT_FALSE(APFloat::getInf(Semantics, false).isSmallestNormalized()); @@ -828,11 +830,23 @@ TEST(APFloatTest, IsSmallestNormalized) { EXPECT_FALSE(APFloat::getLargest(Semantics).isSmallestNormalized()); EXPECT_FALSE(APFloat::getLargest(Semantics, true).isSmallestNormalized()); - EXPECT_FALSE(APFloat::getSmallest(Semantics).isSmallestNormalized()); - EXPECT_FALSE(APFloat::getSmallest(Semantics, true).isSmallestNormalized()); + if (I != APFloat::S_Float8E8M0FN) { + EXPECT_FALSE(APFloat::getSmallest(Semantics).isSmallestNormalized()); + EXPECT_FALSE( + APFloat::getSmallest(Semantics, true).isSmallestNormalized()); + } else { + // For E8M0 format, there are no denorms. + // So, getSmallest is equal to isSmallestNormalized(). + EXPECT_TRUE(APFloat::getSmallest(Semantics).isSmallestNormalized()); + EXPECT_TRUE(APFloat::getSmallest(Semantics, true).isSmallestNormalized()); + } EXPECT_FALSE(APFloat::getAllOnesValue(Semantics).isSmallestNormalized()); + // For E8M0 format, the below cases are tested through Float8E8M0FNNext. + if (I == APFloat::S_Float8E8M0FN) + continue; + APFloat PosSmallestNormalized = APFloat::getSmallestNormalized(Semantics, false); APFloat NegSmallestNormalized = @@ -1907,6 +1921,57 @@ TEST(DoubleAPFloatTest, isInteger) { EXPECT_FALSE(T3.isInteger()); } +// Test to check if the full range of Float8E8M0FN +// values are being represented correctly. +TEST(APFloatTest, Float8E8M0FNValues) { + // High end of the range + auto test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p127"); + EXPECT_EQ(0x1.0p127, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p126"); + EXPECT_EQ(0x1.0p126, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p125"); + EXPECT_EQ(0x1.0p125, test.convertToDouble()); + + // tests the fix in makeLargest() + test = APFloat::getLargest(APFloat::Float8E8M0FN()); + EXPECT_EQ(0x1.0p127, test.convertToDouble()); + + // tests overflow to nan + APFloat nan = APFloat(APFloat::Float8E8M0FN(), "nan"); + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p128"); + EXPECT_TRUE(test.bitwiseIsEqual(nan)); + + // Mid of the range + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p0"); + EXPECT_EQ(1.0, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p1"); + EXPECT_EQ(2.0, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p2"); + EXPECT_EQ(4.0, test.convertToDouble()); + + // Low end of the range + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-125"); + EXPECT_EQ(0x1.0p-125, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-126"); + EXPECT_EQ(0x1.0p-126, test.convertToDouble()); + + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-127"); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + + // Smallest value + test = APFloat::getSmallest(APFloat::Float8E8M0FN()); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + + // Value below the smallest, but clamped to the smallest + test = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-128"); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); +} + TEST(APFloatTest, getLargest) { EXPECT_EQ(3.402823466e+38f, APFloat::getLargest(APFloat::IEEEsingle()).convertToFloat()); EXPECT_EQ(1.7976931348623158e+308, APFloat::getLargest(APFloat::IEEEdouble()).convertToDouble()); @@ -1919,6 +1984,8 @@ TEST(APFloatTest, getLargest) { 30, APFloat::getLargest(APFloat::Float8E4M3B11FNUZ()).convertToDouble()); EXPECT_EQ(3.40116213421e+38f, APFloat::getLargest(APFloat::FloatTF32()).convertToFloat()); + EXPECT_EQ(1.701411834e+38f, + APFloat::getLargest(APFloat::Float8E8M0FN()).convertToDouble()); EXPECT_EQ(28, APFloat::getLargest(APFloat::Float6E3M2FN()).convertToDouble()); EXPECT_EQ(7.5, APFloat::getLargest(APFloat::Float6E2M3FN()).convertToDouble()); @@ -2002,6 +2069,13 @@ TEST(APFloatTest, getSmallest) { EXPECT_TRUE(test.isFiniteNonZero()); EXPECT_TRUE(test.isDenormal()); EXPECT_TRUE(test.bitwiseIsEqual(expected)); + + test = APFloat::getSmallest(APFloat::Float8E8M0FN()); + expected = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-127"); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(test.isFiniteNonZero()); + EXPECT_FALSE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); } TEST(APFloatTest, getSmallestNormalized) { @@ -2108,6 +2182,14 @@ TEST(APFloatTest, getSmallestNormalized) { EXPECT_FALSE(test.isDenormal()); EXPECT_TRUE(test.bitwiseIsEqual(expected)); EXPECT_TRUE(test.isSmallestNormalized()); + + test = APFloat::getSmallestNormalized(APFloat::Float8E8M0FN(), false); + expected = APFloat(APFloat::Float8E8M0FN(), "0x1.0p-127"); + EXPECT_FALSE(test.isNegative()); + EXPECT_TRUE(test.isFiniteNonZero()); + EXPECT_FALSE(test.isDenormal()); + EXPECT_TRUE(test.bitwiseIsEqual(expected)); + EXPECT_TRUE(test.isSmallestNormalized()); } TEST(APFloatTest, getZero) { @@ -5791,6 +5873,46 @@ TEST(APFloatTest, Float8E4M3FNExhaustive) { } } +TEST(APFloatTest, Float8E8M0FNExhaustive) { + // Test each of the 256 Float8E8M0FN values. + for (int i = 0; i < 256; i++) { + APFloat test(APFloat::Float8E8M0FN(), APInt(8, i)); + SCOPED_TRACE("i=" + std::to_string(i)); + + // isLargest + if (i == 254) { + EXPECT_TRUE(test.isLargest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1.0p127); + } else { + EXPECT_FALSE(test.isLargest()); + } + + // isSmallest + if (i == 0) { + EXPECT_TRUE(test.isSmallest()); + EXPECT_EQ(abs(test).convertToDouble(), 0x1.0p-127); + } else { + EXPECT_FALSE(test.isSmallest()); + } + + // convert to Double + bool losesInfo; + std::string val = std::to_string(i - 127); // 127 is the bias + llvm::SmallString<16> str("0x1.0p"); + str += val; + APFloat test2(APFloat::IEEEdouble(), str); + + APFloat::opStatus status = test.convert( + APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &losesInfo); + EXPECT_EQ(status, APFloat::opOK); + EXPECT_FALSE(losesInfo); + if (i == 255) + EXPECT_TRUE(test.isNaN()); + else + EXPECT_EQ(test.convertToDouble(), test2.convertToDouble()); + } +} + TEST(APFloatTest, Float8E5M2FNUZNext) { APFloat test(APFloat::Float8E5M2FNUZ(), APFloat::uninitialized); APFloat expected(APFloat::Float8E5M2FNUZ(), APFloat::uninitialized); @@ -7081,10 +7203,17 @@ TEST(APFloatTest, getExactLog2) { int Precision = APFloat::semanticsPrecision(Semantics); EXPECT_EQ(0, One.getExactLog2()); - EXPECT_EQ(INT_MIN, APFloat(Semantics, "3.0").getExactLog2()); - EXPECT_EQ(INT_MIN, APFloat(Semantics, "-3.0").getExactLog2()); - EXPECT_EQ(INT_MIN, APFloat(Semantics, "3.0").getExactLog2Abs()); - EXPECT_EQ(INT_MIN, APFloat(Semantics, "-3.0").getExactLog2Abs()); + if (APFloat::hasExponentOnly(Semantics)) { + EXPECT_EQ(2, APFloat(Semantics, "3.0").getExactLog2()); + EXPECT_EQ(INT_MIN, APFloat(Semantics, "-3.0").getExactLog2()); + EXPECT_EQ(2, APFloat(Semantics, "3.0").getExactLog2Abs()); + EXPECT_EQ(2, APFloat(Semantics, "-3.0").getExactLog2Abs()); + } else { + EXPECT_EQ(INT_MIN, APFloat(Semantics, "3.0").getExactLog2()); + EXPECT_EQ(INT_MIN, APFloat(Semantics, "-3.0").getExactLog2()); + EXPECT_EQ(INT_MIN, APFloat(Semantics, "3.0").getExactLog2Abs()); + EXPECT_EQ(INT_MIN, APFloat(Semantics, "-3.0").getExactLog2Abs()); + } if (I == APFloat::S_Float6E2M3FN || I == APFloat::S_Float4E2M1FN) { EXPECT_EQ(2, APFloat(Semantics, "4.0").getExactLog2()); @@ -7102,10 +7231,12 @@ TEST(APFloatTest, getExactLog2) { EXPECT_EQ(3, APFloat(Semantics, "-8.0").getExactLog2Abs()); } - EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, false).getExactLog2()); - EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, true).getExactLog2()); - EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, false).getExactLog2Abs()); - EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, true).getExactLog2Abs()); + if (APFloat::hasZero(Semantics)) { + EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, false).getExactLog2()); + EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, true).getExactLog2()); + EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, false).getExactLog2Abs()); + EXPECT_EQ(INT_MIN, APFloat::getZero(Semantics, true).getExactLog2Abs()); + } if (APFloat::hasNanOrInf(Semantics)) { EXPECT_EQ(INT_MIN, APFloat::getInf(Semantics).getExactLog2()); @@ -7119,12 +7250,21 @@ TEST(APFloatTest, getExactLog2) { EXPECT_EQ(INT_MIN, APFloat::getNaN(Semantics, true).getExactLog2Abs()); } - EXPECT_EQ(INT_MIN, - scalbn(One, MinExp - Precision - 1, APFloat::rmNearestTiesToEven) - .getExactLog2()); - EXPECT_EQ(INT_MIN, - scalbn(One, MinExp - Precision, APFloat::rmNearestTiesToEven) - .getExactLog2()); + if (APFloat::hasExponentOnly(Semantics)) { + EXPECT_EQ(-127, scalbn(One, MinExp - Precision - 1, + APFloat::rmNearestTiesToEven) + .getExactLog2()); + EXPECT_EQ(-127, + scalbn(One, MinExp - Precision, APFloat::rmNearestTiesToEven) + .getExactLog2()); + } else { + EXPECT_EQ(INT_MIN, scalbn(One, MinExp - Precision - 1, + APFloat::rmNearestTiesToEven) + .getExactLog2()); + EXPECT_EQ(INT_MIN, + scalbn(One, MinExp - Precision, APFloat::rmNearestTiesToEven) + .getExactLog2()); + } EXPECT_EQ( INT_MIN, @@ -7136,6 +7276,145 @@ TEST(APFloatTest, getExactLog2) { } } +TEST(APFloatTest, Float8E8M0FNGetZero) { +#ifdef GTEST_HAS_DEATH_TEST +#ifndef NDEBUG + EXPECT_DEATH(APFloat::getZero(APFloat::Float8E8M0FN(), false), + "This floating point format does not support Zero"); + EXPECT_DEATH(APFloat::getZero(APFloat::Float8E8M0FN(), true), + "This floating point format does not support Zero"); +#endif +#endif +} + +TEST(APFloatTest, Float8E8M0FNGetInf) { + // The E8M0 format does not support infinity and the + // all ones representation is treated as NaN. + APFloat t = APFloat::getInf(APFloat::Float8E8M0FN()); + EXPECT_TRUE(t.isNaN()); + EXPECT_FALSE(t.isInfinity()); + + t = APFloat::getInf(APFloat::Float8E8M0FN(), true); + EXPECT_TRUE(t.isNaN()); + EXPECT_FALSE(t.isInfinity()); +} + +TEST(APFloatTest, Float8E8M0FNFromString) { + // Exactly representable + EXPECT_EQ(64, APFloat(APFloat::Float8E8M0FN(), "64").convertToDouble()); + // Overflow to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E8M0FN(), "0x1.0p128").isNaN()); + // Inf converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E8M0FN(), "inf").isNaN()); + // NaN converted to NaN + EXPECT_TRUE(APFloat(APFloat::Float8E8M0FN(), "nan").isNaN()); +} + +TEST(APFloatTest, Float8E8M0FNDivideByZero) { + APFloat x(APFloat::Float8E8M0FN(), "1"); + APFloat zero(APFloat::Float8E8M0FN(), "0"); + x.divide(zero, APFloat::rmNearestTiesToEven); + + // Zero is represented as the smallest normalized value + // in this format i.e 2^-127. + // This tests the fix in convertFromDecimalString() function. + EXPECT_EQ(0x1.0p-127, zero.convertToDouble()); + + // [1 / (2^-127)] = 2^127 + EXPECT_EQ(0x1.0p127, x.convertToDouble()); +} + +TEST(APFloatTest, Float8E8M0FNNext) { + APFloat test(APFloat::getSmallest(APFloat::Float8E8M0FN())); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + EXPECT_TRUE(test.isSmallestNormalized()); + EXPECT_EQ(fcPosNormal, test.classify()); + + // Increment of 1 should reach 2^-126 + EXPECT_EQ(APFloat::opOK, test.next(false)); + EXPECT_FALSE(test.isSmallestNormalized()); + EXPECT_EQ(0x1.0p-126, test.convertToDouble()); + + // Decrement of 1, again, should reach 2^-127 + // i.e. smallest normalized + EXPECT_EQ(APFloat::opOK, test.next(true)); + EXPECT_TRUE(test.isSmallestNormalized()); + + // Decrement again, but gets capped at the smallest normalized + EXPECT_EQ(APFloat::opOK, test.next(true)); + EXPECT_TRUE(test.isSmallestNormalized()); +} + +TEST(APFloatTest, ConvertDoubleToE8M0FN) { + bool losesInfo; + APFloat test(APFloat::IEEEdouble(), "1.0"); + APFloat::opStatus status = test.convert( + APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, &losesInfo); + EXPECT_EQ(1.0, test.convertToDouble()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // For E8M0, zero encoding is represented as the smallest normalized value. + test = APFloat(APFloat::IEEEdouble(), "0.0"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(test.isSmallestNormalized()); + EXPECT_EQ(0x1.0p-127, test.convertToDouble()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // Test that the conversion of a power-of-two value is precise. + test = APFloat(APFloat::IEEEdouble(), "8.0"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(8.0f, test.convertToDouble()); + EXPECT_FALSE(losesInfo); + EXPECT_EQ(status, APFloat::opOK); + + // Test to check round-down conversion to power-of-two. + // The fractional part of 9 is "001" (i.e. 1.125x2^3=9). + test = APFloat(APFloat::IEEEdouble(), "9.0"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(8.0f, test.convertToDouble()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Test to check round-up conversion to power-of-two. + // The fractional part of 13 is "101" (i.e. 1.625x2^3=13). + test = APFloat(APFloat::IEEEdouble(), "13.0"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(16.0f, test.convertToDouble()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Test to check round-up conversion to power-of-two. + // The fractional part of 12 is "100" (i.e. 1.5x2^3=12). + test = APFloat(APFloat::IEEEdouble(), "12.0"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_EQ(16.0f, test.convertToDouble()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opInexact); + + // Overflow to NaN. + test = APFloat(APFloat::IEEEdouble(), "0x1.0p128"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(test.isNaN()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opOverflow | APFloat::opInexact); + + // Underflow to smallest normalized value. + test = APFloat(APFloat::IEEEdouble(), "0x1.0p-128"); + status = test.convert(APFloat::Float8E8M0FN(), APFloat::rmNearestTiesToEven, + &losesInfo); + EXPECT_TRUE(test.isSmallestNormalized()); + EXPECT_TRUE(losesInfo); + EXPECT_EQ(status, APFloat::opUnderflow | APFloat::opInexact); +} + TEST(APFloatTest, Float6E3M2FNFromString) { // Exactly representable EXPECT_EQ(28, APFloat(APFloat::Float6E3M2FN(), "28").convertToDouble());